feat: persist story generation states and cache audio
Some checks failed
Build and Push Docker Images / changes (push) Has been cancelled
Build and Push Docker Images / build-backend (push) Has been cancelled
Build and Push Docker Images / build-frontend (push) Has been cancelled
Build and Push Docker Images / build-admin-frontend (push) Has been cancelled
Some checks failed
Build and Push Docker Images / changes (push) Has been cancelled
Build and Push Docker Images / build-backend (push) Has been cancelled
Build and Push Docker Images / build-frontend (push) Has been cancelled
Build and Push Docker Images / build-admin-frontend (push) Has been cancelled
This commit is contained in:
51
backend/.gitignore
vendored
51
backend/.gitignore
vendored
@@ -1,27 +1,28 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# 环境变量
|
||||
.env
|
||||
|
||||
# 测试
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# 其他
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# 环境变量
|
||||
.env
|
||||
|
||||
# 测试
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# 其他
|
||||
*.log
|
||||
.DS_Store
|
||||
storage/
|
||||
|
||||
151
backend/alembic/versions/0009_add_story_generation_statuses.py
Normal file
151
backend/alembic/versions/0009_add_story_generation_statuses.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""add story generation status fields
|
||||
|
||||
Revision ID: 0009_add_story_generation_statuses
|
||||
Revises: 0008_add_pages_to_stories
|
||||
Create Date: 2026-04-17
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0009_add_story_generation_statuses"
|
||||
down_revision = "0008_add_pages_to_stories"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
stories = sa.table(
|
||||
"stories",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("story_text", sa.Text),
|
||||
sa.column("pages", sa.JSON),
|
||||
sa.column("cover_prompt", sa.Text),
|
||||
sa.column("image_url", sa.String(length=500)),
|
||||
sa.column("generation_status", sa.String(length=32)),
|
||||
sa.column("image_status", sa.String(length=32)),
|
||||
sa.column("audio_status", sa.String(length=32)),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_image_status(row: dict) -> str:
|
||||
pages = row.get("pages") or []
|
||||
|
||||
expected_assets = 0
|
||||
ready_assets = 0
|
||||
|
||||
if row.get("cover_prompt") or row.get("image_url"):
|
||||
expected_assets += 1
|
||||
if row.get("image_url"):
|
||||
ready_assets += 1
|
||||
|
||||
for page in pages:
|
||||
if not isinstance(page, dict):
|
||||
continue
|
||||
if not page.get("image_prompt") and not page.get("image_url"):
|
||||
continue
|
||||
expected_assets += 1
|
||||
if page.get("image_url"):
|
||||
ready_assets += 1
|
||||
|
||||
if expected_assets == 0:
|
||||
return "not_requested"
|
||||
|
||||
if ready_assets == expected_assets:
|
||||
return "ready"
|
||||
|
||||
return "failed"
|
||||
|
||||
|
||||
def _resolve_generation_status(
|
||||
*,
|
||||
story_text: str | None,
|
||||
pages: list[dict] | None,
|
||||
image_status: str,
|
||||
audio_status: str,
|
||||
) -> str:
|
||||
has_narrative = bool(story_text) or bool(pages)
|
||||
if not has_narrative:
|
||||
return "failed"
|
||||
|
||||
if "generating" in {image_status, audio_status}:
|
||||
return "assets_generating"
|
||||
|
||||
if "failed" in {image_status, audio_status}:
|
||||
return "degraded_completed"
|
||||
|
||||
if image_status == "not_requested" and audio_status == "not_requested":
|
||||
return "narrative_ready"
|
||||
|
||||
return "completed"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"stories",
|
||||
sa.Column(
|
||||
"generation_status",
|
||||
sa.String(length=32),
|
||||
nullable=False,
|
||||
server_default="narrative_ready",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"stories",
|
||||
sa.Column(
|
||||
"image_status",
|
||||
sa.String(length=32),
|
||||
nullable=False,
|
||||
server_default="not_requested",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"stories",
|
||||
sa.Column(
|
||||
"audio_status",
|
||||
sa.String(length=32),
|
||||
nullable=False,
|
||||
server_default="not_requested",
|
||||
),
|
||||
)
|
||||
op.add_column("stories", sa.Column("last_error", sa.Text(), nullable=True))
|
||||
|
||||
connection = op.get_bind()
|
||||
rows = connection.execute(
|
||||
sa.select(
|
||||
stories.c.id,
|
||||
stories.c.story_text,
|
||||
stories.c.pages,
|
||||
stories.c.cover_prompt,
|
||||
stories.c.image_url,
|
||||
)
|
||||
).mappings()
|
||||
|
||||
for row in rows:
|
||||
image_status = _resolve_image_status(row)
|
||||
audio_status = "not_requested"
|
||||
generation_status = _resolve_generation_status(
|
||||
story_text=row.get("story_text"),
|
||||
pages=row.get("pages"),
|
||||
image_status=image_status,
|
||||
audio_status=audio_status,
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
stories.update()
|
||||
.where(stories.c.id == row["id"])
|
||||
.values(
|
||||
generation_status=generation_status,
|
||||
image_status=image_status,
|
||||
audio_status=audio_status,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("stories", "last_error")
|
||||
op.drop_column("stories", "audio_status")
|
||||
op.drop_column("stories", "image_status")
|
||||
op.drop_column("stories", "generation_status")
|
||||
25
backend/alembic/versions/0010_add_story_audio_cache_path.py
Normal file
25
backend/alembic/versions/0010_add_story_audio_cache_path.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""add audio cache path to stories
|
||||
|
||||
Revision ID: 0010_add_story_audio_cache_path
|
||||
Revises: 0009_add_story_generation_statuses
|
||||
Create Date: 2026-04-17
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0010_add_story_audio_cache_path"
|
||||
down_revision = "0009_add_story_generation_statuses"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("stories", sa.Column("audio_path", sa.String(length=500), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("stories", "audio_path")
|
||||
@@ -1,27 +1,28 @@
|
||||
"""Story related APIs."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.core.logging import get_logger
|
||||
from app.core.rate_limiter import check_rate_limit
|
||||
from app.db.database import get_db
|
||||
from app.db.models import User
|
||||
from fastapi import APIRouter, Depends, Response
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.core.logging import get_logger
|
||||
from app.core.rate_limiter import check_rate_limit
|
||||
from app.db.database import get_db
|
||||
from app.db.models import User
|
||||
from app.schemas.story_schemas import (
|
||||
GenerateRequest,
|
||||
StoryResponse,
|
||||
AchievementItem,
|
||||
FullStoryResponse,
|
||||
GenerateRequest,
|
||||
StoryDetailResponse,
|
||||
StoryImageResponse,
|
||||
StoryListItem,
|
||||
StoryResponse,
|
||||
StorybookRequest,
|
||||
StorybookResponse,
|
||||
StoryListItem,
|
||||
AchievementItem,
|
||||
)
|
||||
from app.services import story_service
|
||||
from app.services.memory_service import build_enhanced_memory_context
|
||||
@@ -29,153 +30,202 @@ from app.services.provider_router import (
|
||||
generate_story_content,
|
||||
generate_image,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
RATE_LIMIT_WINDOW = 60 # seconds
|
||||
RATE_LIMIT_REQUESTS = 10
|
||||
|
||||
|
||||
@router.post("/stories/generate", response_model=StoryResponse)
|
||||
async def generate_story(
|
||||
request: GenerateRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Generate or enhance a story."""
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
return await story_service.generate_and_save_story(request, user.id, db)
|
||||
|
||||
|
||||
@router.post("/stories/generate/full", response_model=FullStoryResponse)
|
||||
async def generate_story_full(
|
||||
request: GenerateRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Generate complete story (story + parallel image/audio generation)."""
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
return await story_service.generate_full_story_service(request, user.id, db)
|
||||
|
||||
|
||||
@router.post("/stories/generate/stream")
|
||||
from app.services.story_status import StoryAssetStatus, sync_story_status
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
RATE_LIMIT_WINDOW = 60 # seconds
|
||||
RATE_LIMIT_REQUESTS = 10
|
||||
|
||||
|
||||
@router.post("/stories/generate", response_model=StoryResponse)
|
||||
async def generate_story(
|
||||
request: GenerateRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Generate or enhance a story."""
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
return await story_service.generate_and_save_story(request, user.id, db)
|
||||
|
||||
|
||||
@router.post("/stories/generate/full", response_model=FullStoryResponse)
|
||||
async def generate_story_full(
|
||||
request: GenerateRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Generate complete story (story + parallel image/audio generation)."""
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
return await story_service.generate_full_story_service(request, user.id, db)
|
||||
|
||||
|
||||
@router.post("/stories/generate/stream")
|
||||
async def generate_story_stream(
|
||||
request: GenerateRequest,
|
||||
req: Request,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""流式生成故事(SSE)。"""
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
|
||||
# Validation
|
||||
profile_id, universe_id = await story_service.validate_profile_and_universe(
|
||||
request.child_profile_id, request.universe_id, user.id, db
|
||||
)
|
||||
|
||||
# Build Context
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[dict, None]:
|
||||
story_id = str(uuid.uuid4())
|
||||
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
|
||||
|
||||
# Step 1: Generate Content
|
||||
try:
|
||||
result = await generate_story_content(
|
||||
input_type=request.type,
|
||||
data=request.data,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
db=db,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("sse_story_generation_failed", error=str(e))
|
||||
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
|
||||
return
|
||||
|
||||
# Save Story
|
||||
story = await story_service.create_story_from_result(
|
||||
result, user.id, profile_id, universe_id, db
|
||||
)
|
||||
|
||||
"""流式生成故事(SSE)。"""
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
|
||||
# Validation
|
||||
profile_id, universe_id = await story_service.validate_profile_and_universe(
|
||||
request.child_profile_id, request.universe_id, user.id, db
|
||||
)
|
||||
|
||||
# Build Context
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[dict, None]:
|
||||
story_id = str(uuid.uuid4())
|
||||
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
|
||||
|
||||
# Step 1: Generate Content
|
||||
try:
|
||||
result = await generate_story_content(
|
||||
input_type=request.type,
|
||||
data=request.data,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
db=db,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("sse_story_generation_failed", error=str(e))
|
||||
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
|
||||
return
|
||||
|
||||
# Save Story
|
||||
story = await story_service.create_story_from_result(
|
||||
result, user.id, profile_id, universe_id, db
|
||||
)
|
||||
|
||||
yield {
|
||||
"event": "story_ready",
|
||||
"data": json.dumps({
|
||||
"id": story.id,
|
||||
"title": story.title,
|
||||
"content": story.story_text,
|
||||
"content": story.story_text,
|
||||
"cover_prompt": story.cover_prompt,
|
||||
"mode": story.mode,
|
||||
"child_profile_id": story.child_profile_id,
|
||||
"universe_id": story.universe_id,
|
||||
"generation_status": story.generation_status,
|
||||
"image_status": story.image_status,
|
||||
"audio_status": story.audio_status,
|
||||
"last_error": story.last_error,
|
||||
}),
|
||||
}
|
||||
|
||||
# Step 2: Generate Image
|
||||
if story.cover_prompt:
|
||||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
try:
|
||||
# Direct call to provider router's generate_image, sharing db session
|
||||
image_url = await generate_image(story.cover_prompt, db=db)
|
||||
story.image_url = image_url
|
||||
sync_story_status(
|
||||
story,
|
||||
image_status=StoryAssetStatus.READY,
|
||||
)
|
||||
await db.commit()
|
||||
yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})}
|
||||
yield {
|
||||
"event": "image_ready",
|
||||
"data": json.dumps(
|
||||
{
|
||||
"image_url": image_url,
|
||||
"generation_status": story.generation_status,
|
||||
"image_status": story.image_status,
|
||||
"audio_status": story.audio_status,
|
||||
"last_error": story.last_error,
|
||||
}
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
sync_story_status(
|
||||
story,
|
||||
image_status=StoryAssetStatus.FAILED,
|
||||
last_error=str(e),
|
||||
)
|
||||
await db.commit()
|
||||
logger.warning("sse_image_generation_failed", story_id=story.id, error=str(e))
|
||||
yield {"event": "image_failed", "data": json.dumps({"error": str(e)})}
|
||||
yield {
|
||||
"event": "image_failed",
|
||||
"data": json.dumps(
|
||||
{
|
||||
"error": str(e),
|
||||
"generation_status": story.generation_status,
|
||||
"image_status": story.image_status,
|
||||
"audio_status": story.audio_status,
|
||||
"last_error": story.last_error,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
yield {"event": "complete", "data": json.dumps({"story_id": story.id})}
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
|
||||
@router.post("/storybook/generate", response_model=StorybookResponse)
|
||||
async def generate_storybook_api(
|
||||
request: StorybookRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Generate storybook."""
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
return await story_service.generate_storybook_service(request, user.id, db)
|
||||
|
||||
|
||||
# ==================== Missing Endpoints (Issue #5) ====================
|
||||
|
||||
@router.get("/stories", response_model=list[StoryListItem])
|
||||
async def list_stories(
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List stories."""
|
||||
return await story_service.list_stories(user.id, limit, offset, db)
|
||||
|
||||
|
||||
@router.get("/stories/{story_id}", response_model=StoryResponse)
|
||||
yield {
|
||||
"event": "complete",
|
||||
"data": json.dumps(
|
||||
{
|
||||
"story_id": story.id,
|
||||
"generation_status": story.generation_status,
|
||||
"image_status": story.image_status,
|
||||
"audio_status": story.audio_status,
|
||||
"last_error": story.last_error,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
|
||||
@router.post("/storybook/generate", response_model=StorybookResponse)
|
||||
async def generate_storybook_api(
|
||||
request: StorybookRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Generate storybook."""
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
return await story_service.generate_storybook_service(request, user.id, db)
|
||||
|
||||
|
||||
# ==================== Missing Endpoints (Issue #5) ====================
|
||||
|
||||
@router.get("/stories", response_model=list[StoryListItem])
|
||||
async def list_stories(
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List stories."""
|
||||
return await story_service.list_stories(user.id, limit, offset, db)
|
||||
|
||||
|
||||
@router.get("/stories/{story_id}", response_model=StoryDetailResponse)
|
||||
async def get_story(
|
||||
story_id: int,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get story detail."""
|
||||
return await story_service.get_story_detail(story_id, user.id, db)
|
||||
|
||||
|
||||
@router.delete("/stories/{story_id}")
|
||||
async def delete_story(
|
||||
story_id: int,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete story."""
|
||||
await story_service.delete_story(story_id, user.id, db)
|
||||
return {"message": "Deleted"}
|
||||
|
||||
|
||||
@router.post("/image/generate/{story_id}")
|
||||
"""Get story detail."""
|
||||
return await story_service.get_story_detail(story_id, user.id, db)
|
||||
|
||||
|
||||
@router.delete("/stories/{story_id}")
|
||||
async def delete_story(
|
||||
story_id: int,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete story."""
|
||||
await story_service.delete_story(story_id, user.id, db)
|
||||
return {"message": "Deleted"}
|
||||
|
||||
|
||||
@router.post("/image/generate/{story_id}", response_model=StoryImageResponse)
|
||||
async def generate_story_image(
|
||||
story_id: int,
|
||||
user: User = Depends(require_user),
|
||||
@@ -183,25 +233,32 @@ async def generate_story_image(
|
||||
):
|
||||
"""Generate cover image for story."""
|
||||
url = await story_service.generate_story_cover(story_id, user.id, db)
|
||||
return {"image_url": url}
|
||||
|
||||
|
||||
@router.get("/audio/{story_id}")
|
||||
async def get_story_audio(
|
||||
story_id: int,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get story audio (MP3)."""
|
||||
audio_bytes = await story_service.generate_story_audio(story_id, user.id, db)
|
||||
return Response(content=audio_bytes, media_type="audio/mpeg")
|
||||
|
||||
|
||||
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
|
||||
async def get_story_achievements(
|
||||
story_id: int,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get story achievements."""
|
||||
return await story_service.get_story_achievements(story_id, user.id, db)
|
||||
story = await story_service.get_story_detail(story_id, user.id, db)
|
||||
return {
|
||||
"image_url": url,
|
||||
"generation_status": story.generation_status,
|
||||
"image_status": story.image_status,
|
||||
"audio_status": story.audio_status,
|
||||
"last_error": story.last_error,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/audio/{story_id}")
|
||||
async def get_story_audio(
|
||||
story_id: int,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get story audio (MP3)."""
|
||||
audio_bytes = await story_service.generate_story_audio(story_id, user.id, db)
|
||||
return Response(content=audio_bytes, media_type="audio/mpeg")
|
||||
|
||||
|
||||
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
|
||||
async def get_story_achievements(
|
||||
story_id: int,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get story achievements."""
|
||||
return await story_service.get_story_achievements(story_id, user.id, db)
|
||||
|
||||
@@ -1,130 +1,134 @@
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用全局配置"""
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
# 应用基础配置
|
||||
app_name: str = "DreamWeaver"
|
||||
debug: bool = False
|
||||
secret_key: str = Field(..., description="JWT 签名密钥")
|
||||
base_url: str = Field("http://localhost:8000", description="后端对外回调地址")
|
||||
|
||||
# 数据库
|
||||
database_url: str = Field(..., description="SQLAlchemy async URL")
|
||||
|
||||
# OAuth - GitHub
|
||||
github_client_id: str = ""
|
||||
github_client_secret: str = ""
|
||||
|
||||
# OAuth - Google
|
||||
google_client_id: str = ""
|
||||
google_client_secret: str = ""
|
||||
|
||||
# AI Capability Keys
|
||||
text_api_key: str = ""
|
||||
tts_api_base: str = ""
|
||||
tts_api_key: str = ""
|
||||
image_api_key: str = ""
|
||||
|
||||
# Additional Provider API Keys
|
||||
openai_api_key: str = ""
|
||||
elevenlabs_api_key: str = ""
|
||||
cqtai_api_key: str = ""
|
||||
minimax_api_key: str = ""
|
||||
minimax_group_id: str = ""
|
||||
antigravity_api_key: str = ""
|
||||
antigravity_api_base: str = ""
|
||||
|
||||
# AI Model Configuration
|
||||
text_model: str = "gemini-2.0-flash"
|
||||
openai_model: str = "gpt-4o-mini"
|
||||
tts_model: str = ""
|
||||
image_model: str = "nano-banana-pro"
|
||||
tts_minimax_model: str = "speech-2.6-turbo"
|
||||
tts_elevenlabs_model: str = "eleven_multilingual_v2"
|
||||
tts_edge_voice: str = "zh-CN-XiaoxiaoNeural"
|
||||
antigravity_model: str = "gemini-3-pro-image"
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用全局配置"""
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
# 应用基础配置
|
||||
app_name: str = "DreamWeaver"
|
||||
debug: bool = False
|
||||
secret_key: str = Field(..., description="JWT 签名密钥")
|
||||
base_url: str = Field("http://localhost:8000", description="后端对外回调地址")
|
||||
|
||||
# 数据库
|
||||
database_url: str = Field(..., description="SQLAlchemy async URL")
|
||||
|
||||
# OAuth - GitHub
|
||||
github_client_id: str = ""
|
||||
github_client_secret: str = ""
|
||||
|
||||
# OAuth - Google
|
||||
google_client_id: str = ""
|
||||
google_client_secret: str = ""
|
||||
|
||||
# AI Capability Keys
|
||||
text_api_key: str = ""
|
||||
tts_api_base: str = ""
|
||||
tts_api_key: str = ""
|
||||
image_api_key: str = ""
|
||||
|
||||
# Additional Provider API Keys
|
||||
openai_api_key: str = ""
|
||||
elevenlabs_api_key: str = ""
|
||||
cqtai_api_key: str = ""
|
||||
minimax_api_key: str = ""
|
||||
minimax_group_id: str = ""
|
||||
antigravity_api_key: str = ""
|
||||
antigravity_api_base: str = ""
|
||||
|
||||
# AI Model Configuration
|
||||
text_model: str = "gemini-2.0-flash"
|
||||
openai_model: str = "gpt-4o-mini"
|
||||
tts_model: str = ""
|
||||
image_model: str = "nano-banana-pro"
|
||||
tts_minimax_model: str = "speech-2.6-turbo"
|
||||
tts_elevenlabs_model: str = "eleven_multilingual_v2"
|
||||
tts_edge_voice: str = "zh-CN-XiaoxiaoNeural"
|
||||
antigravity_model: str = "gemini-3-pro-image"
|
||||
|
||||
# Provider routing (ordered lists)
|
||||
text_providers: list[str] = Field(default_factory=lambda: ["gemini"])
|
||||
image_providers: list[str] = Field(default_factory=lambda: ["cqtai"])
|
||||
tts_providers: list[str] = Field(default_factory=lambda: ["minimax", "elevenlabs", "edge_tts"])
|
||||
story_audio_cache_dir: str = Field(
|
||||
"storage/audio",
|
||||
description="Directory for cached story audio files",
|
||||
)
|
||||
|
||||
# Celery (Redis)
|
||||
celery_broker_url: str = Field("redis://localhost:6379/0")
|
||||
celery_result_backend: str = Field("redis://localhost:6379/0")
|
||||
|
||||
# Generic Redis
|
||||
redis_url: str = Field("redis://localhost:6379/0", description="Redis connection URL")
|
||||
redis_sentinel_enabled: bool = Field(False, description="Whether to enable Redis Sentinel")
|
||||
redis_sentinel_nodes: str = Field(
|
||||
"",
|
||||
description="Comma-separated Redis Sentinel nodes, e.g. host1:26379,host2:26379",
|
||||
)
|
||||
redis_sentinel_master_name: str = Field("mymaster", description="Redis Sentinel master name")
|
||||
redis_sentinel_password: str = Field("", description="Password for Redis Sentinel (optional)")
|
||||
redis_sentinel_db: int = Field(0, description="Redis DB index when using Sentinel")
|
||||
redis_sentinel_socket_timeout: float = Field(
|
||||
0.5,
|
||||
description="Socket timeout in seconds for Sentinel clients",
|
||||
)
|
||||
|
||||
# Admin console
|
||||
enable_admin_console: bool = False
|
||||
admin_username: str = "admin"
|
||||
admin_password: str = "admin123" # 建议通过环境变量覆盖
|
||||
|
||||
# CORS
|
||||
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173"])
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _require_core_settings(self) -> "Settings": # type: ignore[override]
|
||||
missing = []
|
||||
if not self.secret_key or self.secret_key == "change-me-in-production":
|
||||
missing.append("SECRET_KEY")
|
||||
if not self.database_url:
|
||||
missing.append("DATABASE_URL")
|
||||
if self.redis_sentinel_enabled and not self.redis_sentinel_nodes.strip():
|
||||
missing.append("REDIS_SENTINEL_NODES")
|
||||
if missing:
|
||||
raise ValueError(f"Missing required settings: {', '.join(missing)}")
|
||||
return self
|
||||
|
||||
@property
|
||||
def redis_sentinel_hosts(self) -> list[tuple[str, int]]:
|
||||
"""Parse Redis Sentinel nodes into (host, port) tuples."""
|
||||
nodes = []
|
||||
raw = self.redis_sentinel_nodes.strip()
|
||||
if not raw:
|
||||
return nodes
|
||||
|
||||
for item in raw.split(","):
|
||||
value = item.strip()
|
||||
if not value:
|
||||
continue
|
||||
if ":" not in value:
|
||||
raise ValueError(f"Invalid sentinel node format: {value}")
|
||||
host, port_text = value.rsplit(":", 1)
|
||||
if not host:
|
||||
raise ValueError(f"Invalid sentinel node host: {value}")
|
||||
try:
|
||||
port = int(port_text)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Invalid sentinel node port: {value}") from exc
|
||||
nodes.append((host, port))
|
||||
return nodes
|
||||
|
||||
@property
|
||||
def redis_sentinel_urls(self) -> list[str]:
|
||||
"""Build Celery-compatible Sentinel URLs with DB index."""
|
||||
return [
|
||||
f"sentinel://{host}:{port}/{self.redis_sentinel_db}"
|
||||
for host, port in self.redis_sentinel_hosts
|
||||
]
|
||||
|
||||
|
||||
settings = Settings()
|
||||
celery_result_backend: str = Field("redis://localhost:6379/0")
|
||||
|
||||
# Generic Redis
|
||||
redis_url: str = Field("redis://localhost:6379/0", description="Redis connection URL")
|
||||
redis_sentinel_enabled: bool = Field(False, description="Whether to enable Redis Sentinel")
|
||||
redis_sentinel_nodes: str = Field(
|
||||
"",
|
||||
description="Comma-separated Redis Sentinel nodes, e.g. host1:26379,host2:26379",
|
||||
)
|
||||
redis_sentinel_master_name: str = Field("mymaster", description="Redis Sentinel master name")
|
||||
redis_sentinel_password: str = Field("", description="Password for Redis Sentinel (optional)")
|
||||
redis_sentinel_db: int = Field(0, description="Redis DB index when using Sentinel")
|
||||
redis_sentinel_socket_timeout: float = Field(
|
||||
0.5,
|
||||
description="Socket timeout in seconds for Sentinel clients",
|
||||
)
|
||||
|
||||
# Admin console
|
||||
enable_admin_console: bool = False
|
||||
admin_username: str = "admin"
|
||||
admin_password: str = "admin123" # 建议通过环境变量覆盖
|
||||
|
||||
# CORS
|
||||
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173"])
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _require_core_settings(self) -> "Settings": # type: ignore[override]
|
||||
missing = []
|
||||
if not self.secret_key or self.secret_key == "change-me-in-production":
|
||||
missing.append("SECRET_KEY")
|
||||
if not self.database_url:
|
||||
missing.append("DATABASE_URL")
|
||||
if self.redis_sentinel_enabled and not self.redis_sentinel_nodes.strip():
|
||||
missing.append("REDIS_SENTINEL_NODES")
|
||||
if missing:
|
||||
raise ValueError(f"Missing required settings: {', '.join(missing)}")
|
||||
return self
|
||||
|
||||
@property
|
||||
def redis_sentinel_hosts(self) -> list[tuple[str, int]]:
|
||||
"""Parse Redis Sentinel nodes into (host, port) tuples."""
|
||||
nodes = []
|
||||
raw = self.redis_sentinel_nodes.strip()
|
||||
if not raw:
|
||||
return nodes
|
||||
|
||||
for item in raw.split(","):
|
||||
value = item.strip()
|
||||
if not value:
|
||||
continue
|
||||
if ":" not in value:
|
||||
raise ValueError(f"Invalid sentinel node format: {value}")
|
||||
host, port_text = value.rsplit(":", 1)
|
||||
if not host:
|
||||
raise ValueError(f"Invalid sentinel node host: {value}")
|
||||
try:
|
||||
port = int(port_text)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Invalid sentinel node port: {value}") from exc
|
||||
nodes.append((host, port))
|
||||
return nodes
|
||||
|
||||
@property
|
||||
def redis_sentinel_urls(self) -> list[str]:
|
||||
"""Build Celery-compatible Sentinel URLs with DB index."""
|
||||
return [
|
||||
f"sentinel://{host}:{port}/{self.redis_sentinel_db}"
|
||||
for host, port in self.redis_sentinel_hosts
|
||||
]
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -27,10 +27,10 @@ class User(Base):
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(255), primary_key=True) # OAuth provider user ID
|
||||
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
avatar_url: Mapped[str | None] = mapped_column(String(500))
|
||||
provider: Mapped[str] = mapped_column(String(50), nullable=False) # github / google
|
||||
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
@@ -59,11 +59,22 @@ class Story(Base):
|
||||
String(36), ForeignKey("story_universes.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
story_text: Mapped[str] = mapped_column(Text, nullable=True) # 允许为空(绘本模式下)
|
||||
pages: Mapped[list[dict] | None] = mapped_column(JSON, default=list) # 绘本分页数据
|
||||
story_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
pages: Mapped[list[dict] | None] = mapped_column(JSON, default=list)
|
||||
cover_prompt: Mapped[str | None] = mapped_column(Text)
|
||||
image_url: Mapped[str | None] = mapped_column(String(500))
|
||||
mode: Mapped[str] = mapped_column(String(20), nullable=False, default="generated")
|
||||
generation_status: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, default="narrative_ready"
|
||||
)
|
||||
image_status: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, default="not_requested"
|
||||
)
|
||||
audio_status: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, default="not_requested"
|
||||
)
|
||||
audio_path: Mapped[str | None] = mapped_column(String(500))
|
||||
last_error: Mapped[str | None] = mapped_column(Text)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
@@ -123,6 +134,7 @@ class ChildProfile(Base):
|
||||
|
||||
class StoryUniverse(Base):
|
||||
"""Story universe entity."""
|
||||
|
||||
__tablename__ = "story_universes"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
@@ -142,7 +154,9 @@ class StoryUniverse(Base):
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
child_profile: Mapped["ChildProfile"] = relationship("ChildProfile", back_populates="story_universes")
|
||||
child_profile: Mapped["ChildProfile"] = relationship(
|
||||
"ChildProfile", back_populates="story_universes"
|
||||
)
|
||||
|
||||
|
||||
class ReadingEvent(Base):
|
||||
@@ -163,6 +177,7 @@ class ReadingEvent(Base):
|
||||
DateTime(timezone=True), server_default=func.now(), index=True
|
||||
)
|
||||
|
||||
|
||||
class PushConfig(Base):
|
||||
"""Push configuration entity."""
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""故事相关 Pydantic 模型。"""
|
||||
"""Story-related Pydantic schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
@@ -11,7 +11,13 @@ MAX_EDU_THEME_LENGTH = 200
|
||||
MAX_TTS_LENGTH = 4000
|
||||
|
||||
|
||||
# ==================== 故事模型 ====================
|
||||
class StoryStatusMixin(BaseModel):
|
||||
"""Shared generation status fields returned by story APIs."""
|
||||
|
||||
generation_status: str
|
||||
image_status: str
|
||||
audio_status: str
|
||||
last_error: str | None = None
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
@@ -24,8 +30,8 @@ class GenerateRequest(BaseModel):
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
class StoryResponse(BaseModel):
|
||||
"""Story response."""
|
||||
class StoryResponse(StoryStatusMixin):
|
||||
"""Story generation response."""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
@@ -37,7 +43,7 @@ class StoryResponse(BaseModel):
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
class StoryListItem(BaseModel):
|
||||
class StoryListItem(StoryStatusMixin):
|
||||
"""Story list item."""
|
||||
|
||||
id: int
|
||||
@@ -47,8 +53,8 @@ class StoryListItem(BaseModel):
|
||||
mode: str
|
||||
|
||||
|
||||
class FullStoryResponse(BaseModel):
|
||||
"""完整故事响应(含图片和音频状态)。"""
|
||||
class FullStoryResponse(StoryStatusMixin):
|
||||
"""Full story response with asset status."""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
@@ -62,22 +68,19 @@ class FullStoryResponse(BaseModel):
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
# ==================== 绘本模型 ====================
|
||||
|
||||
|
||||
class StorybookRequest(BaseModel):
|
||||
"""Storybook 生成请求。"""
|
||||
"""Storybook generation request."""
|
||||
|
||||
keywords: str = Field(..., min_length=1, max_length=200)
|
||||
page_count: int = Field(default=6, ge=4, le=12)
|
||||
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
|
||||
generate_images: bool = Field(default=False, description="是否同时生成插图")
|
||||
generate_images: bool = Field(default=False, description="Whether to generate images too.")
|
||||
child_profile_id: str | None = None
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
class StorybookPageResponse(BaseModel):
|
||||
"""故事书单页响应。"""
|
||||
"""One storybook page."""
|
||||
|
||||
page_number: int
|
||||
text: str
|
||||
@@ -85,8 +88,8 @@ class StorybookPageResponse(BaseModel):
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
class StorybookResponse(BaseModel):
|
||||
"""故事书响应。"""
|
||||
class StorybookResponse(StoryStatusMixin):
|
||||
"""Storybook generation response."""
|
||||
|
||||
id: int | None = None
|
||||
title: str
|
||||
@@ -97,10 +100,29 @@ class StorybookResponse(BaseModel):
|
||||
cover_url: str | None = None
|
||||
|
||||
|
||||
# ==================== 成就模型 ====================
|
||||
class StoryDetailResponse(StoryStatusMixin):
|
||||
"""Story detail response for both stories and storybooks."""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
story_text: str | None = None
|
||||
pages: list[StorybookPageResponse] | None = None
|
||||
cover_prompt: str | None
|
||||
image_url: str | None
|
||||
mode: str
|
||||
child_profile_id: str | None = None
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
class StoryImageResponse(StoryStatusMixin):
|
||||
"""Cover image generation response."""
|
||||
|
||||
image_url: str | None
|
||||
|
||||
|
||||
class AchievementItem(BaseModel):
|
||||
"""Achievement item returned for a story."""
|
||||
|
||||
type: str
|
||||
description: str
|
||||
obtained_at: str | None = None
|
||||
|
||||
38
backend/app/services/audio_storage.py
Normal file
38
backend/app/services/audio_storage.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Story audio cache storage helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def build_story_audio_path(story_id: int) -> str:
|
||||
"""Build the cache path for a story audio file."""
|
||||
|
||||
return str(Path(settings.story_audio_cache_dir) / f"story-{story_id}.mp3")
|
||||
|
||||
|
||||
def audio_cache_exists(audio_path: str | None) -> bool:
|
||||
"""Whether the cached audio file exists on disk."""
|
||||
|
||||
return bool(audio_path) and Path(audio_path).is_file()
|
||||
|
||||
|
||||
def read_audio_cache(audio_path: str) -> bytes:
|
||||
"""Read cached story audio bytes."""
|
||||
|
||||
return Path(audio_path).read_bytes()
|
||||
|
||||
|
||||
def write_story_audio_cache(story_id: int, audio_data: bytes) -> str:
|
||||
"""Persist story audio and return the saved file path."""
|
||||
|
||||
final_path = Path(build_story_audio_path(story_id))
|
||||
final_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
temp_path = final_path.with_suffix(".tmp")
|
||||
temp_path.write_bytes(audio_data)
|
||||
temp_path.replace(final_path)
|
||||
|
||||
return str(final_path)
|
||||
@@ -1,12 +1,9 @@
|
||||
"""Story business logic service."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select, desc
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
@@ -15,90 +12,151 @@ from app.db.models import ChildProfile, Story, StoryUniverse
|
||||
from app.schemas.story_schemas import (
|
||||
GenerateRequest,
|
||||
StorybookRequest,
|
||||
FullStoryResponse,
|
||||
StorybookResponse,
|
||||
FullStoryResponse,
|
||||
StorybookResponse,
|
||||
StorybookPageResponse,
|
||||
AchievementItem,
|
||||
)
|
||||
from app.services.audio_storage import (
|
||||
audio_cache_exists,
|
||||
read_audio_cache,
|
||||
write_story_audio_cache,
|
||||
)
|
||||
from app.services.memory_service import build_enhanced_memory_context
|
||||
from app.services.provider_router import (
|
||||
generate_story_content,
|
||||
generate_image,
|
||||
generate_storybook,
|
||||
)
|
||||
from app.services.story_status import (
|
||||
StoryAssetStatus,
|
||||
sync_story_status,
|
||||
)
|
||||
from app.tasks.achievements import extract_story_achievements
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def validate_profile_and_universe(
|
||||
profile_id: str | None,
|
||||
universe_id: str | None,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Validate child profile and universe ownership/relationship."""
|
||||
if not profile_id and not universe_id:
|
||||
return None, None
|
||||
def _build_storybook_error_message(
|
||||
*,
|
||||
cover_failed: bool,
|
||||
failed_pages: list[int],
|
||||
) -> str | None:
|
||||
"""Summarize storybook image generation errors for the latest attempt."""
|
||||
|
||||
if profile_id:
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user_id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
||||
|
||||
if universe_id:
|
||||
result = await db.execute(
|
||||
select(StoryUniverse)
|
||||
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
|
||||
.where(
|
||||
StoryUniverse.id == universe_id,
|
||||
ChildProfile.user_id == user_id,
|
||||
)
|
||||
)
|
||||
universe = result.scalar_one_or_none()
|
||||
if not universe:
|
||||
raise HTTPException(status_code=404, detail="故事宇宙不存在")
|
||||
if profile_id and universe.child_profile_id != profile_id:
|
||||
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
|
||||
if not profile_id:
|
||||
profile_id = universe.child_profile_id
|
||||
|
||||
return profile_id, universe_id
|
||||
parts: list[str] = []
|
||||
if cover_failed:
|
||||
parts.append("封面生成失败")
|
||||
if failed_pages:
|
||||
pages = "、".join(str(page) for page in sorted(failed_pages))
|
||||
parts.append(f"第 {pages} 页插图生成失败")
|
||||
return ";".join(parts) if parts else None
|
||||
|
||||
|
||||
def _resolve_storybook_image_status(
|
||||
*,
|
||||
generate_images: bool,
|
||||
cover_prompt: str | None,
|
||||
cover_url: str | None,
|
||||
pages_data: list[dict],
|
||||
) -> StoryAssetStatus:
|
||||
"""Resolve the persisted image status for a storybook."""
|
||||
|
||||
if not generate_images:
|
||||
return StoryAssetStatus.NOT_REQUESTED
|
||||
|
||||
expected_assets = 0
|
||||
ready_assets = 0
|
||||
|
||||
if cover_prompt or cover_url:
|
||||
expected_assets += 1
|
||||
if cover_url:
|
||||
ready_assets += 1
|
||||
|
||||
for page in pages_data:
|
||||
if not page.get("image_prompt") and not page.get("image_url"):
|
||||
continue
|
||||
expected_assets += 1
|
||||
if page.get("image_url"):
|
||||
ready_assets += 1
|
||||
|
||||
if expected_assets == 0:
|
||||
return StoryAssetStatus.NOT_REQUESTED
|
||||
|
||||
if ready_assets == expected_assets:
|
||||
return StoryAssetStatus.READY
|
||||
|
||||
return StoryAssetStatus.FAILED
|
||||
|
||||
|
||||
async def validate_profile_and_universe(
|
||||
profile_id: str | None,
|
||||
universe_id: str | None,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Validate child profile and universe ownership/relationship."""
|
||||
if not profile_id and not universe_id:
|
||||
return None, None
|
||||
|
||||
if profile_id:
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user_id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
||||
|
||||
if universe_id:
|
||||
result = await db.execute(
|
||||
select(StoryUniverse)
|
||||
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
|
||||
.where(
|
||||
StoryUniverse.id == universe_id,
|
||||
ChildProfile.user_id == user_id,
|
||||
)
|
||||
)
|
||||
universe = result.scalar_one_or_none()
|
||||
if not universe:
|
||||
raise HTTPException(status_code=404, detail="故事宇宙不存在")
|
||||
if profile_id and universe.child_profile_id != profile_id:
|
||||
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
|
||||
if not profile_id:
|
||||
profile_id = universe.child_profile_id
|
||||
|
||||
return profile_id, universe_id
|
||||
|
||||
|
||||
async def generate_and_save_story(
|
||||
request: GenerateRequest,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> Story:
|
||||
"""Generate generic story content and save to DB."""
|
||||
# 1. Validate
|
||||
profile_id, universe_id = await validate_profile_and_universe(
|
||||
request.child_profile_id, request.universe_id, user_id, db
|
||||
)
|
||||
|
||||
# 2. Build Context
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
# 3. Generate
|
||||
try:
|
||||
result = await generate_story_content(
|
||||
input_type=request.type,
|
||||
data=request.data,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
db=db,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc
|
||||
|
||||
# 4. Save
|
||||
"""Generate generic story content and save to DB."""
|
||||
# 1. Validate
|
||||
profile_id, universe_id = await validate_profile_and_universe(
|
||||
request.child_profile_id, request.universe_id, user_id, db
|
||||
)
|
||||
|
||||
# 2. Build Context
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
# 3. Generate
|
||||
try:
|
||||
result = await generate_story_content(
|
||||
input_type=request.type,
|
||||
data=request.data,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
db=db,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc
|
||||
|
||||
# 4. Save
|
||||
story = Story(
|
||||
user_id=user_id,
|
||||
child_profile_id=profile_id,
|
||||
@@ -108,170 +166,209 @@ async def generate_and_save_story(
|
||||
cover_prompt=result.cover_prompt_suggestion,
|
||||
mode=result.mode,
|
||||
)
|
||||
sync_story_status(
|
||||
story,
|
||||
image_status=StoryAssetStatus.NOT_REQUESTED,
|
||||
audio_status=StoryAssetStatus.NOT_REQUESTED,
|
||||
last_error=None,
|
||||
)
|
||||
db.add(story)
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
|
||||
# 5. Trigger Async Tasks
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
return story
|
||||
|
||||
|
||||
|
||||
# 5. Trigger Async Tasks
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
return story
|
||||
|
||||
|
||||
async def generate_full_story_service(
|
||||
request: GenerateRequest,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> FullStoryResponse:
|
||||
"""Generate story with parallel image generation."""
|
||||
# 1. Generate text part
|
||||
# We can reuse logic or call generate_story_content directly if we want finer control
|
||||
# reusing generate_and_save_story to ensure consistency (it handles validation + saving)
|
||||
story = await generate_and_save_story(request, user_id, db)
|
||||
|
||||
# 2. Generate Image (Parallel/Async step in this flow)
|
||||
"""Generate story with parallel image generation."""
|
||||
# 1. Generate text part
|
||||
# We can reuse logic or call generate_story_content directly if we want finer control
|
||||
# reusing generate_and_save_story to ensure consistency (it handles validation + saving)
|
||||
story = await generate_and_save_story(request, user_id, db)
|
||||
|
||||
# 2. Generate Image (Parallel/Async step in this flow)
|
||||
image_url: str | None = None
|
||||
errors: dict[str, str | None] = {}
|
||||
|
||||
if story.cover_prompt:
|
||||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
try:
|
||||
image_url = await generate_image(story.cover_prompt, db=db)
|
||||
story.image_url = image_url
|
||||
sync_story_status(
|
||||
story,
|
||||
image_status=StoryAssetStatus.READY,
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
errors["image"] = str(exc)
|
||||
sync_story_status(
|
||||
story,
|
||||
image_status=StoryAssetStatus.FAILED,
|
||||
last_error=str(exc),
|
||||
)
|
||||
await db.commit()
|
||||
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
|
||||
|
||||
return FullStoryResponse(
|
||||
id=story.id,
|
||||
title=story.title,
|
||||
story_text=story.story_text,
|
||||
cover_prompt=story.cover_prompt,
|
||||
image_url=image_url,
|
||||
story_text=story.story_text,
|
||||
cover_prompt=story.cover_prompt,
|
||||
image_url=image_url,
|
||||
audio_ready=False,
|
||||
mode=story.mode,
|
||||
errors=errors,
|
||||
child_profile_id=story.child_profile_id,
|
||||
universe_id=story.universe_id,
|
||||
generation_status=story.generation_status,
|
||||
image_status=story.image_status,
|
||||
audio_status=story.audio_status,
|
||||
last_error=story.last_error,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
async def generate_storybook_service(
|
||||
request: StorybookRequest,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> StorybookResponse:
|
||||
"""Generate storybook with parallel image generation for pages."""
|
||||
# 1. Validate
|
||||
profile_id, universe_id = await validate_profile_and_universe(
|
||||
request.child_profile_id, request.universe_id, user_id, db
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"storybook_request",
|
||||
user_id=user_id,
|
||||
keywords=request.keywords,
|
||||
page_count=request.page_count,
|
||||
profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
)
|
||||
|
||||
# 2. Context
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
# 3. Generate Text Structure
|
||||
try:
|
||||
storybook = await generate_storybook(
|
||||
keywords=request.keywords,
|
||||
page_count=request.page_count,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
db=db,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("storybook_generation_failed", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
|
||||
|
||||
"""Generate storybook with parallel image generation for pages."""
|
||||
# 1. Validate
|
||||
profile_id, universe_id = await validate_profile_and_universe(
|
||||
request.child_profile_id, request.universe_id, user_id, db
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"storybook_request",
|
||||
user_id=user_id,
|
||||
keywords=request.keywords,
|
||||
page_count=request.page_count,
|
||||
profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
)
|
||||
|
||||
# 2. Context
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
# 3. Generate Text Structure
|
||||
try:
|
||||
storybook = await generate_storybook(
|
||||
keywords=request.keywords,
|
||||
page_count=request.page_count,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
db=db,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("storybook_generation_failed", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
|
||||
|
||||
# 4. Parallel Image Generation
|
||||
final_cover_url = storybook.cover_url
|
||||
cover_failed = False
|
||||
failed_pages: list[int] = []
|
||||
|
||||
if request.generate_images:
|
||||
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
|
||||
|
||||
|
||||
tasks = []
|
||||
|
||||
# Cover Task
|
||||
async def _gen_cover():
|
||||
nonlocal cover_failed
|
||||
|
||||
if storybook.cover_prompt and not storybook.cover_url:
|
||||
try:
|
||||
return await generate_image(storybook.cover_prompt, db=db)
|
||||
except Exception as e:
|
||||
logger.warning("cover_gen_failed", error=str(e))
|
||||
except Exception as exc:
|
||||
cover_failed = True
|
||||
logger.warning("cover_gen_failed", error=str(exc))
|
||||
return storybook.cover_url
|
||||
|
||||
tasks.append(_gen_cover())
|
||||
|
||||
# Page Tasks
|
||||
async def _gen_page(page):
|
||||
if page.image_prompt and not page.image_url:
|
||||
try:
|
||||
url = await generate_image(page.image_prompt, db=db)
|
||||
page.image_url = url
|
||||
except Exception as e:
|
||||
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
|
||||
page.image_url = await generate_image(page.image_prompt, db=db)
|
||||
except Exception as exc:
|
||||
failed_pages.append(page.page_number)
|
||||
logger.warning("page_gen_failed", page=page.page_number, error=str(exc))
|
||||
|
||||
for page in storybook.pages:
|
||||
tasks.append(_gen_page(page))
|
||||
|
||||
# Execute
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Update cover result
|
||||
|
||||
cover_res = results[0]
|
||||
if isinstance(cover_res, str):
|
||||
final_cover_url = cover_res
|
||||
|
||||
logger.info("storybook_parallel_generation_complete")
|
||||
|
||||
# 5. Save to DB
|
||||
pages_data = [
|
||||
{
|
||||
"page_number": p.page_number,
|
||||
"text": p.text,
|
||||
"image_prompt": p.image_prompt,
|
||||
"image_url": p.image_url,
|
||||
}
|
||||
for p in storybook.pages
|
||||
]
|
||||
|
||||
|
||||
# 5. Save to DB
|
||||
pages_data = [
|
||||
{
|
||||
"page_number": p.page_number,
|
||||
"text": p.text,
|
||||
"image_prompt": p.image_prompt,
|
||||
"image_url": p.image_url,
|
||||
}
|
||||
for p in storybook.pages
|
||||
]
|
||||
|
||||
story = Story(
|
||||
user_id=user_id,
|
||||
child_profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
title=storybook.title,
|
||||
mode="storybook",
|
||||
pages=pages_data,
|
||||
mode="storybook",
|
||||
pages=pages_data,
|
||||
story_text=None,
|
||||
cover_prompt=storybook.cover_prompt,
|
||||
image_url=final_cover_url,
|
||||
)
|
||||
sync_story_status(
|
||||
story,
|
||||
image_status=_resolve_storybook_image_status(
|
||||
generate_images=request.generate_images,
|
||||
cover_prompt=storybook.cover_prompt,
|
||||
cover_url=final_cover_url,
|
||||
pages_data=pages_data,
|
||||
),
|
||||
audio_status=StoryAssetStatus.NOT_REQUESTED,
|
||||
last_error=_build_storybook_error_message(
|
||||
cover_failed=cover_failed,
|
||||
failed_pages=failed_pages,
|
||||
),
|
||||
)
|
||||
db.add(story)
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
# 6. Build Response
|
||||
response_pages = [
|
||||
StorybookPageResponse(
|
||||
page_number=p["page_number"],
|
||||
text=p["text"],
|
||||
image_prompt=p["image_prompt"],
|
||||
image_url=p.get("image_url"),
|
||||
)
|
||||
for p in pages_data
|
||||
]
|
||||
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
# 6. Build Response
|
||||
response_pages = [
|
||||
StorybookPageResponse(
|
||||
page_number=p["page_number"],
|
||||
text=p["text"],
|
||||
image_prompt=p["image_prompt"],
|
||||
image_url=p.get("image_url"),
|
||||
)
|
||||
for p in pages_data
|
||||
]
|
||||
|
||||
return StorybookResponse(
|
||||
id=story.id,
|
||||
title=storybook.title,
|
||||
@@ -280,155 +377,209 @@ async def generate_storybook_service(
|
||||
pages=response_pages,
|
||||
cover_prompt=storybook.cover_prompt,
|
||||
cover_url=final_cover_url,
|
||||
generation_status=story.generation_status,
|
||||
image_status=story.image_status,
|
||||
audio_status=story.audio_status,
|
||||
last_error=story.last_error,
|
||||
)
|
||||
|
||||
|
||||
# ==================== Missing Endpoints Logic (for Issue #5) ====================
|
||||
|
||||
async def list_stories(
|
||||
user_id: str,
|
||||
limit: int,
|
||||
offset: int,
|
||||
db: AsyncSession,
|
||||
) -> list[Story]:
|
||||
"""List stories for user."""
|
||||
result = await db.execute(
|
||||
select(Story)
|
||||
.where(Story.user_id == user_id)
|
||||
.order_by(desc(Story.created_at))
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
async def get_story_detail(
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> Story:
|
||||
"""Get story detail."""
|
||||
result = await db.execute(
|
||||
select(Story).where(Story.id == story_id, Story.user_id == user_id)
|
||||
)
|
||||
story = result.scalar_one_or_none()
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
return story
|
||||
|
||||
|
||||
async def delete_story(
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Delete a story."""
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
await db.delete(story)
|
||||
await db.commit()
|
||||
|
||||
|
||||
|
||||
|
||||
# ==================== Missing Endpoints Logic (for Issue #5) ====================
|
||||
|
||||
async def list_stories(
|
||||
user_id: str,
|
||||
limit: int,
|
||||
offset: int,
|
||||
db: AsyncSession,
|
||||
) -> list[Story]:
|
||||
"""List stories for user."""
|
||||
result = await db.execute(
|
||||
select(Story)
|
||||
.where(Story.user_id == user_id)
|
||||
.order_by(desc(Story.created_at))
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
async def get_story_detail(
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> Story:
|
||||
"""Get story detail."""
|
||||
result = await db.execute(
|
||||
select(Story).where(Story.id == story_id, Story.user_id == user_id)
|
||||
)
|
||||
story = result.scalar_one_or_none()
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
return story
|
||||
|
||||
|
||||
async def delete_story(
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Delete a story."""
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
await db.delete(story)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def create_story_from_result(
|
||||
result, # StoryOutput
|
||||
user_id: str,
|
||||
profile_id: str | None,
|
||||
universe_id: str | None,
|
||||
db: AsyncSession,
|
||||
) -> Story:
|
||||
"""Save a generated story to DB (helper for stream endpoint)."""
|
||||
story = Story(
|
||||
user_id=user_id,
|
||||
child_profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
db: AsyncSession,
|
||||
) -> Story:
|
||||
"""Save a generated story to DB (helper for stream endpoint)."""
|
||||
story = Story(
|
||||
user_id=user_id,
|
||||
child_profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
title=result.title,
|
||||
story_text=result.story_text,
|
||||
cover_prompt=result.cover_prompt_suggestion,
|
||||
mode=result.mode,
|
||||
)
|
||||
sync_story_status(
|
||||
story,
|
||||
image_status=StoryAssetStatus.NOT_REQUESTED,
|
||||
audio_status=StoryAssetStatus.NOT_REQUESTED,
|
||||
last_error=None,
|
||||
)
|
||||
db.add(story)
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
return story
|
||||
|
||||
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
return story
|
||||
|
||||
|
||||
async def generate_story_cover(
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> str:
|
||||
"""Generate cover image for an existing story."""
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
|
||||
"""Generate cover image for an existing story."""
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
|
||||
if not story.cover_prompt:
|
||||
raise HTTPException(status_code=400, detail="Story has no cover prompt")
|
||||
|
||||
|
||||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
|
||||
try:
|
||||
image_url = await generate_image(story.cover_prompt, db=db)
|
||||
story.image_url = image_url
|
||||
sync_story_status(
|
||||
story,
|
||||
image_status=StoryAssetStatus.READY,
|
||||
)
|
||||
await db.commit()
|
||||
return image_url
|
||||
except Exception as e:
|
||||
sync_story_status(
|
||||
story,
|
||||
image_status=StoryAssetStatus.FAILED,
|
||||
last_error=str(e),
|
||||
)
|
||||
await db.commit()
|
||||
logger.error("cover_generation_failed", story_id=story_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
async def generate_story_audio(
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> bytes:
|
||||
"""Generate audio for a story."""
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
|
||||
"""Generate audio for a story."""
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
|
||||
if not story.story_text:
|
||||
raise HTTPException(status_code=400, detail="Story has no text")
|
||||
|
||||
# TODO: Check if audio is already cached/saved?
|
||||
# For now, generate on the fly via provider
|
||||
if story.audio_path and audio_cache_exists(story.audio_path):
|
||||
if story.audio_status != StoryAssetStatus.READY.value:
|
||||
sync_story_status(story, audio_status=StoryAssetStatus.READY)
|
||||
await db.commit()
|
||||
return read_audio_cache(story.audio_path)
|
||||
|
||||
if story.audio_path and not audio_cache_exists(story.audio_path):
|
||||
logger.warning(
|
||||
"story_audio_cache_missing",
|
||||
story_id=story_id,
|
||||
audio_path=story.audio_path,
|
||||
)
|
||||
story.audio_path = None
|
||||
if story.audio_status == StoryAssetStatus.READY.value:
|
||||
sync_story_status(story, audio_status=StoryAssetStatus.NOT_REQUESTED)
|
||||
await db.commit()
|
||||
|
||||
from app.services.provider_router import text_to_speech
|
||||
|
||||
|
||||
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
|
||||
try:
|
||||
audio_data = await text_to_speech(story.story_text, db=db)
|
||||
story.audio_path = write_story_audio_cache(story.id, audio_data)
|
||||
sync_story_status(
|
||||
story,
|
||||
audio_status=StoryAssetStatus.READY,
|
||||
)
|
||||
await db.commit()
|
||||
return audio_data
|
||||
except Exception as e:
|
||||
story.audio_path = None
|
||||
sync_story_status(
|
||||
story,
|
||||
audio_status=StoryAssetStatus.FAILED,
|
||||
last_error=str(e),
|
||||
)
|
||||
await db.commit()
|
||||
logger.error("audio_generation_failed", story_id=story_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
|
||||
|
||||
|
||||
async def get_story_achievements(
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> list[AchievementItem]:
|
||||
"""Get achievements unlocked by a specific story."""
|
||||
result = await db.execute(
|
||||
select(Story)
|
||||
.options(joinedload(Story.story_universe))
|
||||
.where(Story.id == story_id, Story.user_id == user_id)
|
||||
)
|
||||
story = result.scalar_one_or_none()
|
||||
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
|
||||
if not story.universe_id or not story.story_universe:
|
||||
return []
|
||||
|
||||
universe = story.story_universe
|
||||
if not universe.achievements:
|
||||
return []
|
||||
|
||||
results = []
|
||||
for ach in universe.achievements:
|
||||
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
|
||||
results.append(AchievementItem(
|
||||
type=ach.get("type", "Unknown"),
|
||||
description=ach.get("description", ""),
|
||||
obtained_at=ach.get("obtained_at")
|
||||
))
|
||||
return results
|
||||
|
||||
|
||||
|
||||
async def get_story_achievements(
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> list[AchievementItem]:
|
||||
"""Get achievements unlocked by a specific story."""
|
||||
result = await db.execute(
|
||||
select(Story)
|
||||
.options(joinedload(Story.story_universe))
|
||||
.where(Story.id == story_id, Story.user_id == user_id)
|
||||
)
|
||||
story = result.scalar_one_or_none()
|
||||
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
|
||||
if not story.universe_id or not story.story_universe:
|
||||
return []
|
||||
|
||||
universe = story.story_universe
|
||||
if not universe.achievements:
|
||||
return []
|
||||
|
||||
results = []
|
||||
for ach in universe.achievements:
|
||||
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
|
||||
results.append(AchievementItem(
|
||||
type=ach.get("type", "Unknown"),
|
||||
description=ach.get("description", ""),
|
||||
obtained_at=ach.get("obtained_at")
|
||||
))
|
||||
return results
|
||||
|
||||
|
||||
112
backend/app/services/story_status.py
Normal file
112
backend/app/services/story_status.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Story generation status helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class StoryGenerationStatus(str, Enum):
|
||||
"""Overall story generation lifecycle."""
|
||||
|
||||
NARRATIVE_READY = "narrative_ready"
|
||||
ASSETS_GENERATING = "assets_generating"
|
||||
COMPLETED = "completed"
|
||||
DEGRADED_COMPLETED = "degraded_completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class StoryAssetStatus(str, Enum):
|
||||
"""Asset generation state for image and audio."""
|
||||
|
||||
NOT_REQUESTED = "not_requested"
|
||||
GENERATING = "generating"
|
||||
READY = "ready"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class StoryLike(Protocol):
|
||||
"""Protocol for story-like objects used by status helpers."""
|
||||
|
||||
story_text: str | None
|
||||
pages: list[dict] | None
|
||||
generation_status: str
|
||||
image_status: str
|
||||
audio_status: str
|
||||
last_error: str | None
|
||||
|
||||
|
||||
_ERROR_UNSET = object()
|
||||
|
||||
|
||||
def _normalize_asset_status(value: str | None) -> StoryAssetStatus:
|
||||
if not value:
|
||||
return StoryAssetStatus.NOT_REQUESTED
|
||||
|
||||
try:
|
||||
return StoryAssetStatus(value)
|
||||
except ValueError:
|
||||
return StoryAssetStatus.NOT_REQUESTED
|
||||
|
||||
|
||||
def has_narrative_content(story: StoryLike) -> bool:
|
||||
"""Whether the story already has readable content."""
|
||||
|
||||
return bool(story.story_text) or bool(story.pages)
|
||||
|
||||
|
||||
def resolve_story_generation_status(story: StoryLike) -> StoryGenerationStatus:
|
||||
"""Derive the overall status from narrative and asset states."""
|
||||
|
||||
if not has_narrative_content(story):
|
||||
return StoryGenerationStatus.FAILED
|
||||
|
||||
image_status = _normalize_asset_status(story.image_status)
|
||||
audio_status = _normalize_asset_status(story.audio_status)
|
||||
|
||||
if StoryAssetStatus.GENERATING in (image_status, audio_status):
|
||||
return StoryGenerationStatus.ASSETS_GENERATING
|
||||
|
||||
if StoryAssetStatus.FAILED in (image_status, audio_status):
|
||||
return StoryGenerationStatus.DEGRADED_COMPLETED
|
||||
|
||||
if (
|
||||
image_status == StoryAssetStatus.NOT_REQUESTED
|
||||
and audio_status == StoryAssetStatus.NOT_REQUESTED
|
||||
):
|
||||
return StoryGenerationStatus.NARRATIVE_READY
|
||||
|
||||
return StoryGenerationStatus.COMPLETED
|
||||
|
||||
|
||||
def has_failed_assets(story: StoryLike) -> bool:
|
||||
"""Whether any persisted asset is still in a failed state."""
|
||||
|
||||
image_status = _normalize_asset_status(story.image_status)
|
||||
audio_status = _normalize_asset_status(story.audio_status)
|
||||
return StoryAssetStatus.FAILED in (image_status, audio_status)
|
||||
|
||||
|
||||
def sync_story_status(
|
||||
story: StoryLike,
|
||||
*,
|
||||
image_status: StoryAssetStatus | None = None,
|
||||
audio_status: StoryAssetStatus | None = None,
|
||||
last_error: str | None | object = _ERROR_UNSET,
|
||||
) -> None:
|
||||
"""Update asset statuses and refresh overall generation status."""
|
||||
|
||||
if image_status is not None:
|
||||
story.image_status = image_status.value
|
||||
|
||||
if audio_status is not None:
|
||||
story.audio_status = audio_status.value
|
||||
|
||||
if last_error is not _ERROR_UNSET:
|
||||
story.last_error = last_error
|
||||
|
||||
generation_status = resolve_story_generation_status(story)
|
||||
story.generation_status = generation_status.value
|
||||
|
||||
if last_error is _ERROR_UNSET and not has_failed_assets(story):
|
||||
story.last_error = None
|
||||
@@ -1,4 +1,4 @@
|
||||
"""测试配置和 fixtures。"""
|
||||
"""Pytest fixtures for backend tests."""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
|
||||
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing")
|
||||
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.security import create_access_token
|
||||
from app.db.database import get_db
|
||||
from app.db.models import Base, Story, User
|
||||
@@ -19,7 +20,8 @@ from app.main import app
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""创建内存数据库引擎。"""
|
||||
"""Create an in-memory database engine."""
|
||||
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
@@ -29,7 +31,8 @@ async def async_engine():
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""创建数据库会话。"""
|
||||
"""Create a database session."""
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
@@ -39,7 +42,8 @@ async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
@pytest.fixture
|
||||
async def test_user(db_session: AsyncSession) -> User:
|
||||
"""创建测试用户。"""
|
||||
"""Create a test user."""
|
||||
|
||||
user = User(
|
||||
id="github:12345",
|
||||
name="Test User",
|
||||
@@ -54,13 +58,74 @@ async def test_user(db_session: AsyncSession) -> User:
|
||||
|
||||
@pytest.fixture
|
||||
async def test_story(db_session: AsyncSession, test_user: User) -> Story:
|
||||
"""创建测试故事。"""
|
||||
"""Create a plain generated story."""
|
||||
|
||||
story = Story(
|
||||
user_id=test_user.id,
|
||||
title="测试故事",
|
||||
story_text="从前有一只小兔子...",
|
||||
story_text="从前有一只小兔子。",
|
||||
cover_prompt="A cute rabbit in a forest",
|
||||
mode="generated",
|
||||
generation_status="narrative_ready",
|
||||
image_status="not_requested",
|
||||
audio_status="not_requested",
|
||||
)
|
||||
db_session.add(story)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(story)
|
||||
return story
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def storybook_story(db_session: AsyncSession, test_user: User) -> Story:
|
||||
"""Create a storybook-mode story."""
|
||||
|
||||
story = Story(
|
||||
user_id=test_user.id,
|
||||
title="森林绘本冒险",
|
||||
story_text=None,
|
||||
pages=[
|
||||
{
|
||||
"page_number": 1,
|
||||
"text": "小兔子走进了会发光的森林。",
|
||||
"image_prompt": "A glowing forest with a curious rabbit",
|
||||
"image_url": "https://example.com/page-1.png",
|
||||
},
|
||||
{
|
||||
"page_number": 2,
|
||||
"text": "它遇见了一位会唱歌的萤火虫朋友。",
|
||||
"image_prompt": "A rabbit meeting a singing firefly",
|
||||
"image_url": None,
|
||||
},
|
||||
],
|
||||
cover_prompt="A magical forest storybook cover",
|
||||
image_url="https://example.com/storybook-cover.png",
|
||||
mode="storybook",
|
||||
generation_status="degraded_completed",
|
||||
image_status="failed",
|
||||
audio_status="not_requested",
|
||||
last_error="第 2 页插图生成失败",
|
||||
)
|
||||
db_session.add(story)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(story)
|
||||
return story
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def degraded_story_with_text(db_session: AsyncSession, test_user: User) -> Story:
|
||||
"""Create a readable story whose image generation already failed."""
|
||||
|
||||
story = Story(
|
||||
user_id=test_user.id,
|
||||
title="部分完成的测试故事",
|
||||
story_text="从前有一只小兔子继续冒险。",
|
||||
cover_prompt="A rabbit under the moon",
|
||||
mode="generated",
|
||||
generation_status="degraded_completed",
|
||||
image_status="failed",
|
||||
audio_status="not_requested",
|
||||
last_error="封面生成失败",
|
||||
)
|
||||
db_session.add(story)
|
||||
await db_session.commit()
|
||||
@@ -70,13 +135,14 @@ async def test_story(db_session: AsyncSession, test_user: User) -> Story:
|
||||
|
||||
@pytest.fixture
|
||||
def auth_token(test_user: User) -> str:
|
||||
"""生成测试用户的 JWT token。"""
|
||||
"""Create a JWT token for the test user."""
|
||||
|
||||
return create_access_token({"sub": test_user.id})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(db_session: AsyncSession) -> TestClient:
|
||||
"""创建测试客户端。"""
|
||||
"""Create a test client."""
|
||||
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
@@ -89,35 +155,45 @@ def client(db_session: AsyncSession) -> TestClient:
|
||||
|
||||
@pytest.fixture
|
||||
def auth_client(client: TestClient, auth_token: str) -> TestClient:
|
||||
"""带认证的测试客户端。"""
|
||||
"""Create an authenticated test client."""
|
||||
|
||||
client.cookies.set("access_token", auth_token)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_rate_limit():
|
||||
"""默认绕过限流,让非限流测试正常运行。"""
|
||||
"""Bypass rate limiting in most tests."""
|
||||
|
||||
with patch("app.core.rate_limiter.get_redis", new_callable=AsyncMock) as mock_redis:
|
||||
# 创建一个模拟的 Redis 客户端,所有操作返回安全默认值
|
||||
redis_instance = AsyncMock()
|
||||
redis_instance.incr.return_value = 1 # 始终返回 1 (不触发限流)
|
||||
redis_instance.incr.return_value = 1
|
||||
redis_instance.expire.return_value = True
|
||||
redis_instance.get.return_value = None # 无锁定记录
|
||||
redis_instance.get.return_value = None
|
||||
redis_instance.ttl.return_value = 0
|
||||
redis_instance.delete.return_value = 1
|
||||
mock_redis.return_value = redis_instance
|
||||
yield redis_instance
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_story_audio_cache(tmp_path, monkeypatch):
|
||||
"""Use an isolated directory for cached story audio files."""
|
||||
|
||||
monkeypatch.setattr(settings, "story_audio_cache_dir", str(tmp_path / "audio"))
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_provider():
|
||||
"""Mock 文本生成适配器 API 调用。"""
|
||||
"""Mock text generation."""
|
||||
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
|
||||
mock_result = StoryOutput(
|
||||
mode="generated",
|
||||
title="小兔子的冒险",
|
||||
story_text="从前有一只小兔子...",
|
||||
story_text="从前有一只小兔子。",
|
||||
cover_prompt_suggestion="A cute rabbit",
|
||||
)
|
||||
|
||||
@@ -128,7 +204,8 @@ def mock_text_provider():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_image_provider():
|
||||
"""Mock 图像生成。"""
|
||||
"""Mock image generation."""
|
||||
|
||||
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = "https://example.com/image.png"
|
||||
yield mock
|
||||
@@ -136,7 +213,8 @@ def mock_image_provider():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts_provider():
|
||||
"""Mock TTS。"""
|
||||
"""Mock text-to-speech generation."""
|
||||
|
||||
with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = b"fake-audio-bytes"
|
||||
yield mock
|
||||
@@ -144,7 +222,8 @@ def mock_tts_provider():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_all_providers(mock_text_provider, mock_image_provider, mock_tts_provider):
|
||||
"""Mock 所有 AI 供应商。"""
|
||||
"""Group all mocked providers."""
|
||||
|
||||
return {
|
||||
"text_primary": mock_text_provider,
|
||||
"image_primary": mock_image_provider,
|
||||
|
||||
@@ -1,26 +1,41 @@
|
||||
"""故事 API 测试。"""
|
||||
"""Tests for story-related API endpoints."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.core.config import settings
|
||||
from app.services.adapters.storybook.primary import Storybook, StorybookPage
|
||||
|
||||
# ── 注意 ──────────────────────────────────────────────────────────────────────
|
||||
# 以下路由尚未实现 (stories.py 中没有对应端点),相关测试标记为 skip:
|
||||
# GET /api/stories (列表)
|
||||
# GET /api/stories/{id} (详情)
|
||||
# DELETE /api/stories/{id} (删除)
|
||||
# POST /api/image/generate/{id} (封面图片生成)
|
||||
# GET /api/audio/{id} (音频)
|
||||
# 实现后请取消 skip 标记。
|
||||
|
||||
def build_storybook_output() -> Storybook:
|
||||
"""Create a reusable mocked storybook payload."""
|
||||
|
||||
return Storybook(
|
||||
title="森林里的发光冒险",
|
||||
main_character="小兔子露露",
|
||||
art_style="温暖水彩",
|
||||
cover_prompt="A glowing forest storybook cover",
|
||||
pages=[
|
||||
StorybookPage(
|
||||
page_number=1,
|
||||
text="露露第一次走进会发光的森林。",
|
||||
image_prompt="Lulu entering a glowing forest",
|
||||
),
|
||||
StorybookPage(
|
||||
page_number=2,
|
||||
text="她遇到了一只会唱歌的萤火虫。",
|
||||
image_prompt="Lulu meeting a singing firefly",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestStoryGenerate:
|
||||
"""故事生成测试。"""
|
||||
"""Tests for basic story generation."""
|
||||
|
||||
def test_generate_without_auth(self, client: TestClient):
|
||||
"""未登录时生成故事。"""
|
||||
response = client.post(
|
||||
"/api/stories/generate",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
@@ -28,7 +43,6 @@ class TestStoryGenerate:
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_generate_with_empty_data(self, auth_client: TestClient):
|
||||
"""空数据生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate",
|
||||
json={"type": "keywords", "data": ""},
|
||||
@@ -36,7 +50,6 @@ class TestStoryGenerate:
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_generate_with_invalid_type(self, auth_client: TestClient):
|
||||
"""无效类型生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate",
|
||||
json={"type": "invalid", "data": "test"},
|
||||
@@ -44,7 +57,6 @@ class TestStoryGenerate:
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_generate_story_success(self, auth_client: TestClient, mock_text_provider):
|
||||
"""成功生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate",
|
||||
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
|
||||
@@ -55,82 +67,96 @@ class TestStoryGenerate:
|
||||
assert "title" in data
|
||||
assert "story_text" in data
|
||||
assert data["mode"] == "generated"
|
||||
assert data["generation_status"] == "narrative_ready"
|
||||
assert data["image_status"] == "not_requested"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["last_error"] is None
|
||||
|
||||
|
||||
class TestStoryList:
|
||||
"""故事列表测试。"""
|
||||
"""Tests for story listing."""
|
||||
|
||||
def test_list_without_auth(self, client: TestClient):
|
||||
"""未登录时获取列表。"""
|
||||
response = client.get("/api/stories")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_list_empty(self, auth_client: TestClient):
|
||||
"""空列表。"""
|
||||
response = auth_client.get("/api/stories")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_list_with_stories(self, auth_client: TestClient, test_story):
|
||||
"""有故事时获取列表。"""
|
||||
response = auth_client.get("/api/stories")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == test_story.id
|
||||
assert data[0]["title"] == test_story.title
|
||||
assert data[0]["generation_status"] == "narrative_ready"
|
||||
assert data[0]["image_status"] == "not_requested"
|
||||
assert data[0]["audio_status"] == "not_requested"
|
||||
|
||||
def test_list_pagination(self, auth_client: TestClient, test_story):
|
||||
"""分页测试。"""
|
||||
response = auth_client.get("/api/stories?limit=1&offset=0")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert len(response.json()) == 1
|
||||
|
||||
response = auth_client.get("/api/stories?limit=1&offset=1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 0
|
||||
assert len(response.json()) == 0
|
||||
|
||||
|
||||
class TestStoryDetail:
|
||||
"""故事详情测试。"""
|
||||
"""Tests for story detail retrieval."""
|
||||
|
||||
def test_get_story_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时获取详情。"""
|
||||
response = client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_get_story_not_found(self, auth_client: TestClient):
|
||||
"""故事不存在。"""
|
||||
response = auth_client.get("/api/stories/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_story_success(self, auth_client: TestClient, test_story):
|
||||
"""成功获取详情。"""
|
||||
response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == test_story.id
|
||||
assert data["title"] == test_story.title
|
||||
assert data["story_text"] == test_story.story_text
|
||||
assert data["generation_status"] == "narrative_ready"
|
||||
assert data["image_status"] == "not_requested"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["last_error"] is None
|
||||
|
||||
def test_get_storybook_success(self, auth_client: TestClient, storybook_story):
|
||||
response = auth_client.get(f"/api/stories/{storybook_story.id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == storybook_story.id
|
||||
assert data["mode"] == "storybook"
|
||||
assert data["story_text"] is None
|
||||
assert len(data["pages"]) == 2
|
||||
assert data["pages"][0]["page_number"] == 1
|
||||
assert data["image_url"] == "https://example.com/storybook-cover.png"
|
||||
assert data["generation_status"] == "degraded_completed"
|
||||
assert data["image_status"] == "failed"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert "第 2 页" in data["last_error"]
|
||||
|
||||
|
||||
class TestStoryDelete:
|
||||
"""故事删除测试。"""
|
||||
"""Tests for story deletion."""
|
||||
|
||||
def test_delete_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时删除。"""
|
||||
response = client.delete(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_delete_not_found(self, auth_client: TestClient):
|
||||
"""删除不存在的故事。"""
|
||||
response = auth_client.delete("/api/stories/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_delete_success(self, auth_client: TestClient, test_story):
|
||||
"""成功删除故事。"""
|
||||
response = auth_client.delete(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "Deleted"
|
||||
@@ -140,11 +166,14 @@ class TestStoryDelete:
|
||||
|
||||
|
||||
class TestRateLimit:
|
||||
"""Rate limit 测试。"""
|
||||
"""Tests for story generation rate limiting."""
|
||||
|
||||
def test_rate_limit_allows_normal_requests(self, auth_client: TestClient, mock_text_provider, bypass_rate_limit):
|
||||
"""正常请求不触发限流。"""
|
||||
# bypass_rate_limit 默认 incr 返回 1,不触发限流
|
||||
def test_rate_limit_allows_normal_requests(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
mock_text_provider,
|
||||
bypass_rate_limit,
|
||||
):
|
||||
for _ in range(3):
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate",
|
||||
@@ -152,9 +181,11 @@ class TestRateLimit:
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_rate_limit_blocks_excess_requests(self, auth_client: TestClient, bypass_rate_limit):
|
||||
"""超限请求被阻止。"""
|
||||
# 让 incr 返回超限值 (> RATE_LIMIT_REQUESTS)
|
||||
def test_rate_limit_blocks_excess_requests(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
bypass_rate_limit,
|
||||
):
|
||||
bypass_rate_limit.incr.return_value = 11
|
||||
|
||||
response = auth_client.post(
|
||||
@@ -166,52 +197,118 @@ class TestRateLimit:
|
||||
|
||||
|
||||
class TestImageGenerate:
|
||||
"""封面图片生成测试。"""
|
||||
"""Tests for cover generation endpoint."""
|
||||
|
||||
def test_generate_image_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时生成图片。"""
|
||||
response = client.post(f"/api/image/generate/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_generate_image_not_found(self, auth_client: TestClient):
|
||||
"""故事不存在。"""
|
||||
response = auth_client.post("/api/image/generate/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestAudio:
|
||||
"""语音朗读测试。"""
|
||||
"""Tests for story audio endpoint."""
|
||||
|
||||
def test_get_audio_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时获取音频。"""
|
||||
response = client.get(f"/api/audio/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_get_audio_not_found(self, auth_client: TestClient):
|
||||
"""故事不存在。"""
|
||||
response = auth_client.get("/api/audio/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_audio_success(self, auth_client: TestClient, test_story, mock_tts_provider):
|
||||
"""成功获取音频。"""
|
||||
def test_get_audio_success(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
test_story,
|
||||
mock_tts_provider,
|
||||
):
|
||||
response = auth_client.get(f"/api/audio/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
assert response.content == b"fake-audio-bytes"
|
||||
|
||||
cached_audio_path = Path(settings.story_audio_cache_dir) / f"story-{test_story.id}.mp3"
|
||||
assert cached_audio_path.is_file()
|
||||
|
||||
second_response = auth_client.get(f"/api/audio/{test_story.id}")
|
||||
assert second_response.status_code == 200
|
||||
assert second_response.content == b"fake-audio-bytes"
|
||||
mock_tts_provider.assert_awaited_once()
|
||||
|
||||
detail_response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
detail = detail_response.json()
|
||||
assert detail["audio_status"] == "ready"
|
||||
assert detail["generation_status"] == "completed"
|
||||
assert detail["last_error"] is None
|
||||
|
||||
def test_get_audio_regenerates_when_cache_file_is_missing(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
test_story,
|
||||
mock_tts_provider,
|
||||
):
|
||||
first_response = auth_client.get(f"/api/audio/{test_story.id}")
|
||||
assert first_response.status_code == 200
|
||||
|
||||
cached_audio_path = Path(settings.story_audio_cache_dir) / f"story-{test_story.id}.mp3"
|
||||
cached_audio_path.unlink()
|
||||
mock_tts_provider.reset_mock()
|
||||
|
||||
second_response = auth_client.get(f"/api/audio/{test_story.id}")
|
||||
assert second_response.status_code == 200
|
||||
assert second_response.content == b"fake-audio-bytes"
|
||||
assert cached_audio_path.is_file()
|
||||
mock_tts_provider.assert_awaited_once()
|
||||
|
||||
def test_get_audio_failure_updates_status(self, auth_client: TestClient, test_story):
|
||||
with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock_tts:
|
||||
mock_tts.side_effect = Exception("TTS provider timeout")
|
||||
response = auth_client.get(f"/api/audio/{test_story.id}")
|
||||
assert response.status_code == 500
|
||||
|
||||
detail_response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
detail = detail_response.json()
|
||||
assert detail["audio_status"] == "failed"
|
||||
assert detail["generation_status"] == "degraded_completed"
|
||||
assert "TTS provider timeout" in detail["last_error"]
|
||||
|
||||
def test_get_audio_success_preserves_existing_image_error(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
degraded_story_with_text,
|
||||
mock_tts_provider,
|
||||
):
|
||||
response = auth_client.get(f"/api/audio/{degraded_story_with_text.id}")
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"fake-audio-bytes"
|
||||
mock_tts_provider.assert_awaited_once()
|
||||
|
||||
detail_response = auth_client.get(f"/api/stories/{degraded_story_with_text.id}")
|
||||
detail = detail_response.json()
|
||||
assert detail["audio_status"] == "ready"
|
||||
assert detail["generation_status"] == "degraded_completed"
|
||||
assert detail["last_error"] == "封面生成失败"
|
||||
|
||||
|
||||
class TestGenerateFull:
|
||||
"""完整故事生成测试(/api/stories/generate/full)。"""
|
||||
"""Tests for complete story generation."""
|
||||
|
||||
def test_generate_full_without_auth(self, client: TestClient):
|
||||
"""未登录时生成完整故事。"""
|
||||
response = client.post(
|
||||
"/api/stories/generate/full",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_generate_full_success(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
|
||||
"""成功生成完整故事(含图片)。"""
|
||||
def test_generate_full_success(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
mock_text_provider,
|
||||
mock_image_provider,
|
||||
):
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate/full",
|
||||
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
|
||||
@@ -223,11 +320,14 @@ class TestGenerateFull:
|
||||
assert "story_text" in data
|
||||
assert data["mode"] == "generated"
|
||||
assert data["image_url"] == "https://example.com/image.png"
|
||||
assert data["audio_ready"] is False # 音频按需生成
|
||||
assert data["audio_ready"] is False
|
||||
assert data["errors"] == {}
|
||||
assert data["generation_status"] == "completed"
|
||||
assert data["image_status"] == "ready"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["last_error"] is None
|
||||
|
||||
def test_generate_full_image_failure(self, auth_client: TestClient, mock_text_provider):
|
||||
"""图片生成失败时返回部分成功。"""
|
||||
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_img:
|
||||
mock_img.side_effect = Exception("Image API error")
|
||||
response = auth_client.post(
|
||||
@@ -239,9 +339,17 @@ class TestGenerateFull:
|
||||
assert data["image_url"] is None
|
||||
assert "image" in data["errors"]
|
||||
assert "Image API error" in data["errors"]["image"]
|
||||
assert data["generation_status"] == "degraded_completed"
|
||||
assert data["image_status"] == "failed"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert "Image API error" in data["last_error"]
|
||||
|
||||
def test_generate_full_with_education_theme(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
|
||||
"""带教育主题生成故事。"""
|
||||
def test_generate_full_with_education_theme(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
mock_text_provider,
|
||||
mock_image_provider,
|
||||
):
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate/full",
|
||||
json={
|
||||
@@ -257,11 +365,80 @@ class TestGenerateFull:
|
||||
|
||||
|
||||
class TestImageGenerateSuccess:
|
||||
"""封面图片生成成功测试。"""
|
||||
"""Tests for successful cover generation."""
|
||||
|
||||
def test_generate_image_success(self, auth_client: TestClient, test_story, mock_image_provider):
|
||||
"""成功生成图片。"""
|
||||
def test_generate_image_success(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
test_story,
|
||||
mock_image_provider,
|
||||
):
|
||||
response = auth_client.post(f"/api/image/generate/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["image_url"] == "https://example.com/image.png"
|
||||
assert data["generation_status"] == "completed"
|
||||
assert data["image_status"] == "ready"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["last_error"] is None
|
||||
|
||||
|
||||
class TestStorybookGenerate:
|
||||
"""Tests for storybook generation status handling."""
|
||||
|
||||
def test_generate_storybook_success(self, auth_client: TestClient):
|
||||
with patch("app.services.story_service.generate_storybook", new_callable=AsyncMock) as mock_storybook:
|
||||
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_image:
|
||||
mock_storybook.return_value = build_storybook_output()
|
||||
mock_image.side_effect = [
|
||||
"https://example.com/storybook-cover.png",
|
||||
"https://example.com/storybook-page-1.png",
|
||||
"https://example.com/storybook-page-2.png",
|
||||
]
|
||||
|
||||
response = auth_client.post(
|
||||
"/api/storybook/generate",
|
||||
json={
|
||||
"keywords": "森林, 发光, 友情",
|
||||
"page_count": 6,
|
||||
"generate_images": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] is not None
|
||||
assert data["generation_status"] == "completed"
|
||||
assert data["image_status"] == "ready"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["last_error"] is None
|
||||
assert len(data["pages"]) == 2
|
||||
assert data["cover_url"] == "https://example.com/storybook-cover.png"
|
||||
|
||||
def test_generate_storybook_partial_image_failure(self, auth_client: TestClient):
|
||||
async def image_side_effect(prompt: str, **kwargs):
|
||||
if "singing firefly" in prompt:
|
||||
raise Exception("Image API error")
|
||||
slug = prompt.split()[0].lower()
|
||||
return f"https://example.com/{slug}.png"
|
||||
|
||||
with patch("app.services.story_service.generate_storybook", new_callable=AsyncMock) as mock_storybook:
|
||||
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_image:
|
||||
mock_storybook.return_value = build_storybook_output()
|
||||
mock_image.side_effect = image_side_effect
|
||||
|
||||
response = auth_client.post(
|
||||
"/api/storybook/generate",
|
||||
json={
|
||||
"keywords": "森林, 发光, 友情",
|
||||
"page_count": 6,
|
||||
"generate_images": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["generation_status"] == "degraded_completed"
|
||||
assert data["image_status"] == "failed"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert "第 2 页插图生成失败" in data["last_error"]
|
||||
|
||||
Reference in New Issue
Block a user