Files
dreamweaver/backend/app/services/generation_jobs.py

839 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Lightweight generation job/event tracking."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any
from fastapi import HTTPException
from sqlalchemy import desc, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.models import GenerationJob, GenerationJobEvent, Story
logger = get_logger(__name__)
def _is_terminal_status(status: str) -> bool:
return status in {"completed", "degraded_completed", "failed", "canceled"}
def _job_supports_queue_control(job: GenerationJob) -> bool:
return job.output_mode in {"story", "storybook"}
def generation_job_can_cancel(job: GenerationJob) -> bool:
return (
_job_supports_queue_control(job)
and job.status == "running"
and job.current_step != "cancel_requested"
)
def generation_job_can_retry(job: GenerationJob) -> bool:
return _job_supports_queue_control(job) and job.status in {"failed", "canceled"}
def _story_snapshot(story: Story | None) -> dict[str, Any]:
if story is None:
return {}
return {
"story_id": story.id,
"mode": story.mode,
"generation_status": story.generation_status,
"text_status": story.text_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"retryable_assets": story.retryable_assets,
"last_error": story.last_error,
}
def _job_status_from_story(story: Story) -> str:
if story.generation_status == "failed":
return "failed"
if story.generation_status == "degraded_completed":
return "degraded_completed"
return "completed"
def _job_progress(job: GenerationJob) -> dict[str, Any]:
"""Resolve a compact progress summary for polling-oriented clients."""
if job.status == "failed":
return {
"progress_percent": 100,
"progress_label": "生成失败",
"is_terminal": True,
}
if job.status == "canceled":
return {
"progress_percent": 100,
"progress_label": "已取消",
"is_terminal": True,
}
if job.status in {"completed", "degraded_completed"}:
return {
"progress_percent": 100,
"progress_label": "已完成" if job.status == "completed" else "降级完成",
"is_terminal": True,
}
progress_map: dict[str, tuple[int, str]] = {
"request_accepted": (5, "已接收请求"),
"retry_queued": (8, "重新排队中"),
"worker_started": (12, "后台任务已开始"),
"cancel_requested": (15, "已请求取消"),
"context_prepared": (20, "上下文已准备"),
"narrative_generated": (45, "正文已生成"),
"story_saved": (60, "主记录已保存"),
"provider_call_started": (65, "Provider 调用中"),
"provider_call_succeeded": (72, "Provider 调用成功"),
"provider_call_failed": (72, "Provider 调用失败,尝试恢复"),
"cover_image_started": (75, "封面生成中"),
"cover_image_succeeded": (88, "封面已生成"),
"cover_image_failed": (88, "封面生成失败"),
"storybook_images_started": (75, "绘本插图生成中"),
"storybook_cover_image_succeeded": (82, "绘本封面已生成"),
"storybook_cover_image_failed": (82, "绘本封面生成失败"),
"storybook_page_image_succeeded": (86, "分页插图已生成"),
"storybook_page_image_failed": (86, "分页插图生成失败"),
"storybook_images_completed": (92, "绘本插图已完成"),
"audio_started": (75, "音频生成中"),
"audio_cache_hit": (88, "音频缓存已复用"),
"audio_succeeded": (88, "音频已生成"),
"audio_failed": (88, "音频生成失败"),
"asset_retry_started": (25, "资源重试中"),
"postprocessing_queued": (90, "后处理已排队"),
"asset_generation_completed": (100, "资源已完成"),
"asset_retry_completed": (100, "资源重试完成"),
"generation_canceled": (100, "任务已取消"),
"generation_completed": (100, "生成完成"),
"generation_stale_failed": (100, "任务超时已收敛"),
}
percent, label = progress_map.get(job.current_step, (10, "生成处理中"))
return {
"progress_percent": percent,
"progress_label": label,
"is_terminal": percent >= 100,
}
def _normalize_datetime(value: datetime) -> datetime:
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
def _is_stale_job(job: GenerationJob, *, stale_after_minutes: int) -> bool:
cutoff = datetime.now(timezone.utc) - timedelta(minutes=stale_after_minutes)
return job.status == "running" and _normalize_datetime(job.updated_at) <= cutoff
def _failure_label(job: GenerationJob) -> str:
if job.status == "canceled":
return "任务已取消"
if job.current_step == "generation_stale_failed":
return "任务超时"
if job.output_mode == "asset_retry":
return "资源重试失败"
if job.output_mode == "asset_generation":
return "资源生成失败"
return "生成失败"
async def create_generation_job(
db: AsyncSession,
*,
user_id: str,
output_mode: str,
input_type: str,
request_payload: dict[str, Any],
story_id: int | None = None,
) -> GenerationJob:
"""Create a generation job and record its first event."""
job = GenerationJob(
user_id=user_id,
story_id=story_id,
output_mode=output_mode,
input_type=input_type,
status="running",
current_step="request_accepted",
request_payload=request_payload,
result_snapshot={},
)
db.add(job)
await db.flush()
await record_generation_event(
db,
job=job,
story_id=story_id,
event_type="request_accepted",
status="succeeded",
message="Generation request accepted.",
metadata={"output_mode": output_mode, "input_type": input_type},
commit=False,
)
await db.commit()
await db.refresh(job)
return job
async def record_generation_event(
db: AsyncSession,
*,
job: GenerationJob,
event_type: str,
status: str,
story_id: int | None = None,
message: str | None = None,
metadata: dict[str, Any] | None = None,
commit: bool = True,
) -> GenerationJobEvent:
"""Append one event to an existing generation job."""
job.current_step = event_type
if story_id is not None:
job.story_id = story_id
event = GenerationJobEvent(
job_id=job.id,
story_id=story_id if story_id is not None else job.story_id,
event_type=event_type,
status=status,
message=message,
event_metadata=metadata or {},
)
db.add(event)
if commit:
await db.commit()
return event
async def claim_generation_job_for_worker(
db: AsyncSession,
*,
job_id: str,
) -> GenerationJob | None:
"""Claim one queued generation job for worker execution once."""
claim_result = await db.execute(
update(GenerationJob)
.where(
GenerationJob.id == job_id,
GenerationJob.status == "running",
GenerationJob.current_step.in_(["request_accepted", "retry_queued"]),
)
.values(current_step="worker_started")
)
await db.commit()
if not claim_result.rowcount:
return None
result = await db.execute(select(GenerationJob).where(GenerationJob.id == job_id))
job = result.scalar_one_or_none()
if job is None:
return None
await record_generation_event(
db,
job=job,
event_type="worker_started",
status="running",
message="Generation worker started processing the accepted request.",
)
return job
async def finish_generation_job(
db: AsyncSession,
*,
job: GenerationJob,
story: Story | None,
status: str | None = None,
current_step: str,
error_message: str | None = None,
message: str | None = None,
metadata: dict[str, Any] | None = None,
) -> GenerationJob:
"""Mark a generation job as completed/degraded/failed and append a final event."""
job.story_id = story.id if story is not None else job.story_id
job.status = status or (_job_status_from_story(story) if story is not None else "failed")
job.current_step = current_step
job.error_message = error_message
job.result_snapshot = _story_snapshot(story)
await record_generation_event(
db,
job=job,
story_id=job.story_id,
event_type=current_step,
status=job.status,
message=message,
metadata={
**(metadata or {}),
"result_snapshot": job.result_snapshot,
},
commit=False,
)
await db.commit()
await db.refresh(job)
return job
def generation_event_to_response(event: GenerationJobEvent) -> dict[str, Any]:
"""Convert a generation event ORM object to an API response dict."""
return {
"id": event.id,
"job_id": event.job_id,
"story_id": event.story_id,
"event_type": event.event_type,
"status": event.status,
"message": event.message,
"event_metadata": event.event_metadata or {},
"created_at": event.created_at,
}
def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
"""Convert a generation job ORM object to an API summary dict."""
progress = _job_progress(job)
return {
"id": job.id,
"story_id": job.story_id,
"output_mode": job.output_mode,
"input_type": job.input_type,
"status": job.status,
"current_step": job.current_step,
**progress,
"can_cancel": generation_job_can_cancel(job),
"can_retry": generation_job_can_retry(job),
"result_snapshot": job.result_snapshot or {},
"error_message": job.error_message,
"created_at": job.created_at,
"updated_at": job.updated_at,
}
async def get_generation_job_for_user(
db: AsyncSession,
*,
job_id: str,
user_id: str,
) -> GenerationJob:
"""Load one generation job owned by the current user."""
result = await db.execute(
select(GenerationJob).where(
GenerationJob.id == job_id,
GenerationJob.user_id == user_id,
)
)
job = result.scalar_one_or_none()
if job is None:
raise HTTPException(status_code=404, detail="Generation job not found")
return job
async def request_generation_job_cancel(
db: AsyncSession,
*,
job_id: str,
user_id: str,
) -> dict[str, Any]:
"""Request cancellation for one queued/running generation job."""
job = await get_generation_job_for_user(db, job_id=job_id, user_id=user_id)
if not _job_supports_queue_control(job):
raise HTTPException(status_code=409, detail="当前任务不支持取消")
if job.status == "canceled":
return generation_job_to_summary(job)
if _is_terminal_status(job.status):
raise HTTPException(status_code=409, detail="当前任务已终止,无法取消")
if job.current_step == "cancel_requested":
return generation_job_to_summary(job)
if job.current_step in {"request_accepted", "retry_queued"}:
story = None
if job.story_id is not None:
story = (
await db.execute(
select(Story).where(
Story.id == job.story_id,
Story.user_id == job.user_id,
)
)
).scalar_one_or_none()
await finish_generation_job(
db,
job=job,
story=story,
status="canceled",
current_step="generation_canceled",
error_message="Generation canceled by user before worker execution started.",
message="Generation job was canceled before worker execution started.",
)
return generation_job_to_summary(job)
previous_step = job.current_step
job.error_message = "Cancellation requested by user."
await record_generation_event(
db,
job=job,
story_id=job.story_id,
event_type="cancel_requested",
status="running",
message="Cancellation requested; worker will stop at the next safe checkpoint.",
metadata={"requested_from_step": previous_step},
commit=False,
)
await db.commit()
await db.refresh(job)
return generation_job_to_summary(job)
async def get_generation_job_detail(
db: AsyncSession,
*,
job_id: str,
user_id: str,
) -> dict[str, Any]:
"""Return a user-owned generation job with its ordered event stream."""
result = await db.execute(
select(GenerationJob).where(
GenerationJob.id == job_id,
GenerationJob.user_id == user_id,
)
)
job = result.scalar_one_or_none()
if job is None:
raise HTTPException(status_code=404, detail="Generation job not found")
events = (
await db.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
return {
**generation_job_to_summary(job),
"request_payload": job.request_payload or {},
"events": [generation_event_to_response(event) for event in events],
}
async def list_story_generation_jobs(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> list[dict[str, Any]]:
"""Return recent generation jobs for a user-owned story."""
jobs = (
await db.execute(
select(GenerationJob)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
)
.order_by(desc(GenerationJob.created_at), desc(GenerationJob.id))
)
).scalars().all()
return [generation_job_to_summary(job) for job in jobs]
async def get_active_story_generation_job(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> GenerationJob | None:
"""Return the most recent running job for a story, if any."""
result = await db.execute(
select(GenerationJob)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
GenerationJob.status == "running",
)
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
.limit(1)
)
return result.scalar_one_or_none()
async def ensure_no_active_story_generation_job(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> None:
"""Prevent duplicate asset work while a story already has a running job."""
active_job = await get_active_story_generation_job(db, story_id=story_id, user_id=user_id)
if active_job is None:
return
progress = _job_progress(active_job)
raise HTTPException(
status_code=409,
detail=(
f"当前故事已有运行中的任务({progress['progress_label']}"
"请等待当前任务完成后再试。"
),
)
def _as_float(value: Any) -> float | None:
if isinstance(value, int | float):
return float(value)
return None
def _aggregate_provider_events(
events: list[GenerationJobEvent],
*,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider telemetry from provider call events."""
by_key: dict[tuple[str, str], dict[str, Any]] = {}
failure_reasons: dict[str, int] = {}
total_latency = 0.0
latency_count = 0
total_cost = 0.0
successful_calls = 0
failed_calls = 0
for event in events:
metadata = event.event_metadata or {}
event_capability = str(metadata.get("capability") or "unknown")
if capability is not None and event_capability != capability:
continue
adapter = str(metadata.get("adapter") or "unknown")
key = (event_capability, adapter)
bucket = by_key.setdefault(
key,
{
"capability": event_capability,
"adapter": adapter,
"call_count": 0,
"success_count": 0,
"failure_count": 0,
"latency_total": 0.0,
"latency_count": 0,
"estimated_cost_usd": 0.0,
},
)
bucket["call_count"] += 1
latency = _as_float(metadata.get("latency_ms"))
if latency is not None:
bucket["latency_total"] += latency
bucket["latency_count"] += 1
total_latency += latency
latency_count += 1
if event.event_type == "provider_call_succeeded":
bucket["success_count"] += 1
successful_calls += 1
cost = _as_float(metadata.get("estimated_cost_usd")) or 0.0
bucket["estimated_cost_usd"] += cost
total_cost += cost
else:
bucket["failure_count"] += 1
failed_calls += 1
reason = str(metadata.get("error") or "unknown_error")
failure_reasons[reason] = failure_reasons.get(reason, 0) + 1
by_provider = []
for bucket in by_key.values():
bucket_latency_count = bucket.pop("latency_count")
bucket_latency_total = bucket.pop("latency_total")
by_provider.append(
{
**bucket,
"avg_latency_ms": (
round(bucket_latency_total / bucket_latency_count, 2)
if bucket_latency_count
else None
),
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
}
)
by_provider.sort(
key=lambda item: (
str(item["capability"]),
str(item["adapter"]),
)
)
return {
"total_calls": successful_calls + failed_calls,
"successful_calls": successful_calls,
"failed_calls": failed_calls,
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
"estimated_cost_usd": round(total_cost, 6),
"by_provider": by_provider,
"failure_reasons": [
{"reason": reason, "count": count}
for reason, count in sorted(
failure_reasons.items(),
key=lambda item: (-item[1], item[0]),
)
],
}
def _provider_events_query(
*,
user_id: str,
story_id: int | None = None,
days: int | None = None,
):
query = (
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
)
)
if story_id is not None:
query = query.where(GenerationJob.story_id == story_id)
if days is not None:
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
query = query.where(GenerationJobEvent.created_at >= cutoff)
return query.order_by(GenerationJobEvent.id)
async def get_story_provider_stats(
db: AsyncSession,
*,
story_id: int,
user_id: str,
days: int | None = None,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
events = (
await db.execute(
_provider_events_query(
user_id=user_id,
story_id=story_id,
days=days,
)
)
).scalars().all()
return {
"story_id": story_id,
"window_days": days,
"capability": capability,
**_aggregate_provider_events(events, capability=capability),
}
async def get_user_provider_analytics(
db: AsyncSession,
*,
user_id: str,
days: int | None = None,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider telemetry across all stories owned by one user."""
events = (
await db.execute(
_provider_events_query(
user_id=user_id,
days=days,
)
)
).scalars().all()
filtered_event_job_ids = {
event.job_id
for event in events
if capability is None
or str((event.event_metadata or {}).get("capability") or "unknown") == capability
}
filtered_story_ids = {
event.story_id
for event in events
if event.story_id is not None
and (
capability is None
or str((event.event_metadata or {}).get("capability") or "unknown") == capability
)
}
return {
"window_days": days,
"capability": capability,
**_aggregate_provider_events(events, capability=capability),
"job_count": len(filtered_event_job_ids),
"story_count": len(filtered_story_ids),
}
async def get_user_generation_ops_summary(
db: AsyncSession,
*,
user_id: str,
hours: int = 24,
recent_failure_limit: int = 5,
) -> dict[str, Any]:
"""Summarize recent generation health for one user."""
stale_after_minutes = settings.generation_job_stale_minutes
recent_cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
running_jobs = (
await db.execute(
select(GenerationJob)
.where(
GenerationJob.user_id == user_id,
GenerationJob.status == "running",
)
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
)
).scalars().all()
recent_jobs = (
await db.execute(
select(GenerationJob, Story.title)
.outerjoin(Story, Story.id == GenerationJob.story_id)
.where(
GenerationJob.user_id == user_id,
GenerationJob.updated_at >= recent_cutoff,
)
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
)
).all()
recent_failures: list[dict[str, Any]] = []
failed_jobs = 0
degraded_jobs = 0
asset_retry_jobs = 0
for job, story_title in recent_jobs:
if job.status == "failed":
failed_jobs += 1
if len(recent_failures) < recent_failure_limit:
recent_failures.append(
{
"job_id": job.id,
"story_id": job.story_id,
"story_title": story_title,
"output_mode": job.output_mode,
"current_step": job.current_step,
"error_message": job.error_message,
"failure_label": _failure_label(job),
"updated_at": job.updated_at,
}
)
elif job.status == "degraded_completed":
degraded_jobs += 1
if job.output_mode in {"asset_retry", "asset_generation"}:
asset_retry_jobs += 1
return {
"window_hours": hours,
"stale_threshold_minutes": stale_after_minutes,
"active_jobs": len(running_jobs),
"stale_running_jobs": sum(
1 for job in running_jobs if _is_stale_job(job, stale_after_minutes=stale_after_minutes)
),
"failed_jobs": failed_jobs,
"degraded_jobs": degraded_jobs,
"asset_retry_jobs": asset_retry_jobs,
"recent_failures": recent_failures,
}
async def mark_stale_generation_jobs(
db: AsyncSession,
*,
stale_after_minutes: int | None = None,
) -> dict[str, int]:
"""Mark long-running generation jobs as failed so they no longer appear stuck forever."""
threshold = stale_after_minutes or settings.generation_job_stale_minutes
running_jobs = (
await db.execute(
select(GenerationJob)
.where(GenerationJob.status == "running")
.order_by(GenerationJob.updated_at, GenerationJob.id)
)
).scalars().all()
marked_stale = 0
for job in running_jobs:
if not _is_stale_job(job, stale_after_minutes=threshold):
continue
story = None
if job.story_id is not None:
story = (
await db.execute(
select(Story).where(
Story.id == job.story_id,
Story.user_id == job.user_id,
)
)
).scalar_one_or_none()
await finish_generation_job(
db,
job=job,
story=story,
status="failed",
current_step="generation_stale_failed",
error_message=f"Generation job exceeded {threshold} minutes without progress.",
message="Generation job was marked failed after exceeding the stale threshold.",
metadata={"stale_after_minutes": threshold},
)
marked_stale += 1
logger.warning(
"generation_job_marked_stale",
job_id=job.id,
story_id=job.story_id,
output_mode=job.output_mode,
stale_after_minutes=threshold,
)
return {
"running": len(running_jobs),
"marked_stale": marked_stale,
"stale_after_minutes": threshold,
}