1091 lines
32 KiB
Python
1091 lines
32 KiB
Python
"""Lightweight generation job/event tracking."""
|
||
|
||
from __future__ import annotations
|
||
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Any
|
||
|
||
from fastapi import HTTPException
|
||
from sqlalchemy import desc, select, update
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.config import settings
|
||
from app.core.logging import get_logger
|
||
from app.db.models import (
|
||
GenerationJob,
|
||
GenerationJobEvent,
|
||
Story,
|
||
)
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def _is_terminal_status(status: str) -> bool:
|
||
return status in {"completed", "degraded_completed", "failed", "canceled"}
|
||
|
||
|
||
def _job_supports_queue_control(job: GenerationJob) -> bool:
|
||
return job.output_mode in {"story", "storybook", "asset_generation"}
|
||
|
||
|
||
def generation_job_can_cancel(job: GenerationJob) -> bool:
|
||
return (
|
||
_job_supports_queue_control(job)
|
||
and job.status == "running"
|
||
and job.current_step != "cancel_requested"
|
||
)
|
||
|
||
|
||
def generation_job_can_retry(job: GenerationJob) -> bool:
|
||
return _job_supports_queue_control(job) and job.status in {"failed", "canceled"}
|
||
|
||
|
||
def _story_snapshot(story: Story | None) -> dict[str, Any]:
|
||
if story is None:
|
||
return {}
|
||
|
||
return {
|
||
"story_id": story.id,
|
||
"mode": story.mode,
|
||
"generation_status": story.generation_status,
|
||
"text_status": story.text_status,
|
||
"image_status": story.image_status,
|
||
"audio_status": story.audio_status,
|
||
"retryable_assets": story.retryable_assets,
|
||
"last_error": story.last_error,
|
||
}
|
||
|
||
|
||
def _job_status_from_story(story: Story) -> str:
|
||
if story.generation_status == "failed":
|
||
return "failed"
|
||
if story.generation_status == "degraded_completed":
|
||
return "degraded_completed"
|
||
return "completed"
|
||
|
||
|
||
def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
||
"""Resolve a compact progress summary for polling-oriented clients."""
|
||
|
||
if job.status == "failed":
|
||
return {
|
||
"progress_percent": 100,
|
||
"progress_label": "生成失败",
|
||
"is_terminal": True,
|
||
}
|
||
|
||
if job.status == "canceled":
|
||
return {
|
||
"progress_percent": 100,
|
||
"progress_label": "已取消",
|
||
"is_terminal": True,
|
||
}
|
||
|
||
if job.status in {"completed", "degraded_completed"}:
|
||
return {
|
||
"progress_percent": 100,
|
||
"progress_label": "已完成" if job.status == "completed" else "降级完成",
|
||
"is_terminal": True,
|
||
}
|
||
|
||
progress_map: dict[str, tuple[int, str]] = {
|
||
"request_accepted": (5, "已接收请求"),
|
||
"workflow_planned": (8, "工作流已规划"),
|
||
"retry_queued": (8, "重新排队中"),
|
||
"worker_started": (12, "后台任务已开始"),
|
||
"cancel_requested": (15, "已请求取消"),
|
||
"context_prepared": (20, "上下文已准备"),
|
||
"narrative_generated": (45, "正文已生成"),
|
||
"evaluation_completed": (52, "内容评测已完成"),
|
||
"story_saved": (60, "主记录已保存"),
|
||
"provider_call_started": (65, "Provider 调用中"),
|
||
"provider_call_succeeded": (72, "Provider 调用成功"),
|
||
"provider_call_failed": (72, "Provider 调用失败,尝试恢复"),
|
||
"cover_image_started": (75, "封面生成中"),
|
||
"cover_image_succeeded": (88, "封面已生成"),
|
||
"cover_image_failed": (88, "封面生成失败"),
|
||
"storybook_images_started": (75, "绘本插图生成中"),
|
||
"storybook_cover_image_succeeded": (82, "绘本封面已生成"),
|
||
"storybook_cover_image_failed": (82, "绘本封面生成失败"),
|
||
"storybook_page_image_succeeded": (86, "分页插图已生成"),
|
||
"storybook_page_image_failed": (86, "分页插图生成失败"),
|
||
"storybook_images_completed": (92, "绘本插图已完成"),
|
||
"audio_started": (75, "音频生成中"),
|
||
"audio_cache_hit": (88, "音频缓存已复用"),
|
||
"audio_succeeded": (88, "音频已生成"),
|
||
"audio_failed": (88, "音频生成失败"),
|
||
"asset_retry_started": (25, "资源重试中"),
|
||
"postprocessing_queued": (90, "后处理已排队"),
|
||
"asset_generation_completed": (100, "资源已完成"),
|
||
"asset_retry_completed": (100, "资源重试完成"),
|
||
"generation_canceled": (100, "任务已取消"),
|
||
"generation_completed": (100, "生成完成"),
|
||
"generation_stale_failed": (100, "任务超时已收敛"),
|
||
}
|
||
percent, label = progress_map.get(job.current_step, (10, "生成处理中"))
|
||
return {
|
||
"progress_percent": percent,
|
||
"progress_label": label,
|
||
"is_terminal": percent >= 100,
|
||
}
|
||
|
||
|
||
def _normalize_datetime(value: datetime) -> datetime:
|
||
if value.tzinfo is None:
|
||
return value.replace(tzinfo=timezone.utc)
|
||
return value.astimezone(timezone.utc)
|
||
|
||
|
||
def _is_stale_job(job: GenerationJob, *, stale_after_minutes: int) -> bool:
|
||
cutoff = datetime.now(timezone.utc) - timedelta(minutes=stale_after_minutes)
|
||
return job.status == "running" and _normalize_datetime(job.updated_at) <= cutoff
|
||
|
||
|
||
def _failure_label(job: GenerationJob) -> str:
|
||
if job.status == "canceled":
|
||
return "任务已取消"
|
||
if job.current_step == "generation_stale_failed":
|
||
return "任务超时"
|
||
if job.output_mode == "asset_retry":
|
||
return "资源重试失败"
|
||
if job.output_mode == "asset_generation":
|
||
return "资源生成失败"
|
||
return "生成失败"
|
||
|
||
|
||
async def create_generation_job(
|
||
db: AsyncSession,
|
||
*,
|
||
user_id: str,
|
||
output_mode: str,
|
||
input_type: str,
|
||
request_payload: dict[str, Any],
|
||
story_id: int | None = None,
|
||
) -> GenerationJob:
|
||
"""Create a generation job and record its first event."""
|
||
|
||
job = GenerationJob(
|
||
user_id=user_id,
|
||
story_id=story_id,
|
||
output_mode=output_mode,
|
||
input_type=input_type,
|
||
status="running",
|
||
current_step="request_accepted",
|
||
request_payload=request_payload,
|
||
result_snapshot={},
|
||
)
|
||
db.add(job)
|
||
await db.flush()
|
||
await record_generation_event(
|
||
db,
|
||
job=job,
|
||
story_id=story_id,
|
||
event_type="request_accepted",
|
||
status="succeeded",
|
||
message="Generation request accepted.",
|
||
metadata={"output_mode": output_mode, "input_type": input_type},
|
||
commit=False,
|
||
)
|
||
await db.commit()
|
||
await db.refresh(job)
|
||
return job
|
||
|
||
|
||
async def record_generation_event(
|
||
db: AsyncSession,
|
||
*,
|
||
job: GenerationJob,
|
||
event_type: str,
|
||
status: str,
|
||
story_id: int | None = None,
|
||
message: str | None = None,
|
||
metadata: dict[str, Any] | None = None,
|
||
commit: bool = True,
|
||
) -> GenerationJobEvent:
|
||
"""Append one event to an existing generation job."""
|
||
|
||
job.current_step = event_type
|
||
if story_id is not None:
|
||
job.story_id = story_id
|
||
|
||
event = GenerationJobEvent(
|
||
job_id=job.id,
|
||
story_id=story_id if story_id is not None else job.story_id,
|
||
event_type=event_type,
|
||
status=status,
|
||
message=message,
|
||
event_metadata=metadata or {},
|
||
)
|
||
db.add(event)
|
||
if commit:
|
||
await db.commit()
|
||
return event
|
||
|
||
|
||
async def claim_generation_job_for_worker(
|
||
db: AsyncSession,
|
||
*,
|
||
job_id: str,
|
||
) -> GenerationJob | None:
|
||
"""Claim one queued generation job for worker execution once."""
|
||
|
||
claim_result = await db.execute(
|
||
update(GenerationJob)
|
||
.where(
|
||
GenerationJob.id == job_id,
|
||
GenerationJob.status == "running",
|
||
GenerationJob.current_step.in_(["request_accepted", "retry_queued"]),
|
||
)
|
||
.values(current_step="worker_started")
|
||
)
|
||
await db.commit()
|
||
|
||
if not claim_result.rowcount:
|
||
return None
|
||
|
||
result = await db.execute(select(GenerationJob).where(GenerationJob.id == job_id))
|
||
job = result.scalar_one_or_none()
|
||
if job is None:
|
||
return None
|
||
|
||
await record_generation_event(
|
||
db,
|
||
job=job,
|
||
event_type="worker_started",
|
||
status="running",
|
||
message="Generation worker started processing the accepted request.",
|
||
)
|
||
return job
|
||
|
||
|
||
async def finish_generation_job(
|
||
db: AsyncSession,
|
||
*,
|
||
job: GenerationJob,
|
||
story: Story | None,
|
||
status: str | None = None,
|
||
current_step: str,
|
||
error_message: str | None = None,
|
||
message: str | None = None,
|
||
metadata: dict[str, Any] | None = None,
|
||
) -> GenerationJob:
|
||
"""Mark a generation job as completed/degraded/failed and append a final event."""
|
||
|
||
job.story_id = story.id if story is not None else job.story_id
|
||
job.status = status or (_job_status_from_story(story) if story is not None else "failed")
|
||
job.current_step = current_step
|
||
job.error_message = error_message
|
||
job.result_snapshot = _story_snapshot(story)
|
||
await record_generation_event(
|
||
db,
|
||
job=job,
|
||
story_id=job.story_id,
|
||
event_type=current_step,
|
||
status=job.status,
|
||
message=message,
|
||
metadata={
|
||
**(metadata or {}),
|
||
"result_snapshot": job.result_snapshot,
|
||
},
|
||
commit=False,
|
||
)
|
||
await db.commit()
|
||
await db.refresh(job)
|
||
return job
|
||
|
||
|
||
def generation_event_to_response(event: GenerationJobEvent) -> dict[str, Any]:
|
||
"""Convert a generation event ORM object to an API response dict."""
|
||
|
||
return {
|
||
"id": event.id,
|
||
"job_id": event.job_id,
|
||
"story_id": event.story_id,
|
||
"event_type": event.event_type,
|
||
"status": event.status,
|
||
"message": event.message,
|
||
"event_metadata": event.event_metadata or {},
|
||
"created_at": event.created_at,
|
||
}
|
||
|
||
|
||
_PUBLIC_EVENT_METADATA_KEYS = {
|
||
"adapter",
|
||
"artifact",
|
||
"asset",
|
||
"assets",
|
||
"attempted_cover",
|
||
"audio_status",
|
||
"blocks_main_result",
|
||
"capability",
|
||
"completed_pages",
|
||
"cover_prompt_present",
|
||
"estimated_cost_usd",
|
||
"failed_pages",
|
||
"failure_category",
|
||
"generation_status",
|
||
"has_memory_context",
|
||
"image_status",
|
||
"input_type",
|
||
"latency_ms",
|
||
"mode",
|
||
"output_mode",
|
||
"page_count",
|
||
"page_number",
|
||
"recoverable",
|
||
"requested_from_step",
|
||
"retryable",
|
||
"scope",
|
||
"stale_after_minutes",
|
||
"status",
|
||
"step",
|
||
"strategy",
|
||
"text_status",
|
||
}
|
||
|
||
_PUBLIC_REQUEST_PAYLOAD_KEYS = {
|
||
"assets",
|
||
"child_profile_id",
|
||
"generate_images",
|
||
"input_type",
|
||
"output_mode",
|
||
"page_count",
|
||
"story_id",
|
||
"type",
|
||
"universe_id",
|
||
}
|
||
|
||
|
||
def _public_metadata_value(value: Any) -> Any:
|
||
"""Return a JSON-safe public value or None when the value is internal."""
|
||
|
||
if isinstance(value, str | int | float | bool) or value is None:
|
||
return value
|
||
if isinstance(value, list):
|
||
public_items = [
|
||
item
|
||
for item in value
|
||
if isinstance(item, str | int | float | bool) or item is None
|
||
]
|
||
return public_items
|
||
return None
|
||
|
||
|
||
def public_generation_request_payload(job: GenerationJob) -> dict[str, Any]:
|
||
"""Return request payload fields safe for user-facing job details."""
|
||
|
||
payload = job.request_payload or {}
|
||
public_payload: dict[str, Any] = {}
|
||
|
||
for key in sorted(_PUBLIC_REQUEST_PAYLOAD_KEYS):
|
||
if key not in payload:
|
||
continue
|
||
value = _public_metadata_value(payload[key])
|
||
if value is not None:
|
||
public_payload[key] = value
|
||
|
||
return public_payload
|
||
|
||
|
||
def _public_plan_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
|
||
"""Expose only coarse workflow plan metadata to user-facing responses."""
|
||
|
||
plan = metadata.get("plan")
|
||
if not isinstance(plan, dict):
|
||
return {}
|
||
|
||
public: dict[str, Any] = {}
|
||
mode = plan.get("mode")
|
||
if isinstance(mode, str):
|
||
public["plan_mode"] = mode
|
||
|
||
tasks = plan.get("tasks")
|
||
if isinstance(tasks, list):
|
||
public["planned_task_count"] = len(tasks)
|
||
public["recoverable_task_count"] = sum(
|
||
1
|
||
for task in tasks
|
||
if isinstance(task, dict) and task.get("recoverable") is True
|
||
)
|
||
|
||
return public
|
||
|
||
|
||
def public_generation_event_metadata(event: GenerationJobEvent) -> dict[str, Any]:
|
||
"""Return event metadata safe for user-facing job event streams."""
|
||
|
||
metadata = event.event_metadata or {}
|
||
public_metadata: dict[str, Any] = {}
|
||
|
||
for key in sorted(_PUBLIC_EVENT_METADATA_KEYS):
|
||
if key not in metadata:
|
||
continue
|
||
value = _public_metadata_value(metadata[key])
|
||
if value is not None:
|
||
public_metadata[key] = value
|
||
|
||
if event.event_type == "workflow_planned":
|
||
public_metadata.update(_public_plan_metadata(metadata))
|
||
|
||
return public_metadata
|
||
|
||
|
||
def public_generation_event_to_response(event: GenerationJobEvent) -> dict[str, Any] | None:
|
||
"""Convert a generation event for user-facing APIs with internal data removed."""
|
||
|
||
if event.event_type in {"evaluation_completed", "executor_completed"}:
|
||
return None
|
||
response = generation_event_to_response(event)
|
||
response["event_metadata"] = public_generation_event_metadata(event)
|
||
return response
|
||
|
||
|
||
def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
|
||
"""Convert a generation job ORM object to an API summary dict."""
|
||
|
||
progress = _job_progress(job)
|
||
return {
|
||
"id": job.id,
|
||
"story_id": job.story_id,
|
||
"output_mode": job.output_mode,
|
||
"input_type": job.input_type,
|
||
"status": job.status,
|
||
"current_step": job.current_step,
|
||
**progress,
|
||
"can_cancel": generation_job_can_cancel(job),
|
||
"can_retry": generation_job_can_retry(job),
|
||
"result_snapshot": job.result_snapshot or {},
|
||
"error_message": job.error_message,
|
||
"created_at": job.created_at,
|
||
"updated_at": job.updated_at,
|
||
}
|
||
|
||
|
||
def public_generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
|
||
"""Convert a generation job for user-facing APIs with internal steps hidden."""
|
||
|
||
summary = generation_job_to_summary(job)
|
||
if summary["current_step"] == "evaluation_completed":
|
||
summary["current_step"] = "narrative_generated"
|
||
summary["progress_percent"] = 45
|
||
summary["progress_label"] = "正文已生成"
|
||
summary["is_terminal"] = False
|
||
elif summary["current_step"] == "executor_completed":
|
||
summary["current_step"] = "workflow_planned"
|
||
summary["progress_percent"] = 8
|
||
summary["progress_label"] = "工作流已规划"
|
||
summary["is_terminal"] = False
|
||
return summary
|
||
|
||
|
||
async def get_generation_job_for_user(
|
||
db: AsyncSession,
|
||
*,
|
||
job_id: str,
|
||
user_id: str,
|
||
) -> GenerationJob:
|
||
"""Load one generation job owned by the current user."""
|
||
|
||
result = await db.execute(
|
||
select(GenerationJob).where(
|
||
GenerationJob.id == job_id,
|
||
GenerationJob.user_id == user_id,
|
||
)
|
||
)
|
||
job = result.scalar_one_or_none()
|
||
if job is None:
|
||
raise HTTPException(status_code=404, detail="Generation job not found")
|
||
return job
|
||
|
||
|
||
async def request_generation_job_cancel(
|
||
db: AsyncSession,
|
||
*,
|
||
job_id: str,
|
||
user_id: str,
|
||
) -> dict[str, Any]:
|
||
"""Request cancellation for one queued/running generation job."""
|
||
|
||
job = await get_generation_job_for_user(db, job_id=job_id, user_id=user_id)
|
||
|
||
if not _job_supports_queue_control(job):
|
||
raise HTTPException(status_code=409, detail="当前任务不支持取消")
|
||
|
||
if job.status == "canceled":
|
||
return public_generation_job_to_summary(job)
|
||
|
||
if _is_terminal_status(job.status):
|
||
raise HTTPException(status_code=409, detail="当前任务已终止,无法取消")
|
||
|
||
if job.current_step == "cancel_requested":
|
||
return public_generation_job_to_summary(job)
|
||
|
||
if job.current_step in {"request_accepted", "retry_queued"}:
|
||
story = None
|
||
if job.story_id is not None:
|
||
story = (
|
||
await db.execute(
|
||
select(Story).where(
|
||
Story.id == job.story_id,
|
||
Story.user_id == job.user_id,
|
||
)
|
||
)
|
||
).scalar_one_or_none()
|
||
|
||
await finish_generation_job(
|
||
db,
|
||
job=job,
|
||
story=story,
|
||
status="canceled",
|
||
current_step="generation_canceled",
|
||
error_message="Generation canceled by user before worker execution started.",
|
||
message="Generation job was canceled before worker execution started.",
|
||
)
|
||
return public_generation_job_to_summary(job)
|
||
|
||
previous_step = job.current_step
|
||
job.error_message = "Cancellation requested by user."
|
||
await record_generation_event(
|
||
db,
|
||
job=job,
|
||
story_id=job.story_id,
|
||
event_type="cancel_requested",
|
||
status="running",
|
||
message="Cancellation requested; worker will stop at the next safe checkpoint.",
|
||
metadata={"requested_from_step": previous_step},
|
||
commit=False,
|
||
)
|
||
await db.commit()
|
||
await db.refresh(job)
|
||
return public_generation_job_to_summary(job)
|
||
|
||
|
||
async def get_generation_job_detail(
|
||
db: AsyncSession,
|
||
*,
|
||
job_id: str,
|
||
user_id: str,
|
||
) -> dict[str, Any]:
|
||
"""Return a user-owned generation job with its ordered event stream."""
|
||
|
||
result = await db.execute(
|
||
select(GenerationJob).where(
|
||
GenerationJob.id == job_id,
|
||
GenerationJob.user_id == user_id,
|
||
)
|
||
)
|
||
job = result.scalar_one_or_none()
|
||
if job is None:
|
||
raise HTTPException(status_code=404, detail="Generation job not found")
|
||
|
||
events = (
|
||
await db.execute(
|
||
select(GenerationJobEvent)
|
||
.where(GenerationJobEvent.job_id == job.id)
|
||
.order_by(GenerationJobEvent.id)
|
||
)
|
||
).scalars().all()
|
||
|
||
return {
|
||
**public_generation_job_to_summary(job),
|
||
"request_payload": public_generation_request_payload(job),
|
||
"events": [
|
||
response
|
||
for event in events
|
||
if (response := public_generation_event_to_response(event)) is not None
|
||
],
|
||
}
|
||
|
||
|
||
async def list_story_generation_jobs(
|
||
db: AsyncSession,
|
||
*,
|
||
story_id: int,
|
||
user_id: str,
|
||
) -> list[dict[str, Any]]:
|
||
"""Return recent generation jobs for a user-owned story."""
|
||
|
||
jobs = (
|
||
await db.execute(
|
||
select(GenerationJob)
|
||
.where(
|
||
GenerationJob.story_id == story_id,
|
||
GenerationJob.user_id == user_id,
|
||
)
|
||
.order_by(desc(GenerationJob.created_at), desc(GenerationJob.id))
|
||
)
|
||
).scalars().all()
|
||
return [public_generation_job_to_summary(job) for job in jobs]
|
||
|
||
|
||
async def get_active_story_generation_job(
|
||
db: AsyncSession,
|
||
*,
|
||
story_id: int,
|
||
user_id: str,
|
||
) -> GenerationJob | None:
|
||
"""Return the most recent running job for a story, if any."""
|
||
|
||
result = await db.execute(
|
||
select(GenerationJob)
|
||
.where(
|
||
GenerationJob.story_id == story_id,
|
||
GenerationJob.user_id == user_id,
|
||
GenerationJob.status == "running",
|
||
)
|
||
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
|
||
.limit(1)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
|
||
async def ensure_no_active_story_generation_job(
|
||
db: AsyncSession,
|
||
*,
|
||
story_id: int,
|
||
user_id: str,
|
||
) -> None:
|
||
"""Prevent duplicate asset work while a story already has a running job."""
|
||
|
||
active_job = await get_active_story_generation_job(db, story_id=story_id, user_id=user_id)
|
||
if active_job is None:
|
||
return
|
||
|
||
progress = _job_progress(active_job)
|
||
raise HTTPException(
|
||
status_code=409,
|
||
detail=(
|
||
f"当前故事已有运行中的任务({progress['progress_label']}),"
|
||
"请等待当前任务完成后再试。"
|
||
),
|
||
)
|
||
|
||
|
||
def _as_float(value: Any) -> float | None:
|
||
if isinstance(value, int | float):
|
||
return float(value)
|
||
return None
|
||
|
||
|
||
def _sorted_buckets(counts: dict[str, int]) -> list[dict[str, Any]]:
|
||
return [
|
||
{"name": name, "count": count}
|
||
for name, count in sorted(
|
||
counts.items(),
|
||
key=lambda item: (-item[1], item[0]),
|
||
)
|
||
]
|
||
|
||
|
||
def _aggregate_trace_events(events: list[GenerationJobEvent]) -> dict[str, Any]:
|
||
"""Aggregate workflow trace metadata across job events."""
|
||
|
||
by_step: dict[str, int] = {}
|
||
by_artifact: dict[str, int] = {}
|
||
failure_categories: dict[str, int] = {}
|
||
failed_events = 0
|
||
total_events = 0
|
||
|
||
for event in events:
|
||
if event.event_type in {"evaluation_completed", "executor_completed"}:
|
||
continue
|
||
|
||
total_events += 1
|
||
metadata = event.event_metadata or {}
|
||
step = metadata.get("step")
|
||
artifact = metadata.get("artifact")
|
||
failure_category = metadata.get("failure_category")
|
||
|
||
if isinstance(step, str) and step:
|
||
by_step[step] = by_step.get(step, 0) + 1
|
||
|
||
if isinstance(artifact, str) and artifact and artifact != "none":
|
||
by_artifact[artifact] = by_artifact.get(artifact, 0) + 1
|
||
|
||
if event.status == "failed":
|
||
failed_events += 1
|
||
category = (
|
||
failure_category
|
||
if isinstance(failure_category, str) and failure_category
|
||
else "unknown_error"
|
||
)
|
||
failure_categories[category] = failure_categories.get(category, 0) + 1
|
||
|
||
return {
|
||
"total_events": total_events,
|
||
"failed_events": failed_events,
|
||
"by_step": _sorted_buckets(by_step),
|
||
"by_artifact": _sorted_buckets(by_artifact),
|
||
"failure_categories": _sorted_buckets(failure_categories),
|
||
}
|
||
|
||
|
||
def _aggregate_provider_events(
|
||
events: list[GenerationJobEvent],
|
||
*,
|
||
capability: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""Aggregate provider telemetry from provider call events."""
|
||
|
||
by_key: dict[tuple[str, str], dict[str, Any]] = {}
|
||
failure_reasons: dict[str, int] = {}
|
||
total_latency = 0.0
|
||
latency_count = 0
|
||
total_cost = 0.0
|
||
successful_calls = 0
|
||
failed_calls = 0
|
||
|
||
for event in events:
|
||
metadata = event.event_metadata or {}
|
||
event_capability = str(metadata.get("capability") or "unknown")
|
||
if capability is not None and event_capability != capability:
|
||
continue
|
||
|
||
adapter = str(metadata.get("adapter") or "unknown")
|
||
key = (event_capability, adapter)
|
||
bucket = by_key.setdefault(
|
||
key,
|
||
{
|
||
"capability": event_capability,
|
||
"adapter": adapter,
|
||
"call_count": 0,
|
||
"success_count": 0,
|
||
"failure_count": 0,
|
||
"latency_total": 0.0,
|
||
"latency_count": 0,
|
||
"estimated_cost_usd": 0.0,
|
||
},
|
||
)
|
||
|
||
bucket["call_count"] += 1
|
||
latency = _as_float(metadata.get("latency_ms"))
|
||
if latency is not None:
|
||
bucket["latency_total"] += latency
|
||
bucket["latency_count"] += 1
|
||
total_latency += latency
|
||
latency_count += 1
|
||
|
||
if event.event_type == "provider_call_succeeded":
|
||
bucket["success_count"] += 1
|
||
successful_calls += 1
|
||
cost = _as_float(metadata.get("estimated_cost_usd")) or 0.0
|
||
bucket["estimated_cost_usd"] += cost
|
||
total_cost += cost
|
||
else:
|
||
bucket["failure_count"] += 1
|
||
failed_calls += 1
|
||
reason = str(metadata.get("error") or "unknown_error")
|
||
failure_reasons[reason] = failure_reasons.get(reason, 0) + 1
|
||
|
||
by_provider = []
|
||
for bucket in by_key.values():
|
||
bucket_latency_count = bucket.pop("latency_count")
|
||
bucket_latency_total = bucket.pop("latency_total")
|
||
by_provider.append(
|
||
{
|
||
**bucket,
|
||
"avg_latency_ms": (
|
||
round(bucket_latency_total / bucket_latency_count, 2)
|
||
if bucket_latency_count
|
||
else None
|
||
),
|
||
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
|
||
}
|
||
)
|
||
|
||
by_provider.sort(
|
||
key=lambda item: (
|
||
str(item["capability"]),
|
||
str(item["adapter"]),
|
||
)
|
||
)
|
||
|
||
return {
|
||
"total_calls": successful_calls + failed_calls,
|
||
"successful_calls": successful_calls,
|
||
"failed_calls": failed_calls,
|
||
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
|
||
"estimated_cost_usd": round(total_cost, 6),
|
||
"by_provider": by_provider,
|
||
"failure_reasons": [
|
||
{"reason": reason, "count": count}
|
||
for reason, count in sorted(
|
||
failure_reasons.items(),
|
||
key=lambda item: (-item[1], item[0]),
|
||
)
|
||
],
|
||
}
|
||
|
||
|
||
def _event_matches_capability(
|
||
event: GenerationJobEvent,
|
||
capability: str | None = None,
|
||
) -> bool:
|
||
event_capability = str((event.event_metadata or {}).get("capability") or "unknown")
|
||
return capability is None or event_capability == capability
|
||
|
||
|
||
def _provider_events_query(
|
||
*,
|
||
user_id: str | None = None,
|
||
story_id: int | None = None,
|
||
days: int | None = None,
|
||
):
|
||
query = (
|
||
select(
|
||
GenerationJobEvent,
|
||
GenerationJob.user_id,
|
||
GenerationJob.story_id,
|
||
)
|
||
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
|
||
.where(
|
||
GenerationJobEvent.event_type.in_(
|
||
["provider_call_succeeded", "provider_call_failed"]
|
||
),
|
||
)
|
||
)
|
||
|
||
if user_id is not None:
|
||
query = query.where(GenerationJob.user_id == user_id)
|
||
|
||
if story_id is not None:
|
||
query = query.where(GenerationJob.story_id == story_id)
|
||
|
||
if days is not None:
|
||
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
|
||
query = query.where(GenerationJobEvent.created_at >= cutoff)
|
||
|
||
return query.order_by(GenerationJobEvent.id)
|
||
|
||
|
||
async def get_story_provider_stats(
|
||
db: AsyncSession,
|
||
*,
|
||
story_id: int,
|
||
user_id: str,
|
||
days: int | None = None,
|
||
capability: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
|
||
|
||
events = (
|
||
await db.execute(
|
||
_provider_events_query(
|
||
user_id=user_id,
|
||
story_id=story_id,
|
||
days=days,
|
||
)
|
||
)
|
||
).scalars().all()
|
||
|
||
return {
|
||
"story_id": story_id,
|
||
"window_days": days,
|
||
"capability": capability,
|
||
**_aggregate_provider_events(events, capability=capability),
|
||
}
|
||
|
||
|
||
async def get_story_trace_summary(
|
||
db: AsyncSession,
|
||
*,
|
||
story_id: int,
|
||
user_id: str,
|
||
days: int | None = None,
|
||
) -> dict[str, Any]:
|
||
"""Aggregate workflow trace metadata from all user-owned jobs for one story."""
|
||
|
||
query = (
|
||
select(GenerationJobEvent)
|
||
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
|
||
.where(
|
||
GenerationJob.story_id == story_id,
|
||
GenerationJob.user_id == user_id,
|
||
)
|
||
.order_by(GenerationJobEvent.id)
|
||
)
|
||
|
||
if days is not None:
|
||
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
|
||
query = query.where(GenerationJobEvent.created_at >= cutoff)
|
||
|
||
events = (await db.execute(query)).scalars().all()
|
||
|
||
return {
|
||
"story_id": story_id,
|
||
"window_days": days,
|
||
**_aggregate_trace_events(events),
|
||
}
|
||
|
||
|
||
async def get_user_provider_analytics(
|
||
db: AsyncSession,
|
||
*,
|
||
user_id: str,
|
||
days: int | None = None,
|
||
capability: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""Aggregate provider telemetry across all stories owned by one user."""
|
||
|
||
events = (
|
||
await db.execute(
|
||
_provider_events_query(
|
||
user_id=user_id,
|
||
days=days,
|
||
)
|
||
)
|
||
).scalars().all()
|
||
filtered_event_job_ids = {
|
||
event.job_id
|
||
for event in events
|
||
if _event_matches_capability(event, capability)
|
||
}
|
||
filtered_story_ids = {
|
||
event.story_id
|
||
for event in events
|
||
if event.story_id is not None and _event_matches_capability(event, capability)
|
||
}
|
||
|
||
return {
|
||
"window_days": days,
|
||
"capability": capability,
|
||
**_aggregate_provider_events(events, capability=capability),
|
||
"job_count": len(filtered_event_job_ids),
|
||
"story_count": len(filtered_story_ids),
|
||
}
|
||
|
||
|
||
async def get_user_generation_ops_summary(
|
||
db: AsyncSession,
|
||
*,
|
||
user_id: str,
|
||
hours: int = 24,
|
||
recent_failure_limit: int = 5,
|
||
) -> dict[str, Any]:
|
||
"""Summarize recent generation health for one user."""
|
||
|
||
stale_after_minutes = settings.generation_job_stale_minutes
|
||
recent_cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||
|
||
running_jobs = (
|
||
await db.execute(
|
||
select(GenerationJob)
|
||
.where(
|
||
GenerationJob.user_id == user_id,
|
||
GenerationJob.status == "running",
|
||
)
|
||
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
|
||
)
|
||
).scalars().all()
|
||
|
||
recent_jobs = (
|
||
await db.execute(
|
||
select(GenerationJob, Story.title)
|
||
.outerjoin(Story, Story.id == GenerationJob.story_id)
|
||
.where(
|
||
GenerationJob.user_id == user_id,
|
||
GenerationJob.updated_at >= recent_cutoff,
|
||
)
|
||
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
|
||
)
|
||
).all()
|
||
|
||
recent_failures: list[dict[str, Any]] = []
|
||
failed_jobs = 0
|
||
degraded_jobs = 0
|
||
asset_retry_jobs = 0
|
||
|
||
for job, story_title in recent_jobs:
|
||
if job.status == "failed":
|
||
failed_jobs += 1
|
||
if len(recent_failures) < recent_failure_limit:
|
||
recent_failures.append(
|
||
{
|
||
"job_id": job.id,
|
||
"story_id": job.story_id,
|
||
"story_title": story_title,
|
||
"output_mode": job.output_mode,
|
||
"current_step": job.current_step,
|
||
"error_message": job.error_message,
|
||
"failure_label": _failure_label(job),
|
||
"updated_at": job.updated_at,
|
||
}
|
||
)
|
||
elif job.status == "degraded_completed":
|
||
degraded_jobs += 1
|
||
|
||
if job.output_mode in {"asset_retry", "asset_generation"}:
|
||
asset_retry_jobs += 1
|
||
|
||
return {
|
||
"window_hours": hours,
|
||
"stale_threshold_minutes": stale_after_minutes,
|
||
"active_jobs": len(running_jobs),
|
||
"stale_running_jobs": sum(
|
||
1 for job in running_jobs if _is_stale_job(job, stale_after_minutes=stale_after_minutes)
|
||
),
|
||
"failed_jobs": failed_jobs,
|
||
"degraded_jobs": degraded_jobs,
|
||
"asset_retry_jobs": asset_retry_jobs,
|
||
"recent_failures": recent_failures,
|
||
}
|
||
|
||
|
||
async def mark_stale_generation_jobs(
|
||
db: AsyncSession,
|
||
*,
|
||
stale_after_minutes: int | None = None,
|
||
) -> dict[str, int]:
|
||
"""Mark long-running generation jobs as failed so they no longer appear stuck forever."""
|
||
|
||
threshold = stale_after_minutes or settings.generation_job_stale_minutes
|
||
running_jobs = (
|
||
await db.execute(
|
||
select(GenerationJob)
|
||
.where(GenerationJob.status == "running")
|
||
.order_by(GenerationJob.updated_at, GenerationJob.id)
|
||
)
|
||
).scalars().all()
|
||
|
||
marked_stale = 0
|
||
|
||
for job in running_jobs:
|
||
if not _is_stale_job(job, stale_after_minutes=threshold):
|
||
continue
|
||
|
||
story = None
|
||
if job.story_id is not None:
|
||
story = (
|
||
await db.execute(
|
||
select(Story).where(
|
||
Story.id == job.story_id,
|
||
Story.user_id == job.user_id,
|
||
)
|
||
)
|
||
).scalar_one_or_none()
|
||
|
||
await finish_generation_job(
|
||
db,
|
||
job=job,
|
||
story=story,
|
||
status="failed",
|
||
current_step="generation_stale_failed",
|
||
error_message=f"Generation job exceeded {threshold} minutes without progress.",
|
||
message="Generation job was marked failed after exceeding the stale threshold.",
|
||
metadata={"stale_after_minutes": threshold},
|
||
)
|
||
marked_stale += 1
|
||
logger.warning(
|
||
"generation_job_marked_stale",
|
||
job_id=job.id,
|
||
story_id=job.story_id,
|
||
output_mode=job.output_mode,
|
||
stale_after_minutes=threshold,
|
||
)
|
||
|
||
return {
|
||
"running": len(running_jobs),
|
||
"marked_stale": marked_stale,
|
||
"stale_after_minutes": threshold,
|
||
}
|