Files
dreamweaver/backend/app/api/stories.py

509 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Story related APIs."""
import json
import uuid
from typing import AsyncGenerator
from fastapi import APIRouter, Depends, Query, Response
from sqlalchemy.ext.asyncio import AsyncSession
from sse_starlette.sse import EventSourceResponse
from app.core.deps import require_user
from app.core.logging import get_logger
from app.core.rate_limiter import check_rate_limit
from app.db.database import get_db
from app.db.models import User
from app.schemas.story_schemas import (
AchievementItem,
FullStoryResponse,
GenerateRequest,
GenerationJobDetailResponse,
GenerationJobSummaryResponse,
GenerationOpsSummaryResponse,
GenerationProviderAnalyticsResponse,
GenerationProviderStatsResponse,
GenerationRequest,
GenerationResponse,
GenerationTraceSummaryResponse,
StoryAssetRetryRequest,
StoryAudioStatusResponse,
StorybookRequest,
StorybookResponse,
StoryDetailResponse,
StoryImageResponse,
StoryListItem,
StoryResponse,
)
from app.services import story_service
from app.services.generation_jobs import (
get_generation_job_detail,
get_story_provider_stats,
get_story_trace_summary,
get_user_generation_ops_summary,
get_user_provider_analytics,
list_story_generation_jobs,
request_generation_job_cancel,
)
from app.services.memory_service import build_enhanced_memory_context
from app.services.provider_router import (
generate_image,
generate_story_content,
)
from app.services.story_status import StoryAssetStatus, sync_story_status
logger = get_logger(__name__)
router = APIRouter()
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10
def _legacy_generation_headers(successor: str) -> dict[str, str]:
return {
"Deprecation": "true",
"Link": f"<{successor}>; rel=\"successor-version\"",
"X-DreamWeaver-Successor-Endpoint": successor,
}
def _mark_legacy_generation_endpoint(response: Response, successor: str) -> None:
response.headers.update(_legacy_generation_headers(successor))
@router.post("/generations", response_model=GenerationResponse, status_code=202)
async def create_generation(
request: GenerationRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Accept one story/storybook generation request for background execution."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_generation_service(request, user.id, db)
@router.get("/generations/jobs/{job_id}", response_model=GenerationJobDetailResponse)
async def get_generation_job(
job_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get one generation job with ordered workflow events."""
return await get_generation_job_detail(db, job_id=job_id, user_id=user.id)
@router.post(
"/generations/jobs/{job_id}/cancel",
response_model=GenerationJobSummaryResponse,
)
async def cancel_generation_job(
job_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Request cancellation for one queued/running generation job."""
return await request_generation_job_cancel(db, job_id=job_id, user_id=user.id)
@router.post(
"/generations/jobs/{job_id}/retry",
response_model=GenerationJobSummaryResponse,
)
async def retry_generation_job(
job_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Queue one new generation job from a failed/canceled terminal job."""
return await story_service.retry_generation_job_service(job_id, user.id, db)
@router.get(
"/generations/ops-summary",
response_model=GenerationOpsSummaryResponse,
)
async def get_generation_ops_summary(
hours: int = Query(default=24, ge=1, le=168),
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get a compact recent operations summary for generation workflows."""
return await get_user_generation_ops_summary(db, user_id=user.id, hours=hours)
@router.get(
"/generations/provider-analytics",
response_model=GenerationProviderAnalyticsResponse,
)
async def get_generation_provider_analytics(
days: int | None = Query(default=None, ge=1, le=365),
capability: str | None = Query(default=None),
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get provider call stats aggregated across the user's generation history."""
return await get_user_provider_analytics(
db,
user_id=user.id,
days=days,
capability=capability,
)
@router.get(
"/generations/{story_id}/jobs",
response_model=list[GenerationJobSummaryResponse],
)
async def list_generation_jobs(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List recent generation jobs for a generated story/storybook."""
return await list_story_generation_jobs(db, story_id=story_id, user_id=user.id)
@router.get(
"/generations/{story_id}/provider-stats",
response_model=GenerationProviderStatsResponse,
)
async def get_generation_provider_stats(
story_id: int,
days: int | None = Query(default=None, ge=1, le=365),
capability: str | None = Query(default=None),
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get provider call stats aggregated from generation job events."""
return await get_story_provider_stats(
db,
story_id=story_id,
user_id=user.id,
days=days,
capability=capability,
)
@router.get(
"/generations/{story_id}/trace-summary",
response_model=GenerationTraceSummaryResponse,
)
async def get_generation_trace_summary(
story_id: int,
days: int | None = Query(default=None, ge=1, le=365),
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get workflow trace summary aggregated from generation job events."""
return await get_story_trace_summary(
db,
story_id=story_id,
user_id=user.id,
days=days,
)
@router.get("/generations/{story_id}", response_model=StoryDetailResponse)
async def get_generation(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get a generated story/storybook through the unified generation API."""
return await story_service.get_story_detail(story_id, user.id, db)
@router.post("/generations/{story_id}/retry-assets", response_model=StoryDetailResponse)
async def retry_generation_assets(
story_id: int,
payload: StoryAssetRetryRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Retry generated assets through the unified generation API."""
return await story_service.retry_story_assets(story_id, user.id, payload.assets, db)
@router.post("/stories/generate", response_model=StoryResponse)
async def generate_story(
request: GenerateRequest,
response: Response,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate or enhance a story."""
_mark_legacy_generation_endpoint(response, "/api/generations")
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_and_save_story(request, user.id, db)
@router.post("/stories/generate/full", response_model=FullStoryResponse)
async def generate_story_full(
request: GenerateRequest,
response: Response,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate complete story (story + parallel image/audio generation)."""
_mark_legacy_generation_endpoint(response, "/api/generations")
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_full_story_service(request, user.id, db)
@router.post("/stories/generate/stream")
async def generate_story_stream(
request: GenerateRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""流式生成故事SSE"""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
# Validation
profile_id, universe_id = await story_service.validate_profile_and_universe(
request.child_profile_id, request.universe_id, user.id, db
)
# Build Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
async def event_generator() -> AsyncGenerator[dict, None]:
story_id = str(uuid.uuid4())
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
# Step 1: Generate Content
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
user_id=user.id,
db=db,
)
except Exception as e:
logger.error("sse_story_generation_failed", error=str(e))
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
return
# Save Story
story = await story_service.create_story_from_result(
result, user.id, profile_id, universe_id, db
)
yield {
"event": "story_ready",
"data": json.dumps({
"id": story.id,
"title": story.title,
"content": story.story_text,
"cover_prompt": story.cover_prompt,
"mode": story.mode,
"child_profile_id": story.child_profile_id,
"universe_id": story.universe_id,
"generation_status": story.generation_status,
"text_status": story.text_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
}),
}
# Step 2: Generate Image
if story.cover_prompt:
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
# Direct call to provider router's generate_image, sharing db session
image_url = await generate_image(
story.cover_prompt,
db=db,
user_id=user.id,
story_id=story.id,
)
story.image_url = image_url
sync_story_status(
story,
image_status=StoryAssetStatus.READY,
)
await db.commit()
yield {
"event": "image_ready",
"data": json.dumps(
{
"image_url": image_url,
"generation_status": story.generation_status,
"text_status": story.text_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
}
),
}
except Exception as e:
sync_story_status(
story,
image_status=StoryAssetStatus.FAILED,
last_error=str(e),
)
await db.commit()
logger.warning("sse_image_generation_failed", story_id=story.id, error=str(e))
yield {
"event": "image_failed",
"data": json.dumps(
{
"error": str(e),
"generation_status": story.generation_status,
"text_status": story.text_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
}
),
}
yield {
"event": "complete",
"data": json.dumps(
{
"story_id": story.id,
"generation_status": story.generation_status,
"text_status": story.text_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
}
),
}
return EventSourceResponse(
event_generator(),
headers=_legacy_generation_headers("/api/generations"),
)
@router.post("/storybook/generate", response_model=StorybookResponse)
async def generate_storybook_api(
request: StorybookRequest,
response: Response,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate storybook."""
_mark_legacy_generation_endpoint(response, "/api/generations")
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_storybook_service(request, user.id, db)
@router.get("/stories", response_model=list[StoryListItem])
async def list_stories(
limit: int = 20,
offset: int = 0,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List stories."""
return await story_service.list_stories(user.id, limit, offset, db)
@router.get("/stories/{story_id}", response_model=StoryDetailResponse)
async def get_story(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story detail."""
return await story_service.get_story_detail(story_id, user.id, db)
@router.delete("/stories/{story_id}")
async def delete_story(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Delete story."""
await story_service.delete_story(story_id, user.id, db)
return {"message": "Deleted"}
@router.post("/image/generate/{story_id}", response_model=StoryImageResponse)
async def generate_story_image(
story_id: int,
response: Response,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate cover image for story."""
_mark_legacy_generation_endpoint(
response,
f"/api/generations/{story_id}/retry-assets",
)
url = await story_service.generate_story_cover(story_id, user.id, db)
story = await story_service.get_story_detail(story_id, user.id, db)
return {
"image_url": url,
"generation_status": story.generation_status,
"text_status": story.text_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
"retryable_assets": story.retryable_assets,
}
@router.post("/stories/{story_id}/assets/retry", response_model=StoryDetailResponse)
async def retry_story_assets(
story_id: int,
payload: StoryAssetRetryRequest,
response: Response,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Retry selected generated assets for a story."""
_mark_legacy_generation_endpoint(
response,
f"/api/generations/{story_id}/retry-assets",
)
return await story_service.retry_story_assets(story_id, user.id, payload.assets, db)
@router.get("/audio/{story_id}")
async def get_story_audio(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story audio (MP3)."""
audio_bytes = await story_service.generate_story_audio(story_id, user.id, db)
return Response(content=audio_bytes, media_type="audio/mpeg")
@router.get("/audio/{story_id}/status", response_model=StoryAudioStatusResponse)
async def get_story_audio_status(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get audio cache status without generating audio."""
return await story_service.get_story_audio_status(story_id, user.id, db)
@router.delete("/audio/{story_id}/cache", response_model=StoryAudioStatusResponse)
async def delete_story_audio_cache(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Clear cached story audio so it can be regenerated."""
return await story_service.clear_story_audio_cache(story_id, user.id, db)
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
async def get_story_achievements(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story achievements."""
return await story_service.get_story_achievements(story_id, user.id, db)