"""Lightweight generation job/event tracking.""" from __future__ import annotations from datetime import datetime, timedelta, timezone from typing import Any from fastapi import HTTPException from sqlalchemy import desc, select from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.logging import get_logger from app.db.models import GenerationJob, GenerationJobEvent, Story logger = get_logger(__name__) def _story_snapshot(story: Story | None) -> dict[str, Any]: if story is None: return {} return { "story_id": story.id, "mode": story.mode, "generation_status": story.generation_status, "text_status": story.text_status, "image_status": story.image_status, "audio_status": story.audio_status, "retryable_assets": story.retryable_assets, "last_error": story.last_error, } def _job_status_from_story(story: Story) -> str: if story.generation_status == "failed": return "failed" if story.generation_status == "degraded_completed": return "degraded_completed" return "completed" def _job_progress(job: GenerationJob) -> dict[str, Any]: """Resolve a compact progress summary for polling-oriented clients.""" if job.status == "failed": return { "progress_percent": 100, "progress_label": "生成失败", "is_terminal": True, } if job.status in {"completed", "degraded_completed"}: return { "progress_percent": 100, "progress_label": "已完成" if job.status == "completed" else "降级完成", "is_terminal": True, } progress_map: dict[str, tuple[int, str]] = { "request_accepted": (5, "已接收请求"), "context_prepared": (20, "上下文已准备"), "narrative_generated": (45, "正文已生成"), "story_saved": (60, "主记录已保存"), "provider_call_started": (65, "Provider 调用中"), "provider_call_succeeded": (72, "Provider 调用成功"), "provider_call_failed": (72, "Provider 调用失败,尝试恢复"), "cover_image_started": (75, "封面生成中"), "storybook_images_started": (75, "绘本插图生成中"), "audio_started": (75, "音频生成中"), "asset_retry_started": (25, "资源重试中"), "postprocessing_queued": (90, "后处理已排队"), "asset_generation_completed": (100, "资源已完成"), "asset_retry_completed": (100, "资源重试完成"), "generation_completed": (100, "生成完成"), "generation_stale_failed": (100, "任务超时已收敛"), } percent, label = progress_map.get(job.current_step, (10, "生成处理中")) return { "progress_percent": percent, "progress_label": label, "is_terminal": percent >= 100, } def _normalize_datetime(value: datetime) -> datetime: if value.tzinfo is None: return value.replace(tzinfo=timezone.utc) return value.astimezone(timezone.utc) def _is_stale_job(job: GenerationJob, *, stale_after_minutes: int) -> bool: cutoff = datetime.now(timezone.utc) - timedelta(minutes=stale_after_minutes) return job.status == "running" and _normalize_datetime(job.updated_at) <= cutoff def _failure_label(job: GenerationJob) -> str: if job.current_step == "generation_stale_failed": return "任务超时" if job.output_mode == "asset_retry": return "资源重试失败" if job.output_mode == "asset_generation": return "资源生成失败" return "生成失败" async def create_generation_job( db: AsyncSession, *, user_id: str, output_mode: str, input_type: str, request_payload: dict[str, Any], story_id: int | None = None, ) -> GenerationJob: """Create a generation job and record its first event.""" job = GenerationJob( user_id=user_id, story_id=story_id, output_mode=output_mode, input_type=input_type, status="running", current_step="request_accepted", request_payload=request_payload, result_snapshot={}, ) db.add(job) await db.flush() await record_generation_event( db, job=job, story_id=story_id, event_type="request_accepted", status="succeeded", message="Generation request accepted.", metadata={"output_mode": output_mode, "input_type": input_type}, commit=False, ) await db.commit() await db.refresh(job) return job async def record_generation_event( db: AsyncSession, *, job: GenerationJob, event_type: str, status: str, story_id: int | None = None, message: str | None = None, metadata: dict[str, Any] | None = None, commit: bool = True, ) -> GenerationJobEvent: """Append one event to an existing generation job.""" event = GenerationJobEvent( job_id=job.id, story_id=story_id if story_id is not None else job.story_id, event_type=event_type, status=status, message=message, event_metadata=metadata or {}, ) db.add(event) if commit: await db.commit() return event async def finish_generation_job( db: AsyncSession, *, job: GenerationJob, story: Story | None, status: str | None = None, current_step: str, error_message: str | None = None, message: str | None = None, metadata: dict[str, Any] | None = None, ) -> GenerationJob: """Mark a generation job as completed/degraded/failed and append a final event.""" job.story_id = story.id if story is not None else job.story_id job.status = status or (_job_status_from_story(story) if story is not None else "failed") job.current_step = current_step job.error_message = error_message job.result_snapshot = _story_snapshot(story) await record_generation_event( db, job=job, story_id=job.story_id, event_type=current_step, status=job.status, message=message, metadata={ **(metadata or {}), "result_snapshot": job.result_snapshot, }, commit=False, ) await db.commit() await db.refresh(job) return job def generation_event_to_response(event: GenerationJobEvent) -> dict[str, Any]: """Convert a generation event ORM object to an API response dict.""" return { "id": event.id, "job_id": event.job_id, "story_id": event.story_id, "event_type": event.event_type, "status": event.status, "message": event.message, "event_metadata": event.event_metadata or {}, "created_at": event.created_at, } def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]: """Convert a generation job ORM object to an API summary dict.""" progress = _job_progress(job) return { "id": job.id, "story_id": job.story_id, "output_mode": job.output_mode, "input_type": job.input_type, "status": job.status, "current_step": job.current_step, **progress, "result_snapshot": job.result_snapshot or {}, "error_message": job.error_message, "created_at": job.created_at, "updated_at": job.updated_at, } async def get_generation_job_detail( db: AsyncSession, *, job_id: str, user_id: str, ) -> dict[str, Any]: """Return a user-owned generation job with its ordered event stream.""" result = await db.execute( select(GenerationJob).where( GenerationJob.id == job_id, GenerationJob.user_id == user_id, ) ) job = result.scalar_one_or_none() if job is None: raise HTTPException(status_code=404, detail="Generation job not found") events = ( await db.execute( select(GenerationJobEvent) .where(GenerationJobEvent.job_id == job.id) .order_by(GenerationJobEvent.id) ) ).scalars().all() return { **generation_job_to_summary(job), "request_payload": job.request_payload or {}, "events": [generation_event_to_response(event) for event in events], } async def list_story_generation_jobs( db: AsyncSession, *, story_id: int, user_id: str, ) -> list[dict[str, Any]]: """Return recent generation jobs for a user-owned story.""" jobs = ( await db.execute( select(GenerationJob) .where( GenerationJob.story_id == story_id, GenerationJob.user_id == user_id, ) .order_by(desc(GenerationJob.created_at), desc(GenerationJob.id)) ) ).scalars().all() return [generation_job_to_summary(job) for job in jobs] async def get_active_story_generation_job( db: AsyncSession, *, story_id: int, user_id: str, ) -> GenerationJob | None: """Return the most recent running job for a story, if any.""" result = await db.execute( select(GenerationJob) .where( GenerationJob.story_id == story_id, GenerationJob.user_id == user_id, GenerationJob.status == "running", ) .order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id)) .limit(1) ) return result.scalar_one_or_none() async def ensure_no_active_story_generation_job( db: AsyncSession, *, story_id: int, user_id: str, ) -> None: """Prevent duplicate asset work while a story already has a running job.""" active_job = await get_active_story_generation_job(db, story_id=story_id, user_id=user_id) if active_job is None: return progress = _job_progress(active_job) raise HTTPException( status_code=409, detail=( f"当前故事已有运行中的任务({progress['progress_label']})," "请等待当前任务完成后再试。" ), ) def _as_float(value: Any) -> float | None: if isinstance(value, int | float): return float(value) return None def _aggregate_provider_events( events: list[GenerationJobEvent], *, capability: str | None = None, ) -> dict[str, Any]: """Aggregate provider telemetry from provider call events.""" by_key: dict[tuple[str, str], dict[str, Any]] = {} failure_reasons: dict[str, int] = {} total_latency = 0.0 latency_count = 0 total_cost = 0.0 successful_calls = 0 failed_calls = 0 for event in events: metadata = event.event_metadata or {} event_capability = str(metadata.get("capability") or "unknown") if capability is not None and event_capability != capability: continue adapter = str(metadata.get("adapter") or "unknown") key = (event_capability, adapter) bucket = by_key.setdefault( key, { "capability": event_capability, "adapter": adapter, "call_count": 0, "success_count": 0, "failure_count": 0, "latency_total": 0.0, "latency_count": 0, "estimated_cost_usd": 0.0, }, ) bucket["call_count"] += 1 latency = _as_float(metadata.get("latency_ms")) if latency is not None: bucket["latency_total"] += latency bucket["latency_count"] += 1 total_latency += latency latency_count += 1 if event.event_type == "provider_call_succeeded": bucket["success_count"] += 1 successful_calls += 1 cost = _as_float(metadata.get("estimated_cost_usd")) or 0.0 bucket["estimated_cost_usd"] += cost total_cost += cost else: bucket["failure_count"] += 1 failed_calls += 1 reason = str(metadata.get("error") or "unknown_error") failure_reasons[reason] = failure_reasons.get(reason, 0) + 1 by_provider = [] for bucket in by_key.values(): bucket_latency_count = bucket.pop("latency_count") bucket_latency_total = bucket.pop("latency_total") by_provider.append( { **bucket, "avg_latency_ms": ( round(bucket_latency_total / bucket_latency_count, 2) if bucket_latency_count else None ), "estimated_cost_usd": round(bucket["estimated_cost_usd"], 6), } ) by_provider.sort( key=lambda item: ( str(item["capability"]), str(item["adapter"]), ) ) return { "total_calls": successful_calls + failed_calls, "successful_calls": successful_calls, "failed_calls": failed_calls, "avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None, "estimated_cost_usd": round(total_cost, 6), "by_provider": by_provider, "failure_reasons": [ {"reason": reason, "count": count} for reason, count in sorted( failure_reasons.items(), key=lambda item: (-item[1], item[0]), ) ], } def _provider_events_query( *, user_id: str, story_id: int | None = None, days: int | None = None, ): query = ( select(GenerationJobEvent) .join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id) .where( GenerationJob.user_id == user_id, GenerationJobEvent.event_type.in_( ["provider_call_succeeded", "provider_call_failed"] ), ) ) if story_id is not None: query = query.where(GenerationJob.story_id == story_id) if days is not None: cutoff = datetime.now(timezone.utc) - timedelta(days=days) query = query.where(GenerationJobEvent.created_at >= cutoff) return query.order_by(GenerationJobEvent.id) async def get_story_provider_stats( db: AsyncSession, *, story_id: int, user_id: str, days: int | None = None, capability: str | None = None, ) -> dict[str, Any]: """Aggregate provider call telemetry from all user-owned jobs for one story.""" events = ( await db.execute( _provider_events_query( user_id=user_id, story_id=story_id, days=days, ) ) ).scalars().all() return { "story_id": story_id, "window_days": days, "capability": capability, **_aggregate_provider_events(events, capability=capability), } async def get_user_provider_analytics( db: AsyncSession, *, user_id: str, days: int | None = None, capability: str | None = None, ) -> dict[str, Any]: """Aggregate provider telemetry across all stories owned by one user.""" events = ( await db.execute( _provider_events_query( user_id=user_id, days=days, ) ) ).scalars().all() filtered_event_job_ids = { event.job_id for event in events if capability is None or str((event.event_metadata or {}).get("capability") or "unknown") == capability } filtered_story_ids = { event.story_id for event in events if event.story_id is not None and ( capability is None or str((event.event_metadata or {}).get("capability") or "unknown") == capability ) } return { "window_days": days, "capability": capability, **_aggregate_provider_events(events, capability=capability), "job_count": len(filtered_event_job_ids), "story_count": len(filtered_story_ids), } async def get_user_generation_ops_summary( db: AsyncSession, *, user_id: str, hours: int = 24, recent_failure_limit: int = 5, ) -> dict[str, Any]: """Summarize recent generation health for one user.""" stale_after_minutes = settings.generation_job_stale_minutes recent_cutoff = datetime.now(timezone.utc) - timedelta(hours=hours) running_jobs = ( await db.execute( select(GenerationJob) .where( GenerationJob.user_id == user_id, GenerationJob.status == "running", ) .order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id)) ) ).scalars().all() recent_jobs = ( await db.execute( select(GenerationJob, Story.title) .outerjoin(Story, Story.id == GenerationJob.story_id) .where( GenerationJob.user_id == user_id, GenerationJob.updated_at >= recent_cutoff, ) .order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id)) ) ).all() recent_failures: list[dict[str, Any]] = [] failed_jobs = 0 degraded_jobs = 0 asset_retry_jobs = 0 for job, story_title in recent_jobs: if job.status == "failed": failed_jobs += 1 if len(recent_failures) < recent_failure_limit: recent_failures.append( { "job_id": job.id, "story_id": job.story_id, "story_title": story_title, "output_mode": job.output_mode, "current_step": job.current_step, "error_message": job.error_message, "failure_label": _failure_label(job), "updated_at": job.updated_at, } ) elif job.status == "degraded_completed": degraded_jobs += 1 if job.output_mode in {"asset_retry", "asset_generation"}: asset_retry_jobs += 1 return { "window_hours": hours, "stale_threshold_minutes": stale_after_minutes, "active_jobs": len(running_jobs), "stale_running_jobs": sum( 1 for job in running_jobs if _is_stale_job(job, stale_after_minutes=stale_after_minutes) ), "failed_jobs": failed_jobs, "degraded_jobs": degraded_jobs, "asset_retry_jobs": asset_retry_jobs, "recent_failures": recent_failures, } async def mark_stale_generation_jobs( db: AsyncSession, *, stale_after_minutes: int | None = None, ) -> dict[str, int]: """Mark long-running generation jobs as failed so they no longer appear stuck forever.""" threshold = stale_after_minutes or settings.generation_job_stale_minutes running_jobs = ( await db.execute( select(GenerationJob) .where(GenerationJob.status == "running") .order_by(GenerationJob.updated_at, GenerationJob.id) ) ).scalars().all() marked_stale = 0 for job in running_jobs: if not _is_stale_job(job, stale_after_minutes=threshold): continue story = None if job.story_id is not None: story = ( await db.execute( select(Story).where( Story.id == job.story_id, Story.user_id == job.user_id, ) ) ).scalar_one_or_none() await finish_generation_job( db, job=job, story=story, status="failed", current_step="generation_stale_failed", error_message=f"Generation job exceeded {threshold} minutes without progress.", message="Generation job was marked failed after exceeding the stale threshold.", metadata={"stale_after_minutes": threshold}, ) marked_stale += 1 logger.warning( "generation_job_marked_stale", job_id=job.id, story_id=job.story_id, output_mode=job.output_mode, stale_after_minutes=threshold, ) return { "running": len(running_jobs), "marked_stale": marked_stale, "stale_after_minutes": threshold, }