Files
dreamweaver/backend/app/services/generation_jobs.py

1091 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Lightweight generation job/event tracking."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any
from fastapi import HTTPException
from sqlalchemy import desc, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.models import (
GenerationJob,
GenerationJobEvent,
Story,
)
logger = get_logger(__name__)
def _is_terminal_status(status: str) -> bool:
return status in {"completed", "degraded_completed", "failed", "canceled"}
def _job_supports_queue_control(job: GenerationJob) -> bool:
return job.output_mode in {"story", "storybook", "asset_generation"}
def generation_job_can_cancel(job: GenerationJob) -> bool:
return (
_job_supports_queue_control(job)
and job.status == "running"
and job.current_step != "cancel_requested"
)
def generation_job_can_retry(job: GenerationJob) -> bool:
return _job_supports_queue_control(job) and job.status in {"failed", "canceled"}
def _story_snapshot(story: Story | None) -> dict[str, Any]:
if story is None:
return {}
return {
"story_id": story.id,
"mode": story.mode,
"generation_status": story.generation_status,
"text_status": story.text_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"retryable_assets": story.retryable_assets,
"last_error": story.last_error,
}
def _job_status_from_story(story: Story) -> str:
if story.generation_status == "failed":
return "failed"
if story.generation_status == "degraded_completed":
return "degraded_completed"
return "completed"
def _job_progress(job: GenerationJob) -> dict[str, Any]:
"""Resolve a compact progress summary for polling-oriented clients."""
if job.status == "failed":
return {
"progress_percent": 100,
"progress_label": "生成失败",
"is_terminal": True,
}
if job.status == "canceled":
return {
"progress_percent": 100,
"progress_label": "已取消",
"is_terminal": True,
}
if job.status in {"completed", "degraded_completed"}:
return {
"progress_percent": 100,
"progress_label": "已完成" if job.status == "completed" else "降级完成",
"is_terminal": True,
}
progress_map: dict[str, tuple[int, str]] = {
"request_accepted": (5, "已接收请求"),
"workflow_planned": (8, "工作流已规划"),
"retry_queued": (8, "重新排队中"),
"worker_started": (12, "后台任务已开始"),
"cancel_requested": (15, "已请求取消"),
"context_prepared": (20, "上下文已准备"),
"narrative_generated": (45, "正文已生成"),
"evaluation_completed": (52, "内容评测已完成"),
"story_saved": (60, "主记录已保存"),
"provider_call_started": (65, "Provider 调用中"),
"provider_call_succeeded": (72, "Provider 调用成功"),
"provider_call_failed": (72, "Provider 调用失败,尝试恢复"),
"cover_image_started": (75, "封面生成中"),
"cover_image_succeeded": (88, "封面已生成"),
"cover_image_failed": (88, "封面生成失败"),
"storybook_images_started": (75, "绘本插图生成中"),
"storybook_cover_image_succeeded": (82, "绘本封面已生成"),
"storybook_cover_image_failed": (82, "绘本封面生成失败"),
"storybook_page_image_succeeded": (86, "分页插图已生成"),
"storybook_page_image_failed": (86, "分页插图生成失败"),
"storybook_images_completed": (92, "绘本插图已完成"),
"audio_started": (75, "音频生成中"),
"audio_cache_hit": (88, "音频缓存已复用"),
"audio_succeeded": (88, "音频已生成"),
"audio_failed": (88, "音频生成失败"),
"asset_retry_started": (25, "资源重试中"),
"postprocessing_queued": (90, "后处理已排队"),
"asset_generation_completed": (100, "资源已完成"),
"asset_retry_completed": (100, "资源重试完成"),
"generation_canceled": (100, "任务已取消"),
"generation_completed": (100, "生成完成"),
"generation_stale_failed": (100, "任务超时已收敛"),
}
percent, label = progress_map.get(job.current_step, (10, "生成处理中"))
return {
"progress_percent": percent,
"progress_label": label,
"is_terminal": percent >= 100,
}
def _normalize_datetime(value: datetime) -> datetime:
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
def _is_stale_job(job: GenerationJob, *, stale_after_minutes: int) -> bool:
cutoff = datetime.now(timezone.utc) - timedelta(minutes=stale_after_minutes)
return job.status == "running" and _normalize_datetime(job.updated_at) <= cutoff
def _failure_label(job: GenerationJob) -> str:
if job.status == "canceled":
return "任务已取消"
if job.current_step == "generation_stale_failed":
return "任务超时"
if job.output_mode == "asset_retry":
return "资源重试失败"
if job.output_mode == "asset_generation":
return "资源生成失败"
return "生成失败"
async def create_generation_job(
db: AsyncSession,
*,
user_id: str,
output_mode: str,
input_type: str,
request_payload: dict[str, Any],
story_id: int | None = None,
) -> GenerationJob:
"""Create a generation job and record its first event."""
job = GenerationJob(
user_id=user_id,
story_id=story_id,
output_mode=output_mode,
input_type=input_type,
status="running",
current_step="request_accepted",
request_payload=request_payload,
result_snapshot={},
)
db.add(job)
await db.flush()
await record_generation_event(
db,
job=job,
story_id=story_id,
event_type="request_accepted",
status="succeeded",
message="Generation request accepted.",
metadata={"output_mode": output_mode, "input_type": input_type},
commit=False,
)
await db.commit()
await db.refresh(job)
return job
async def record_generation_event(
db: AsyncSession,
*,
job: GenerationJob,
event_type: str,
status: str,
story_id: int | None = None,
message: str | None = None,
metadata: dict[str, Any] | None = None,
commit: bool = True,
) -> GenerationJobEvent:
"""Append one event to an existing generation job."""
job.current_step = event_type
if story_id is not None:
job.story_id = story_id
event = GenerationJobEvent(
job_id=job.id,
story_id=story_id if story_id is not None else job.story_id,
event_type=event_type,
status=status,
message=message,
event_metadata=metadata or {},
)
db.add(event)
if commit:
await db.commit()
return event
async def claim_generation_job_for_worker(
db: AsyncSession,
*,
job_id: str,
) -> GenerationJob | None:
"""Claim one queued generation job for worker execution once."""
claim_result = await db.execute(
update(GenerationJob)
.where(
GenerationJob.id == job_id,
GenerationJob.status == "running",
GenerationJob.current_step.in_(["request_accepted", "retry_queued"]),
)
.values(current_step="worker_started")
)
await db.commit()
if not claim_result.rowcount:
return None
result = await db.execute(select(GenerationJob).where(GenerationJob.id == job_id))
job = result.scalar_one_or_none()
if job is None:
return None
await record_generation_event(
db,
job=job,
event_type="worker_started",
status="running",
message="Generation worker started processing the accepted request.",
)
return job
async def finish_generation_job(
db: AsyncSession,
*,
job: GenerationJob,
story: Story | None,
status: str | None = None,
current_step: str,
error_message: str | None = None,
message: str | None = None,
metadata: dict[str, Any] | None = None,
) -> GenerationJob:
"""Mark a generation job as completed/degraded/failed and append a final event."""
job.story_id = story.id if story is not None else job.story_id
job.status = status or (_job_status_from_story(story) if story is not None else "failed")
job.current_step = current_step
job.error_message = error_message
job.result_snapshot = _story_snapshot(story)
await record_generation_event(
db,
job=job,
story_id=job.story_id,
event_type=current_step,
status=job.status,
message=message,
metadata={
**(metadata or {}),
"result_snapshot": job.result_snapshot,
},
commit=False,
)
await db.commit()
await db.refresh(job)
return job
def generation_event_to_response(event: GenerationJobEvent) -> dict[str, Any]:
"""Convert a generation event ORM object to an API response dict."""
return {
"id": event.id,
"job_id": event.job_id,
"story_id": event.story_id,
"event_type": event.event_type,
"status": event.status,
"message": event.message,
"event_metadata": event.event_metadata or {},
"created_at": event.created_at,
}
_PUBLIC_EVENT_METADATA_KEYS = {
"adapter",
"artifact",
"asset",
"assets",
"attempted_cover",
"audio_status",
"blocks_main_result",
"capability",
"completed_pages",
"cover_prompt_present",
"estimated_cost_usd",
"failed_pages",
"failure_category",
"generation_status",
"has_memory_context",
"image_status",
"input_type",
"latency_ms",
"mode",
"output_mode",
"page_count",
"page_number",
"recoverable",
"requested_from_step",
"retryable",
"scope",
"stale_after_minutes",
"status",
"step",
"strategy",
"text_status",
}
_PUBLIC_REQUEST_PAYLOAD_KEYS = {
"assets",
"child_profile_id",
"generate_images",
"input_type",
"output_mode",
"page_count",
"story_id",
"type",
"universe_id",
}
def _public_metadata_value(value: Any) -> Any:
"""Return a JSON-safe public value or None when the value is internal."""
if isinstance(value, str | int | float | bool) or value is None:
return value
if isinstance(value, list):
public_items = [
item
for item in value
if isinstance(item, str | int | float | bool) or item is None
]
return public_items
return None
def public_generation_request_payload(job: GenerationJob) -> dict[str, Any]:
"""Return request payload fields safe for user-facing job details."""
payload = job.request_payload or {}
public_payload: dict[str, Any] = {}
for key in sorted(_PUBLIC_REQUEST_PAYLOAD_KEYS):
if key not in payload:
continue
value = _public_metadata_value(payload[key])
if value is not None:
public_payload[key] = value
return public_payload
def _public_plan_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
"""Expose only coarse workflow plan metadata to user-facing responses."""
plan = metadata.get("plan")
if not isinstance(plan, dict):
return {}
public: dict[str, Any] = {}
mode = plan.get("mode")
if isinstance(mode, str):
public["plan_mode"] = mode
tasks = plan.get("tasks")
if isinstance(tasks, list):
public["planned_task_count"] = len(tasks)
public["recoverable_task_count"] = sum(
1
for task in tasks
if isinstance(task, dict) and task.get("recoverable") is True
)
return public
def public_generation_event_metadata(event: GenerationJobEvent) -> dict[str, Any]:
"""Return event metadata safe for user-facing job event streams."""
metadata = event.event_metadata or {}
public_metadata: dict[str, Any] = {}
for key in sorted(_PUBLIC_EVENT_METADATA_KEYS):
if key not in metadata:
continue
value = _public_metadata_value(metadata[key])
if value is not None:
public_metadata[key] = value
if event.event_type == "workflow_planned":
public_metadata.update(_public_plan_metadata(metadata))
return public_metadata
def public_generation_event_to_response(event: GenerationJobEvent) -> dict[str, Any] | None:
"""Convert a generation event for user-facing APIs with internal data removed."""
if event.event_type in {"evaluation_completed", "executor_completed"}:
return None
response = generation_event_to_response(event)
response["event_metadata"] = public_generation_event_metadata(event)
return response
def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
"""Convert a generation job ORM object to an API summary dict."""
progress = _job_progress(job)
return {
"id": job.id,
"story_id": job.story_id,
"output_mode": job.output_mode,
"input_type": job.input_type,
"status": job.status,
"current_step": job.current_step,
**progress,
"can_cancel": generation_job_can_cancel(job),
"can_retry": generation_job_can_retry(job),
"result_snapshot": job.result_snapshot or {},
"error_message": job.error_message,
"created_at": job.created_at,
"updated_at": job.updated_at,
}
def public_generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
"""Convert a generation job for user-facing APIs with internal steps hidden."""
summary = generation_job_to_summary(job)
if summary["current_step"] == "evaluation_completed":
summary["current_step"] = "narrative_generated"
summary["progress_percent"] = 45
summary["progress_label"] = "正文已生成"
summary["is_terminal"] = False
elif summary["current_step"] == "executor_completed":
summary["current_step"] = "workflow_planned"
summary["progress_percent"] = 8
summary["progress_label"] = "工作流已规划"
summary["is_terminal"] = False
return summary
async def get_generation_job_for_user(
db: AsyncSession,
*,
job_id: str,
user_id: str,
) -> GenerationJob:
"""Load one generation job owned by the current user."""
result = await db.execute(
select(GenerationJob).where(
GenerationJob.id == job_id,
GenerationJob.user_id == user_id,
)
)
job = result.scalar_one_or_none()
if job is None:
raise HTTPException(status_code=404, detail="Generation job not found")
return job
async def request_generation_job_cancel(
db: AsyncSession,
*,
job_id: str,
user_id: str,
) -> dict[str, Any]:
"""Request cancellation for one queued/running generation job."""
job = await get_generation_job_for_user(db, job_id=job_id, user_id=user_id)
if not _job_supports_queue_control(job):
raise HTTPException(status_code=409, detail="当前任务不支持取消")
if job.status == "canceled":
return public_generation_job_to_summary(job)
if _is_terminal_status(job.status):
raise HTTPException(status_code=409, detail="当前任务已终止,无法取消")
if job.current_step == "cancel_requested":
return public_generation_job_to_summary(job)
if job.current_step in {"request_accepted", "retry_queued"}:
story = None
if job.story_id is not None:
story = (
await db.execute(
select(Story).where(
Story.id == job.story_id,
Story.user_id == job.user_id,
)
)
).scalar_one_or_none()
await finish_generation_job(
db,
job=job,
story=story,
status="canceled",
current_step="generation_canceled",
error_message="Generation canceled by user before worker execution started.",
message="Generation job was canceled before worker execution started.",
)
return public_generation_job_to_summary(job)
previous_step = job.current_step
job.error_message = "Cancellation requested by user."
await record_generation_event(
db,
job=job,
story_id=job.story_id,
event_type="cancel_requested",
status="running",
message="Cancellation requested; worker will stop at the next safe checkpoint.",
metadata={"requested_from_step": previous_step},
commit=False,
)
await db.commit()
await db.refresh(job)
return public_generation_job_to_summary(job)
async def get_generation_job_detail(
db: AsyncSession,
*,
job_id: str,
user_id: str,
) -> dict[str, Any]:
"""Return a user-owned generation job with its ordered event stream."""
result = await db.execute(
select(GenerationJob).where(
GenerationJob.id == job_id,
GenerationJob.user_id == user_id,
)
)
job = result.scalar_one_or_none()
if job is None:
raise HTTPException(status_code=404, detail="Generation job not found")
events = (
await db.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
return {
**public_generation_job_to_summary(job),
"request_payload": public_generation_request_payload(job),
"events": [
response
for event in events
if (response := public_generation_event_to_response(event)) is not None
],
}
async def list_story_generation_jobs(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> list[dict[str, Any]]:
"""Return recent generation jobs for a user-owned story."""
jobs = (
await db.execute(
select(GenerationJob)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
)
.order_by(desc(GenerationJob.created_at), desc(GenerationJob.id))
)
).scalars().all()
return [public_generation_job_to_summary(job) for job in jobs]
async def get_active_story_generation_job(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> GenerationJob | None:
"""Return the most recent running job for a story, if any."""
result = await db.execute(
select(GenerationJob)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
GenerationJob.status == "running",
)
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
.limit(1)
)
return result.scalar_one_or_none()
async def ensure_no_active_story_generation_job(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> None:
"""Prevent duplicate asset work while a story already has a running job."""
active_job = await get_active_story_generation_job(db, story_id=story_id, user_id=user_id)
if active_job is None:
return
progress = _job_progress(active_job)
raise HTTPException(
status_code=409,
detail=(
f"当前故事已有运行中的任务({progress['progress_label']}"
"请等待当前任务完成后再试。"
),
)
def _as_float(value: Any) -> float | None:
if isinstance(value, int | float):
return float(value)
return None
def _sorted_buckets(counts: dict[str, int]) -> list[dict[str, Any]]:
return [
{"name": name, "count": count}
for name, count in sorted(
counts.items(),
key=lambda item: (-item[1], item[0]),
)
]
def _aggregate_trace_events(events: list[GenerationJobEvent]) -> dict[str, Any]:
"""Aggregate workflow trace metadata across job events."""
by_step: dict[str, int] = {}
by_artifact: dict[str, int] = {}
failure_categories: dict[str, int] = {}
failed_events = 0
total_events = 0
for event in events:
if event.event_type in {"evaluation_completed", "executor_completed"}:
continue
total_events += 1
metadata = event.event_metadata or {}
step = metadata.get("step")
artifact = metadata.get("artifact")
failure_category = metadata.get("failure_category")
if isinstance(step, str) and step:
by_step[step] = by_step.get(step, 0) + 1
if isinstance(artifact, str) and artifact and artifact != "none":
by_artifact[artifact] = by_artifact.get(artifact, 0) + 1
if event.status == "failed":
failed_events += 1
category = (
failure_category
if isinstance(failure_category, str) and failure_category
else "unknown_error"
)
failure_categories[category] = failure_categories.get(category, 0) + 1
return {
"total_events": total_events,
"failed_events": failed_events,
"by_step": _sorted_buckets(by_step),
"by_artifact": _sorted_buckets(by_artifact),
"failure_categories": _sorted_buckets(failure_categories),
}
def _aggregate_provider_events(
events: list[GenerationJobEvent],
*,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider telemetry from provider call events."""
by_key: dict[tuple[str, str], dict[str, Any]] = {}
failure_reasons: dict[str, int] = {}
total_latency = 0.0
latency_count = 0
total_cost = 0.0
successful_calls = 0
failed_calls = 0
for event in events:
metadata = event.event_metadata or {}
event_capability = str(metadata.get("capability") or "unknown")
if capability is not None and event_capability != capability:
continue
adapter = str(metadata.get("adapter") or "unknown")
key = (event_capability, adapter)
bucket = by_key.setdefault(
key,
{
"capability": event_capability,
"adapter": adapter,
"call_count": 0,
"success_count": 0,
"failure_count": 0,
"latency_total": 0.0,
"latency_count": 0,
"estimated_cost_usd": 0.0,
},
)
bucket["call_count"] += 1
latency = _as_float(metadata.get("latency_ms"))
if latency is not None:
bucket["latency_total"] += latency
bucket["latency_count"] += 1
total_latency += latency
latency_count += 1
if event.event_type == "provider_call_succeeded":
bucket["success_count"] += 1
successful_calls += 1
cost = _as_float(metadata.get("estimated_cost_usd")) or 0.0
bucket["estimated_cost_usd"] += cost
total_cost += cost
else:
bucket["failure_count"] += 1
failed_calls += 1
reason = str(metadata.get("error") or "unknown_error")
failure_reasons[reason] = failure_reasons.get(reason, 0) + 1
by_provider = []
for bucket in by_key.values():
bucket_latency_count = bucket.pop("latency_count")
bucket_latency_total = bucket.pop("latency_total")
by_provider.append(
{
**bucket,
"avg_latency_ms": (
round(bucket_latency_total / bucket_latency_count, 2)
if bucket_latency_count
else None
),
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
}
)
by_provider.sort(
key=lambda item: (
str(item["capability"]),
str(item["adapter"]),
)
)
return {
"total_calls": successful_calls + failed_calls,
"successful_calls": successful_calls,
"failed_calls": failed_calls,
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
"estimated_cost_usd": round(total_cost, 6),
"by_provider": by_provider,
"failure_reasons": [
{"reason": reason, "count": count}
for reason, count in sorted(
failure_reasons.items(),
key=lambda item: (-item[1], item[0]),
)
],
}
def _event_matches_capability(
event: GenerationJobEvent,
capability: str | None = None,
) -> bool:
event_capability = str((event.event_metadata or {}).get("capability") or "unknown")
return capability is None or event_capability == capability
def _provider_events_query(
*,
user_id: str | None = None,
story_id: int | None = None,
days: int | None = None,
):
query = (
select(
GenerationJobEvent,
GenerationJob.user_id,
GenerationJob.story_id,
)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
)
)
if user_id is not None:
query = query.where(GenerationJob.user_id == user_id)
if story_id is not None:
query = query.where(GenerationJob.story_id == story_id)
if days is not None:
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
query = query.where(GenerationJobEvent.created_at >= cutoff)
return query.order_by(GenerationJobEvent.id)
async def get_story_provider_stats(
db: AsyncSession,
*,
story_id: int,
user_id: str,
days: int | None = None,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
events = (
await db.execute(
_provider_events_query(
user_id=user_id,
story_id=story_id,
days=days,
)
)
).scalars().all()
return {
"story_id": story_id,
"window_days": days,
"capability": capability,
**_aggregate_provider_events(events, capability=capability),
}
async def get_story_trace_summary(
db: AsyncSession,
*,
story_id: int,
user_id: str,
days: int | None = None,
) -> dict[str, Any]:
"""Aggregate workflow trace metadata from all user-owned jobs for one story."""
query = (
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
)
.order_by(GenerationJobEvent.id)
)
if days is not None:
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
query = query.where(GenerationJobEvent.created_at >= cutoff)
events = (await db.execute(query)).scalars().all()
return {
"story_id": story_id,
"window_days": days,
**_aggregate_trace_events(events),
}
async def get_user_provider_analytics(
db: AsyncSession,
*,
user_id: str,
days: int | None = None,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider telemetry across all stories owned by one user."""
events = (
await db.execute(
_provider_events_query(
user_id=user_id,
days=days,
)
)
).scalars().all()
filtered_event_job_ids = {
event.job_id
for event in events
if _event_matches_capability(event, capability)
}
filtered_story_ids = {
event.story_id
for event in events
if event.story_id is not None and _event_matches_capability(event, capability)
}
return {
"window_days": days,
"capability": capability,
**_aggregate_provider_events(events, capability=capability),
"job_count": len(filtered_event_job_ids),
"story_count": len(filtered_story_ids),
}
async def get_user_generation_ops_summary(
db: AsyncSession,
*,
user_id: str,
hours: int = 24,
recent_failure_limit: int = 5,
) -> dict[str, Any]:
"""Summarize recent generation health for one user."""
stale_after_minutes = settings.generation_job_stale_minutes
recent_cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
running_jobs = (
await db.execute(
select(GenerationJob)
.where(
GenerationJob.user_id == user_id,
GenerationJob.status == "running",
)
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
)
).scalars().all()
recent_jobs = (
await db.execute(
select(GenerationJob, Story.title)
.outerjoin(Story, Story.id == GenerationJob.story_id)
.where(
GenerationJob.user_id == user_id,
GenerationJob.updated_at >= recent_cutoff,
)
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
)
).all()
recent_failures: list[dict[str, Any]] = []
failed_jobs = 0
degraded_jobs = 0
asset_retry_jobs = 0
for job, story_title in recent_jobs:
if job.status == "failed":
failed_jobs += 1
if len(recent_failures) < recent_failure_limit:
recent_failures.append(
{
"job_id": job.id,
"story_id": job.story_id,
"story_title": story_title,
"output_mode": job.output_mode,
"current_step": job.current_step,
"error_message": job.error_message,
"failure_label": _failure_label(job),
"updated_at": job.updated_at,
}
)
elif job.status == "degraded_completed":
degraded_jobs += 1
if job.output_mode in {"asset_retry", "asset_generation"}:
asset_retry_jobs += 1
return {
"window_hours": hours,
"stale_threshold_minutes": stale_after_minutes,
"active_jobs": len(running_jobs),
"stale_running_jobs": sum(
1 for job in running_jobs if _is_stale_job(job, stale_after_minutes=stale_after_minutes)
),
"failed_jobs": failed_jobs,
"degraded_jobs": degraded_jobs,
"asset_retry_jobs": asset_retry_jobs,
"recent_failures": recent_failures,
}
async def mark_stale_generation_jobs(
db: AsyncSession,
*,
stale_after_minutes: int | None = None,
) -> dict[str, int]:
"""Mark long-running generation jobs as failed so they no longer appear stuck forever."""
threshold = stale_after_minutes or settings.generation_job_stale_minutes
running_jobs = (
await db.execute(
select(GenerationJob)
.where(GenerationJob.status == "running")
.order_by(GenerationJob.updated_at, GenerationJob.id)
)
).scalars().all()
marked_stale = 0
for job in running_jobs:
if not _is_stale_job(job, stale_after_minutes=threshold):
continue
story = None
if job.story_id is not None:
story = (
await db.execute(
select(Story).where(
Story.id == job.story_id,
Story.user_id == job.user_id,
)
)
).scalar_one_or_none()
await finish_generation_job(
db,
job=job,
story=story,
status="failed",
current_step="generation_stale_failed",
error_message=f"Generation job exceeded {threshold} minutes without progress.",
message="Generation job was marked failed after exceeding the stale threshold.",
metadata={"stale_after_minutes": threshold},
)
marked_stale += 1
logger.warning(
"generation_job_marked_stale",
job_id=job.id,
story_id=job.story_id,
output_mode=job.output_mode,
stale_after_minutes=threshold,
)
return {
"running": len(running_jobs),
"marked_stale": marked_stale,
"stale_after_minutes": threshold,
}