Some checks failed
Build and Push Docker Images / changes (push) Has been cancelled
Build and Push Docker Images / build-backend (push) Has been cancelled
Build and Push Docker Images / build-frontend (push) Has been cancelled
Build and Push Docker Images / build-admin-frontend (push) Has been cancelled
265 lines
9.0 KiB
Python
265 lines
9.0 KiB
Python
"""Story related APIs."""
|
||
|
||
import json
|
||
import uuid
|
||
from typing import AsyncGenerator
|
||
|
||
from fastapi import APIRouter, Depends, Response
|
||
from sse_starlette.sse import EventSourceResponse
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.deps import require_user
|
||
from app.core.logging import get_logger
|
||
from app.core.rate_limiter import check_rate_limit
|
||
from app.db.database import get_db
|
||
from app.db.models import User
|
||
from app.schemas.story_schemas import (
|
||
AchievementItem,
|
||
FullStoryResponse,
|
||
GenerateRequest,
|
||
StoryDetailResponse,
|
||
StoryImageResponse,
|
||
StoryListItem,
|
||
StoryResponse,
|
||
StorybookRequest,
|
||
StorybookResponse,
|
||
)
|
||
from app.services import story_service
|
||
from app.services.memory_service import build_enhanced_memory_context
|
||
from app.services.provider_router import (
|
||
generate_story_content,
|
||
generate_image,
|
||
)
|
||
from app.services.story_status import StoryAssetStatus, sync_story_status
|
||
|
||
logger = get_logger(__name__)
|
||
router = APIRouter()
|
||
|
||
RATE_LIMIT_WINDOW = 60 # seconds
|
||
RATE_LIMIT_REQUESTS = 10
|
||
|
||
|
||
@router.post("/stories/generate", response_model=StoryResponse)
|
||
async def generate_story(
|
||
request: GenerateRequest,
|
||
user: User = Depends(require_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Generate or enhance a story."""
|
||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||
return await story_service.generate_and_save_story(request, user.id, db)
|
||
|
||
|
||
@router.post("/stories/generate/full", response_model=FullStoryResponse)
|
||
async def generate_story_full(
|
||
request: GenerateRequest,
|
||
user: User = Depends(require_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Generate complete story (story + parallel image/audio generation)."""
|
||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||
return await story_service.generate_full_story_service(request, user.id, db)
|
||
|
||
|
||
@router.post("/stories/generate/stream")
|
||
async def generate_story_stream(
|
||
request: GenerateRequest,
|
||
user: User = Depends(require_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""流式生成故事(SSE)。"""
|
||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||
|
||
# Validation
|
||
profile_id, universe_id = await story_service.validate_profile_and_universe(
|
||
request.child_profile_id, request.universe_id, user.id, db
|
||
)
|
||
|
||
# Build Context
|
||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||
|
||
async def event_generator() -> AsyncGenerator[dict, None]:
|
||
story_id = str(uuid.uuid4())
|
||
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
|
||
|
||
# Step 1: Generate Content
|
||
try:
|
||
result = await generate_story_content(
|
||
input_type=request.type,
|
||
data=request.data,
|
||
education_theme=request.education_theme,
|
||
memory_context=memory_context,
|
||
db=db,
|
||
)
|
||
except Exception as e:
|
||
logger.error("sse_story_generation_failed", error=str(e))
|
||
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
|
||
return
|
||
|
||
# Save Story
|
||
story = await story_service.create_story_from_result(
|
||
result, user.id, profile_id, universe_id, db
|
||
)
|
||
|
||
yield {
|
||
"event": "story_ready",
|
||
"data": json.dumps({
|
||
"id": story.id,
|
||
"title": story.title,
|
||
"content": story.story_text,
|
||
"cover_prompt": story.cover_prompt,
|
||
"mode": story.mode,
|
||
"child_profile_id": story.child_profile_id,
|
||
"universe_id": story.universe_id,
|
||
"generation_status": story.generation_status,
|
||
"image_status": story.image_status,
|
||
"audio_status": story.audio_status,
|
||
"last_error": story.last_error,
|
||
}),
|
||
}
|
||
|
||
# Step 2: Generate Image
|
||
if story.cover_prompt:
|
||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||
await db.commit()
|
||
try:
|
||
# Direct call to provider router's generate_image, sharing db session
|
||
image_url = await generate_image(story.cover_prompt, db=db)
|
||
story.image_url = image_url
|
||
sync_story_status(
|
||
story,
|
||
image_status=StoryAssetStatus.READY,
|
||
)
|
||
await db.commit()
|
||
yield {
|
||
"event": "image_ready",
|
||
"data": json.dumps(
|
||
{
|
||
"image_url": image_url,
|
||
"generation_status": story.generation_status,
|
||
"image_status": story.image_status,
|
||
"audio_status": story.audio_status,
|
||
"last_error": story.last_error,
|
||
}
|
||
),
|
||
}
|
||
except Exception as e:
|
||
sync_story_status(
|
||
story,
|
||
image_status=StoryAssetStatus.FAILED,
|
||
last_error=str(e),
|
||
)
|
||
await db.commit()
|
||
logger.warning("sse_image_generation_failed", story_id=story.id, error=str(e))
|
||
yield {
|
||
"event": "image_failed",
|
||
"data": json.dumps(
|
||
{
|
||
"error": str(e),
|
||
"generation_status": story.generation_status,
|
||
"image_status": story.image_status,
|
||
"audio_status": story.audio_status,
|
||
"last_error": story.last_error,
|
||
}
|
||
),
|
||
}
|
||
|
||
yield {
|
||
"event": "complete",
|
||
"data": json.dumps(
|
||
{
|
||
"story_id": story.id,
|
||
"generation_status": story.generation_status,
|
||
"image_status": story.image_status,
|
||
"audio_status": story.audio_status,
|
||
"last_error": story.last_error,
|
||
}
|
||
),
|
||
}
|
||
|
||
return EventSourceResponse(event_generator())
|
||
|
||
|
||
@router.post("/storybook/generate", response_model=StorybookResponse)
|
||
async def generate_storybook_api(
|
||
request: StorybookRequest,
|
||
user: User = Depends(require_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Generate storybook."""
|
||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||
return await story_service.generate_storybook_service(request, user.id, db)
|
||
|
||
|
||
# ==================== Missing Endpoints (Issue #5) ====================
|
||
|
||
@router.get("/stories", response_model=list[StoryListItem])
|
||
async def list_stories(
|
||
limit: int = 20,
|
||
offset: int = 0,
|
||
user: User = Depends(require_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""List stories."""
|
||
return await story_service.list_stories(user.id, limit, offset, db)
|
||
|
||
|
||
@router.get("/stories/{story_id}", response_model=StoryDetailResponse)
|
||
async def get_story(
|
||
story_id: int,
|
||
user: User = Depends(require_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Get story detail."""
|
||
return await story_service.get_story_detail(story_id, user.id, db)
|
||
|
||
|
||
@router.delete("/stories/{story_id}")
|
||
async def delete_story(
|
||
story_id: int,
|
||
user: User = Depends(require_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Delete story."""
|
||
await story_service.delete_story(story_id, user.id, db)
|
||
return {"message": "Deleted"}
|
||
|
||
|
||
@router.post("/image/generate/{story_id}", response_model=StoryImageResponse)
|
||
async def generate_story_image(
|
||
story_id: int,
|
||
user: User = Depends(require_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Generate cover image for story."""
|
||
url = await story_service.generate_story_cover(story_id, user.id, db)
|
||
story = await story_service.get_story_detail(story_id, user.id, db)
|
||
return {
|
||
"image_url": url,
|
||
"generation_status": story.generation_status,
|
||
"image_status": story.image_status,
|
||
"audio_status": story.audio_status,
|
||
"last_error": story.last_error,
|
||
}
|
||
|
||
|
||
@router.get("/audio/{story_id}")
|
||
async def get_story_audio(
|
||
story_id: int,
|
||
user: User = Depends(require_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Get story audio (MP3)."""
|
||
audio_bytes = await story_service.generate_story_audio(story_id, user.id, db)
|
||
return Response(content=audio_bytes, media_type="audio/mpeg")
|
||
|
||
|
||
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
|
||
async def get_story_achievements(
|
||
story_id: int,
|
||
user: User = Depends(require_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Get story achievements."""
|
||
return await story_service.get_story_achievements(story_id, user.id, db)
|