feat: improve generation analytics and maintenance
This commit is contained in:
@@ -4,7 +4,7 @@ import json
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter, Depends, Response
|
||||
from fastapi import APIRouter, Depends, Query, Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
@@ -19,6 +19,7 @@ from app.schemas.story_schemas import (
|
||||
GenerateRequest,
|
||||
GenerationJobDetailResponse,
|
||||
GenerationJobSummaryResponse,
|
||||
GenerationOpsSummaryResponse,
|
||||
GenerationProviderAnalyticsResponse,
|
||||
GenerationProviderStatsResponse,
|
||||
GenerationRequest,
|
||||
@@ -36,6 +37,7 @@ from app.services import story_service
|
||||
from app.services.generation_jobs import (
|
||||
get_generation_job_detail,
|
||||
get_story_provider_stats,
|
||||
get_user_generation_ops_summary,
|
||||
get_user_provider_analytics,
|
||||
list_story_generation_jobs,
|
||||
)
|
||||
@@ -86,16 +88,36 @@ async def get_generation_job(
|
||||
return await get_generation_job_detail(db, job_id=job_id, user_id=user.id)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/generations/ops-summary",
|
||||
response_model=GenerationOpsSummaryResponse,
|
||||
)
|
||||
async def get_generation_ops_summary(
|
||||
hours: int = Query(default=24, ge=1, le=168),
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get a compact recent operations summary for generation workflows."""
|
||||
return await get_user_generation_ops_summary(db, user_id=user.id, hours=hours)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/generations/provider-analytics",
|
||||
response_model=GenerationProviderAnalyticsResponse,
|
||||
)
|
||||
async def get_generation_provider_analytics(
|
||||
days: int | None = Query(default=None, ge=1, le=365),
|
||||
capability: str | None = Query(default=None),
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get provider call stats aggregated across the user's generation history."""
|
||||
return await get_user_provider_analytics(db, user_id=user.id)
|
||||
return await get_user_provider_analytics(
|
||||
db,
|
||||
user_id=user.id,
|
||||
days=days,
|
||||
capability=capability,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -117,11 +139,19 @@ async def list_generation_jobs(
|
||||
)
|
||||
async def get_generation_provider_stats(
|
||||
story_id: int,
|
||||
days: int | None = Query(default=None, ge=1, le=365),
|
||||
capability: str | None = Query(default=None),
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get provider call stats aggregated from generation job events."""
|
||||
return await get_story_provider_stats(db, story_id=story_id, user_id=user.id)
|
||||
return await get_story_provider_stats(
|
||||
db,
|
||||
story_id=story_id,
|
||||
user_id=user.id,
|
||||
days=days,
|
||||
capability=capability,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/generations/{story_id}", response_model=StoryDetailResponse)
|
||||
|
||||
@@ -49,6 +49,14 @@ celery_app.conf.update(
|
||||
"task": "app.tasks.memory.prune_memories_task",
|
||||
"schedule": crontab(minute="0", hour="3"), # Daily at 03:00
|
||||
},
|
||||
"prune_story_audio_cache": {
|
||||
"task": "app.tasks.audio_cache.prune_story_audio_cache_task",
|
||||
"schedule": crontab(minute="30", hour="3"), # Daily at 03:30
|
||||
},
|
||||
"prune_stale_generation_jobs": {
|
||||
"task": "app.tasks.generation_maintenance.prune_stale_generation_jobs_task",
|
||||
"schedule": crontab(minute="*/30"),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -62,12 +62,20 @@ class Settings(BaseSettings):
|
||||
False,
|
||||
description="Enable local deterministic demo providers for portfolio demos",
|
||||
)
|
||||
story_audio_cache_dir: str = Field(
|
||||
"storage/audio",
|
||||
description="Directory for cached story audio files",
|
||||
)
|
||||
|
||||
# Celery (Redis)
|
||||
story_audio_cache_dir: str = Field(
|
||||
"storage/audio",
|
||||
description="Directory for cached story audio files",
|
||||
)
|
||||
story_audio_cache_ttl_days: int = Field(
|
||||
30,
|
||||
description="TTL in days before cached story audio is pruned",
|
||||
)
|
||||
generation_job_stale_minutes: int = Field(
|
||||
60,
|
||||
description="Minutes before a running generation job is considered stale",
|
||||
)
|
||||
|
||||
# Celery (Redis)
|
||||
celery_broker_url: str = Field("redis://localhost:6379/0")
|
||||
celery_result_backend: str = Field("redis://localhost:6379/0")
|
||||
|
||||
|
||||
@@ -220,21 +220,33 @@ class GenerationProviderStatResponse(BaseModel):
|
||||
estimated_cost_usd: float = 0.0
|
||||
|
||||
|
||||
class GenerationProviderFailureReasonResponse(BaseModel):
|
||||
"""Aggregated failed provider call reason."""
|
||||
|
||||
reason: str
|
||||
count: int
|
||||
|
||||
|
||||
class GenerationProviderStatsResponse(BaseModel):
|
||||
"""Provider call stats aggregated from generation job events."""
|
||||
|
||||
story_id: int
|
||||
window_days: int | None = None
|
||||
capability: str | None = None
|
||||
total_calls: int
|
||||
successful_calls: int
|
||||
failed_calls: int
|
||||
avg_latency_ms: float | None = None
|
||||
estimated_cost_usd: float = 0.0
|
||||
by_provider: list[GenerationProviderStatResponse] = Field(default_factory=list)
|
||||
failure_reasons: list[GenerationProviderFailureReasonResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class GenerationProviderAnalyticsResponse(BaseModel):
|
||||
"""Provider call stats aggregated across one user's generation history."""
|
||||
|
||||
window_days: int | None = None
|
||||
capability: str | None = None
|
||||
total_calls: int
|
||||
successful_calls: int
|
||||
failed_calls: int
|
||||
@@ -243,6 +255,33 @@ class GenerationProviderAnalyticsResponse(BaseModel):
|
||||
job_count: int
|
||||
story_count: int
|
||||
by_provider: list[GenerationProviderStatResponse] = Field(default_factory=list)
|
||||
failure_reasons: list[GenerationProviderFailureReasonResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class GenerationRecentFailureResponse(BaseModel):
|
||||
"""One recent failed generation task for operations summary."""
|
||||
|
||||
job_id: str
|
||||
story_id: int | None = None
|
||||
story_title: str | None = None
|
||||
output_mode: str
|
||||
current_step: str
|
||||
error_message: str | None = None
|
||||
failure_label: str
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class GenerationOpsSummaryResponse(BaseModel):
|
||||
"""Recent generation health summary for one user."""
|
||||
|
||||
window_hours: int
|
||||
stale_threshold_minutes: int
|
||||
active_jobs: int
|
||||
stale_running_jobs: int
|
||||
failed_jobs: int
|
||||
degraded_jobs: int
|
||||
asset_retry_jobs: int
|
||||
recent_failures: list[GenerationRecentFailureResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AchievementItem(BaseModel):
|
||||
|
||||
@@ -2,14 +2,19 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import desc, distinct, func, select
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import GenerationJob, GenerationJobEvent, Story
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _story_snapshot(story: Story | None) -> dict[str, Any]:
|
||||
if story is None:
|
||||
@@ -68,6 +73,7 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
||||
"asset_generation_completed": (100, "资源已完成"),
|
||||
"asset_retry_completed": (100, "资源重试完成"),
|
||||
"generation_completed": (100, "生成完成"),
|
||||
"generation_stale_failed": (100, "任务超时已收敛"),
|
||||
}
|
||||
percent, label = progress_map.get(job.current_step, (10, "生成处理中"))
|
||||
return {
|
||||
@@ -77,6 +83,27 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _normalize_datetime(value: datetime) -> datetime:
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def _is_stale_job(job: GenerationJob, *, stale_after_minutes: int) -> bool:
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(minutes=stale_after_minutes)
|
||||
return job.status == "running" and _normalize_datetime(job.updated_at) <= cutoff
|
||||
|
||||
|
||||
def _failure_label(job: GenerationJob) -> str:
|
||||
if job.current_step == "generation_stale_failed":
|
||||
return "任务超时"
|
||||
if job.output_mode == "asset_retry":
|
||||
return "资源重试失败"
|
||||
if job.output_mode == "asset_generation":
|
||||
return "资源生成失败"
|
||||
return "生成失败"
|
||||
|
||||
|
||||
async def create_generation_job(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
@@ -266,16 +293,64 @@ async def list_story_generation_jobs(
|
||||
return [generation_job_to_summary(job) for job in jobs]
|
||||
|
||||
|
||||
async def get_active_story_generation_job(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
) -> GenerationJob | None:
|
||||
"""Return the most recent running job for a story, if any."""
|
||||
|
||||
result = await db.execute(
|
||||
select(GenerationJob)
|
||||
.where(
|
||||
GenerationJob.story_id == story_id,
|
||||
GenerationJob.user_id == user_id,
|
||||
GenerationJob.status == "running",
|
||||
)
|
||||
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def ensure_no_active_story_generation_job(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
"""Prevent duplicate asset work while a story already has a running job."""
|
||||
|
||||
active_job = await get_active_story_generation_job(db, story_id=story_id, user_id=user_id)
|
||||
if active_job is None:
|
||||
return
|
||||
|
||||
progress = _job_progress(active_job)
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=(
|
||||
f"当前故事已有运行中的任务({progress['progress_label']}),"
|
||||
"请等待当前任务完成后再试。"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _as_float(value: Any) -> float | None:
|
||||
if isinstance(value, int | float):
|
||||
return float(value)
|
||||
return None
|
||||
|
||||
|
||||
def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, Any]:
|
||||
def _aggregate_provider_events(
|
||||
events: list[GenerationJobEvent],
|
||||
*,
|
||||
capability: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Aggregate provider telemetry from provider call events."""
|
||||
|
||||
by_key: dict[tuple[str, str], dict[str, Any]] = {}
|
||||
failure_reasons: dict[str, int] = {}
|
||||
total_latency = 0.0
|
||||
latency_count = 0
|
||||
total_cost = 0.0
|
||||
@@ -284,13 +359,16 @@ def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, An
|
||||
|
||||
for event in events:
|
||||
metadata = event.event_metadata or {}
|
||||
capability = str(metadata.get("capability") or "unknown")
|
||||
event_capability = str(metadata.get("capability") or "unknown")
|
||||
if capability is not None and event_capability != capability:
|
||||
continue
|
||||
|
||||
adapter = str(metadata.get("adapter") or "unknown")
|
||||
key = (capability, adapter)
|
||||
key = (event_capability, adapter)
|
||||
bucket = by_key.setdefault(
|
||||
key,
|
||||
{
|
||||
"capability": capability,
|
||||
"capability": event_capability,
|
||||
"adapter": adapter,
|
||||
"call_count": 0,
|
||||
"success_count": 0,
|
||||
@@ -318,6 +396,8 @@ def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, An
|
||||
else:
|
||||
bucket["failure_count"] += 1
|
||||
failed_calls += 1
|
||||
reason = str(metadata.get("error") or "unknown_error")
|
||||
failure_reasons[reason] = failure_reasons.get(reason, 0) + 1
|
||||
|
||||
by_provider = []
|
||||
for bucket in by_key.values():
|
||||
@@ -349,67 +429,243 @@ def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, An
|
||||
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
|
||||
"estimated_cost_usd": round(total_cost, 6),
|
||||
"by_provider": by_provider,
|
||||
"failure_reasons": [
|
||||
{"reason": reason, "count": count}
|
||||
for reason, count in sorted(
|
||||
failure_reasons.items(),
|
||||
key=lambda item: (-item[1], item[0]),
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _provider_events_query(
|
||||
*,
|
||||
user_id: str,
|
||||
story_id: int | None = None,
|
||||
days: int | None = None,
|
||||
):
|
||||
query = (
|
||||
select(GenerationJobEvent)
|
||||
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
|
||||
.where(
|
||||
GenerationJob.user_id == user_id,
|
||||
GenerationJobEvent.event_type.in_(
|
||||
["provider_call_succeeded", "provider_call_failed"]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if story_id is not None:
|
||||
query = query.where(GenerationJob.story_id == story_id)
|
||||
|
||||
if days is not None:
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
query = query.where(GenerationJobEvent.created_at >= cutoff)
|
||||
|
||||
return query.order_by(GenerationJobEvent.id)
|
||||
|
||||
|
||||
async def get_story_provider_stats(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
days: int | None = None,
|
||||
capability: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
|
||||
|
||||
events = (
|
||||
await db.execute(
|
||||
select(GenerationJobEvent)
|
||||
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
|
||||
.where(
|
||||
GenerationJob.story_id == story_id,
|
||||
GenerationJob.user_id == user_id,
|
||||
GenerationJobEvent.event_type.in_(
|
||||
["provider_call_succeeded", "provider_call_failed"]
|
||||
),
|
||||
_provider_events_query(
|
||||
user_id=user_id,
|
||||
story_id=story_id,
|
||||
days=days,
|
||||
)
|
||||
.order_by(GenerationJobEvent.id)
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
return {"story_id": story_id, **_aggregate_provider_events(events)}
|
||||
return {
|
||||
"story_id": story_id,
|
||||
"window_days": days,
|
||||
"capability": capability,
|
||||
**_aggregate_provider_events(events, capability=capability),
|
||||
}
|
||||
|
||||
|
||||
async def get_user_provider_analytics(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
days: int | None = None,
|
||||
capability: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Aggregate provider telemetry across all stories owned by one user."""
|
||||
|
||||
events = (
|
||||
await db.execute(
|
||||
select(GenerationJobEvent)
|
||||
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
|
||||
_provider_events_query(
|
||||
user_id=user_id,
|
||||
days=days,
|
||||
)
|
||||
)
|
||||
).scalars().all()
|
||||
filtered_event_job_ids = {
|
||||
event.job_id
|
||||
for event in events
|
||||
if capability is None
|
||||
or str((event.event_metadata or {}).get("capability") or "unknown") == capability
|
||||
}
|
||||
filtered_story_ids = {
|
||||
event.story_id
|
||||
for event in events
|
||||
if event.story_id is not None
|
||||
and (
|
||||
capability is None
|
||||
or str((event.event_metadata or {}).get("capability") or "unknown") == capability
|
||||
)
|
||||
}
|
||||
|
||||
return {
|
||||
"window_days": days,
|
||||
"capability": capability,
|
||||
**_aggregate_provider_events(events, capability=capability),
|
||||
"job_count": len(filtered_event_job_ids),
|
||||
"story_count": len(filtered_story_ids),
|
||||
}
|
||||
|
||||
|
||||
async def get_user_generation_ops_summary(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
hours: int = 24,
|
||||
recent_failure_limit: int = 5,
|
||||
) -> dict[str, Any]:
|
||||
"""Summarize recent generation health for one user."""
|
||||
|
||||
stale_after_minutes = settings.generation_job_stale_minutes
|
||||
recent_cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
|
||||
running_jobs = (
|
||||
await db.execute(
|
||||
select(GenerationJob)
|
||||
.where(
|
||||
GenerationJob.user_id == user_id,
|
||||
GenerationJobEvent.event_type.in_(
|
||||
["provider_call_succeeded", "provider_call_failed"]
|
||||
),
|
||||
GenerationJob.status == "running",
|
||||
)
|
||||
.order_by(GenerationJobEvent.id)
|
||||
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
job_count, story_count = (
|
||||
recent_jobs = (
|
||||
await db.execute(
|
||||
select(
|
||||
func.count(GenerationJob.id),
|
||||
func.count(distinct(GenerationJob.story_id)),
|
||||
).where(GenerationJob.user_id == user_id)
|
||||
select(GenerationJob, Story.title)
|
||||
.outerjoin(Story, Story.id == GenerationJob.story_id)
|
||||
.where(
|
||||
GenerationJob.user_id == user_id,
|
||||
GenerationJob.updated_at >= recent_cutoff,
|
||||
)
|
||||
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
|
||||
)
|
||||
).one()
|
||||
).all()
|
||||
|
||||
recent_failures: list[dict[str, Any]] = []
|
||||
failed_jobs = 0
|
||||
degraded_jobs = 0
|
||||
asset_retry_jobs = 0
|
||||
|
||||
for job, story_title in recent_jobs:
|
||||
if job.status == "failed":
|
||||
failed_jobs += 1
|
||||
if len(recent_failures) < recent_failure_limit:
|
||||
recent_failures.append(
|
||||
{
|
||||
"job_id": job.id,
|
||||
"story_id": job.story_id,
|
||||
"story_title": story_title,
|
||||
"output_mode": job.output_mode,
|
||||
"current_step": job.current_step,
|
||||
"error_message": job.error_message,
|
||||
"failure_label": _failure_label(job),
|
||||
"updated_at": job.updated_at,
|
||||
}
|
||||
)
|
||||
elif job.status == "degraded_completed":
|
||||
degraded_jobs += 1
|
||||
|
||||
if job.output_mode in {"asset_retry", "asset_generation"}:
|
||||
asset_retry_jobs += 1
|
||||
|
||||
return {
|
||||
**_aggregate_provider_events(events),
|
||||
"job_count": job_count,
|
||||
"story_count": story_count,
|
||||
"window_hours": hours,
|
||||
"stale_threshold_minutes": stale_after_minutes,
|
||||
"active_jobs": len(running_jobs),
|
||||
"stale_running_jobs": sum(
|
||||
1 for job in running_jobs if _is_stale_job(job, stale_after_minutes=stale_after_minutes)
|
||||
),
|
||||
"failed_jobs": failed_jobs,
|
||||
"degraded_jobs": degraded_jobs,
|
||||
"asset_retry_jobs": asset_retry_jobs,
|
||||
"recent_failures": recent_failures,
|
||||
}
|
||||
|
||||
|
||||
async def mark_stale_generation_jobs(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
stale_after_minutes: int | None = None,
|
||||
) -> dict[str, int]:
|
||||
"""Mark long-running generation jobs as failed so they no longer appear stuck forever."""
|
||||
|
||||
threshold = stale_after_minutes or settings.generation_job_stale_minutes
|
||||
running_jobs = (
|
||||
await db.execute(
|
||||
select(GenerationJob)
|
||||
.where(GenerationJob.status == "running")
|
||||
.order_by(GenerationJob.updated_at, GenerationJob.id)
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
marked_stale = 0
|
||||
|
||||
for job in running_jobs:
|
||||
if not _is_stale_job(job, stale_after_minutes=threshold):
|
||||
continue
|
||||
|
||||
story = None
|
||||
if job.story_id is not None:
|
||||
story = (
|
||||
await db.execute(
|
||||
select(Story).where(
|
||||
Story.id == job.story_id,
|
||||
Story.user_id == job.user_id,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
await finish_generation_job(
|
||||
db,
|
||||
job=job,
|
||||
story=story,
|
||||
status="failed",
|
||||
current_step="generation_stale_failed",
|
||||
error_message=f"Generation job exceeded {threshold} minutes without progress.",
|
||||
message="Generation job was marked failed after exceeding the stale threshold.",
|
||||
metadata={"stale_after_minutes": threshold},
|
||||
)
|
||||
marked_stale += 1
|
||||
logger.warning(
|
||||
"generation_job_marked_stale",
|
||||
job_id=job.id,
|
||||
story_id=job.story_id,
|
||||
output_mode=job.output_mode,
|
||||
stale_after_minutes=threshold,
|
||||
)
|
||||
|
||||
return {
|
||||
"running": len(running_jobs),
|
||||
"marked_stale": marked_stale,
|
||||
"stale_after_minutes": threshold,
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -9,6 +10,7 @@ from sqlalchemy import desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import ChildProfile, Story, StoryUniverse
|
||||
from app.schemas.story_schemas import (
|
||||
@@ -32,6 +34,7 @@ from app.services.audio_storage import (
|
||||
)
|
||||
from app.services.generation_jobs import (
|
||||
create_generation_job,
|
||||
ensure_no_active_story_generation_job,
|
||||
finish_generation_job,
|
||||
record_generation_event,
|
||||
)
|
||||
@@ -1369,6 +1372,7 @@ async def retry_story_assets(
|
||||
db: AsyncSession,
|
||||
) -> Story:
|
||||
"""Retry selected assets through one workflow-level endpoint."""
|
||||
await ensure_no_active_story_generation_job(db, story_id=story_id, user_id=user_id)
|
||||
requested_assets = list(dict.fromkeys(assets))
|
||||
job = await create_generation_job(
|
||||
db,
|
||||
@@ -1443,6 +1447,7 @@ async def generate_story_cover(
|
||||
db: AsyncSession,
|
||||
) -> str:
|
||||
"""Generate cover image for an existing story."""
|
||||
await ensure_no_active_story_generation_job(db, story_id=story_id, user_id=user_id)
|
||||
job = await create_generation_job(
|
||||
db,
|
||||
user_id=user_id,
|
||||
@@ -1495,6 +1500,7 @@ async def generate_story_audio(
|
||||
db: AsyncSession,
|
||||
) -> bytes:
|
||||
"""Generate audio for a story."""
|
||||
await ensure_no_active_story_generation_job(db, story_id=story_id, user_id=user_id)
|
||||
job = await create_generation_job(
|
||||
db,
|
||||
user_id=user_id,
|
||||
@@ -1597,6 +1603,50 @@ async def clear_story_audio_cache(
|
||||
return await get_story_audio_status(story_id, user_id, db)
|
||||
|
||||
|
||||
async def prune_story_audio_cache(db: AsyncSession) -> dict[str, int]:
|
||||
"""Prune expired audio cache files and repair story metadata."""
|
||||
|
||||
ttl_days = max(1, settings.story_audio_cache_ttl_days)
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=ttl_days)
|
||||
result = await db.execute(select(Story).where(Story.audio_path.is_not(None)))
|
||||
stories = result.scalars().all()
|
||||
|
||||
scanned = 0
|
||||
pruned = 0
|
||||
repaired = 0
|
||||
|
||||
for story in stories:
|
||||
scanned += 1
|
||||
metadata = get_audio_cache_metadata(story.audio_path)
|
||||
|
||||
if not metadata.exists:
|
||||
story.audio_path = None
|
||||
if story.audio_status == StoryAssetStatus.READY.value:
|
||||
sync_story_status(story, audio_status=StoryAssetStatus.NOT_REQUESTED)
|
||||
repaired += 1
|
||||
continue
|
||||
|
||||
if metadata.updated_at and metadata.updated_at < cutoff:
|
||||
delete_audio_cache(story.audio_path)
|
||||
story.audio_path = None
|
||||
sync_story_status(
|
||||
story,
|
||||
audio_status=StoryAssetStatus.NOT_REQUESTED,
|
||||
last_error=None,
|
||||
)
|
||||
pruned += 1
|
||||
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"story_audio_cache_pruned",
|
||||
scanned=scanned,
|
||||
pruned=pruned,
|
||||
repaired=repaired,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
return {"scanned": scanned, "pruned": pruned, "repaired": repaired}
|
||||
|
||||
|
||||
async def get_story_achievements(
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
|
||||
29
backend/app/tasks/audio_cache.py
Normal file
29
backend/app/tasks/audio_cache.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Celery tasks for story audio cache maintenance."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from app.core.celery_app import celery_app
|
||||
from app.core.logging import get_logger
|
||||
from app.db.database import _get_session_factory
|
||||
from app.services.story_service import prune_story_audio_cache
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def prune_story_audio_cache_task():
|
||||
"""Daily task to prune expired story audio cache files."""
|
||||
logger.info("prune_story_audio_cache_task_started")
|
||||
|
||||
async def _run():
|
||||
session_factory = _get_session_factory()
|
||||
async with session_factory() as session:
|
||||
return await prune_story_audio_cache(session)
|
||||
|
||||
try:
|
||||
result = asyncio.run(_run())
|
||||
logger.info("prune_story_audio_cache_task_completed", **result)
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.error("prune_story_audio_cache_task_failed", error=str(exc))
|
||||
raise
|
||||
30
backend/app/tasks/generation_maintenance.py
Normal file
30
backend/app/tasks/generation_maintenance.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Generation job maintenance tasks."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from app.core.celery_app import celery_app
|
||||
from app.core.logging import get_logger
|
||||
from app.db.database import _get_session_factory
|
||||
from app.services.generation_jobs import mark_stale_generation_jobs
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def prune_stale_generation_jobs_task():
|
||||
"""Periodically mark stale running generation jobs as failed."""
|
||||
|
||||
logger.info("prune_stale_generation_jobs_task_started")
|
||||
|
||||
async def _run():
|
||||
session_factory = _get_session_factory()
|
||||
async with session_factory() as session:
|
||||
return await mark_stale_generation_jobs(session)
|
||||
|
||||
try:
|
||||
result = asyncio.run(_run())
|
||||
logger.info("prune_stale_generation_jobs_task_completed", **result)
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.error("prune_stale_generation_jobs_task_failed", error=str(exc))
|
||||
raise
|
||||
Reference in New Issue
Block a user