"""Lightweight generation job/event tracking.""" from __future__ import annotations from typing import Any from fastapi import HTTPException from sqlalchemy import desc, distinct, func, select from sqlalchemy.ext.asyncio import AsyncSession from app.db.models import GenerationJob, GenerationJobEvent, Story def _story_snapshot(story: Story | None) -> dict[str, Any]: if story is None: return {} return { "story_id": story.id, "mode": story.mode, "generation_status": story.generation_status, "text_status": story.text_status, "image_status": story.image_status, "audio_status": story.audio_status, "retryable_assets": story.retryable_assets, "last_error": story.last_error, } def _job_status_from_story(story: Story) -> str: if story.generation_status == "failed": return "failed" if story.generation_status == "degraded_completed": return "degraded_completed" return "completed" def _job_progress(job: GenerationJob) -> dict[str, Any]: """Resolve a compact progress summary for polling-oriented clients.""" if job.status == "failed": return { "progress_percent": 100, "progress_label": "生成失败", "is_terminal": True, } if job.status in {"completed", "degraded_completed"}: return { "progress_percent": 100, "progress_label": "已完成" if job.status == "completed" else "降级完成", "is_terminal": True, } progress_map: dict[str, tuple[int, str]] = { "request_accepted": (5, "已接收请求"), "context_prepared": (20, "上下文已准备"), "narrative_generated": (45, "正文已生成"), "story_saved": (60, "主记录已保存"), "provider_call_started": (65, "Provider 调用中"), "provider_call_succeeded": (72, "Provider 调用成功"), "provider_call_failed": (72, "Provider 调用失败,尝试恢复"), "cover_image_started": (75, "封面生成中"), "storybook_images_started": (75, "绘本插图生成中"), "audio_started": (75, "音频生成中"), "asset_retry_started": (25, "资源重试中"), "postprocessing_queued": (90, "后处理已排队"), "asset_generation_completed": (100, "资源已完成"), "asset_retry_completed": (100, "资源重试完成"), "generation_completed": (100, "生成完成"), } percent, label = progress_map.get(job.current_step, (10, "生成处理中")) return { "progress_percent": percent, "progress_label": label, "is_terminal": percent >= 100, } async def create_generation_job( db: AsyncSession, *, user_id: str, output_mode: str, input_type: str, request_payload: dict[str, Any], story_id: int | None = None, ) -> GenerationJob: """Create a generation job and record its first event.""" job = GenerationJob( user_id=user_id, story_id=story_id, output_mode=output_mode, input_type=input_type, status="running", current_step="request_accepted", request_payload=request_payload, result_snapshot={}, ) db.add(job) await db.flush() await record_generation_event( db, job=job, story_id=story_id, event_type="request_accepted", status="succeeded", message="Generation request accepted.", metadata={"output_mode": output_mode, "input_type": input_type}, commit=False, ) await db.commit() await db.refresh(job) return job async def record_generation_event( db: AsyncSession, *, job: GenerationJob, event_type: str, status: str, story_id: int | None = None, message: str | None = None, metadata: dict[str, Any] | None = None, commit: bool = True, ) -> GenerationJobEvent: """Append one event to an existing generation job.""" event = GenerationJobEvent( job_id=job.id, story_id=story_id if story_id is not None else job.story_id, event_type=event_type, status=status, message=message, event_metadata=metadata or {}, ) db.add(event) if commit: await db.commit() return event async def finish_generation_job( db: AsyncSession, *, job: GenerationJob, story: Story | None, status: str | None = None, current_step: str, error_message: str | None = None, message: str | None = None, metadata: dict[str, Any] | None = None, ) -> GenerationJob: """Mark a generation job as completed/degraded/failed and append a final event.""" job.story_id = story.id if story is not None else job.story_id job.status = status or (_job_status_from_story(story) if story is not None else "failed") job.current_step = current_step job.error_message = error_message job.result_snapshot = _story_snapshot(story) await record_generation_event( db, job=job, story_id=job.story_id, event_type=current_step, status=job.status, message=message, metadata={ **(metadata or {}), "result_snapshot": job.result_snapshot, }, commit=False, ) await db.commit() await db.refresh(job) return job def generation_event_to_response(event: GenerationJobEvent) -> dict[str, Any]: """Convert a generation event ORM object to an API response dict.""" return { "id": event.id, "job_id": event.job_id, "story_id": event.story_id, "event_type": event.event_type, "status": event.status, "message": event.message, "event_metadata": event.event_metadata or {}, "created_at": event.created_at, } def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]: """Convert a generation job ORM object to an API summary dict.""" progress = _job_progress(job) return { "id": job.id, "story_id": job.story_id, "output_mode": job.output_mode, "input_type": job.input_type, "status": job.status, "current_step": job.current_step, **progress, "result_snapshot": job.result_snapshot or {}, "error_message": job.error_message, "created_at": job.created_at, "updated_at": job.updated_at, } async def get_generation_job_detail( db: AsyncSession, *, job_id: str, user_id: str, ) -> dict[str, Any]: """Return a user-owned generation job with its ordered event stream.""" result = await db.execute( select(GenerationJob).where( GenerationJob.id == job_id, GenerationJob.user_id == user_id, ) ) job = result.scalar_one_or_none() if job is None: raise HTTPException(status_code=404, detail="Generation job not found") events = ( await db.execute( select(GenerationJobEvent) .where(GenerationJobEvent.job_id == job.id) .order_by(GenerationJobEvent.id) ) ).scalars().all() return { **generation_job_to_summary(job), "request_payload": job.request_payload or {}, "events": [generation_event_to_response(event) for event in events], } async def list_story_generation_jobs( db: AsyncSession, *, story_id: int, user_id: str, ) -> list[dict[str, Any]]: """Return recent generation jobs for a user-owned story.""" jobs = ( await db.execute( select(GenerationJob) .where( GenerationJob.story_id == story_id, GenerationJob.user_id == user_id, ) .order_by(desc(GenerationJob.created_at), desc(GenerationJob.id)) ) ).scalars().all() return [generation_job_to_summary(job) for job in jobs] def _as_float(value: Any) -> float | None: if isinstance(value, int | float): return float(value) return None def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, Any]: """Aggregate provider telemetry from provider call events.""" by_key: dict[tuple[str, str], dict[str, Any]] = {} total_latency = 0.0 latency_count = 0 total_cost = 0.0 successful_calls = 0 failed_calls = 0 for event in events: metadata = event.event_metadata or {} capability = str(metadata.get("capability") or "unknown") adapter = str(metadata.get("adapter") or "unknown") key = (capability, adapter) bucket = by_key.setdefault( key, { "capability": capability, "adapter": adapter, "call_count": 0, "success_count": 0, "failure_count": 0, "latency_total": 0.0, "latency_count": 0, "estimated_cost_usd": 0.0, }, ) bucket["call_count"] += 1 latency = _as_float(metadata.get("latency_ms")) if latency is not None: bucket["latency_total"] += latency bucket["latency_count"] += 1 total_latency += latency latency_count += 1 if event.event_type == "provider_call_succeeded": bucket["success_count"] += 1 successful_calls += 1 cost = _as_float(metadata.get("estimated_cost_usd")) or 0.0 bucket["estimated_cost_usd"] += cost total_cost += cost else: bucket["failure_count"] += 1 failed_calls += 1 by_provider = [] for bucket in by_key.values(): bucket_latency_count = bucket.pop("latency_count") bucket_latency_total = bucket.pop("latency_total") by_provider.append( { **bucket, "avg_latency_ms": ( round(bucket_latency_total / bucket_latency_count, 2) if bucket_latency_count else None ), "estimated_cost_usd": round(bucket["estimated_cost_usd"], 6), } ) by_provider.sort( key=lambda item: ( str(item["capability"]), str(item["adapter"]), ) ) return { "total_calls": successful_calls + failed_calls, "successful_calls": successful_calls, "failed_calls": failed_calls, "avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None, "estimated_cost_usd": round(total_cost, 6), "by_provider": by_provider, } async def get_story_provider_stats( db: AsyncSession, *, story_id: int, user_id: str, ) -> dict[str, Any]: """Aggregate provider call telemetry from all user-owned jobs for one story.""" events = ( await db.execute( select(GenerationJobEvent) .join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id) .where( GenerationJob.story_id == story_id, GenerationJob.user_id == user_id, GenerationJobEvent.event_type.in_( ["provider_call_succeeded", "provider_call_failed"] ), ) .order_by(GenerationJobEvent.id) ) ).scalars().all() return {"story_id": story_id, **_aggregate_provider_events(events)} async def get_user_provider_analytics( db: AsyncSession, *, user_id: str, ) -> dict[str, Any]: """Aggregate provider telemetry across all stories owned by one user.""" events = ( await db.execute( select(GenerationJobEvent) .join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id) .where( GenerationJob.user_id == user_id, GenerationJobEvent.event_type.in_( ["provider_call_succeeded", "provider_call_failed"] ), ) .order_by(GenerationJobEvent.id) ) ).scalars().all() job_count, story_count = ( await db.execute( select( func.count(GenerationJob.id), func.count(distinct(GenerationJob.story_id)), ).where(GenerationJob.user_id == user_id) ) ).one() return { **_aggregate_provider_events(events), "job_count": job_count, "story_count": story_count, }