feat: persist story generation states and cache audio
Some checks failed
Build and Push Docker Images / changes (push) Has been cancelled
Build and Push Docker Images / build-backend (push) Has been cancelled
Build and Push Docker Images / build-frontend (push) Has been cancelled
Build and Push Docker Images / build-admin-frontend (push) Has been cancelled

This commit is contained in:
2026-04-17 17:14:09 +08:00
parent 145be0e67b
commit a97a2fe005
17 changed files with 2045 additions and 849 deletions

51
backend/.gitignore vendored
View File

@@ -1,27 +1,28 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
.venv/
venv/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
# 环境变量
.env
# 测试
.pytest_cache/
.coverage
htmlcov/
# 其他
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
.venv/
venv/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
# 环境变量
.env
# 测试
.pytest_cache/
.coverage
htmlcov/
# 其他
*.log
.DS_Store
storage/

View File

@@ -0,0 +1,151 @@
"""add story generation status fields
Revision ID: 0009_add_story_generation_statuses
Revises: 0008_add_pages_to_stories
Create Date: 2026-04-17
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "0009_add_story_generation_statuses"
down_revision = "0008_add_pages_to_stories"
branch_labels = None
depends_on = None
stories = sa.table(
"stories",
sa.column("id", sa.Integer),
sa.column("story_text", sa.Text),
sa.column("pages", sa.JSON),
sa.column("cover_prompt", sa.Text),
sa.column("image_url", sa.String(length=500)),
sa.column("generation_status", sa.String(length=32)),
sa.column("image_status", sa.String(length=32)),
sa.column("audio_status", sa.String(length=32)),
)
def _resolve_image_status(row: dict) -> str:
pages = row.get("pages") or []
expected_assets = 0
ready_assets = 0
if row.get("cover_prompt") or row.get("image_url"):
expected_assets += 1
if row.get("image_url"):
ready_assets += 1
for page in pages:
if not isinstance(page, dict):
continue
if not page.get("image_prompt") and not page.get("image_url"):
continue
expected_assets += 1
if page.get("image_url"):
ready_assets += 1
if expected_assets == 0:
return "not_requested"
if ready_assets == expected_assets:
return "ready"
return "failed"
def _resolve_generation_status(
*,
story_text: str | None,
pages: list[dict] | None,
image_status: str,
audio_status: str,
) -> str:
has_narrative = bool(story_text) or bool(pages)
if not has_narrative:
return "failed"
if "generating" in {image_status, audio_status}:
return "assets_generating"
if "failed" in {image_status, audio_status}:
return "degraded_completed"
if image_status == "not_requested" and audio_status == "not_requested":
return "narrative_ready"
return "completed"
def upgrade() -> None:
op.add_column(
"stories",
sa.Column(
"generation_status",
sa.String(length=32),
nullable=False,
server_default="narrative_ready",
),
)
op.add_column(
"stories",
sa.Column(
"image_status",
sa.String(length=32),
nullable=False,
server_default="not_requested",
),
)
op.add_column(
"stories",
sa.Column(
"audio_status",
sa.String(length=32),
nullable=False,
server_default="not_requested",
),
)
op.add_column("stories", sa.Column("last_error", sa.Text(), nullable=True))
connection = op.get_bind()
rows = connection.execute(
sa.select(
stories.c.id,
stories.c.story_text,
stories.c.pages,
stories.c.cover_prompt,
stories.c.image_url,
)
).mappings()
for row in rows:
image_status = _resolve_image_status(row)
audio_status = "not_requested"
generation_status = _resolve_generation_status(
story_text=row.get("story_text"),
pages=row.get("pages"),
image_status=image_status,
audio_status=audio_status,
)
connection.execute(
stories.update()
.where(stories.c.id == row["id"])
.values(
generation_status=generation_status,
image_status=image_status,
audio_status=audio_status,
)
)
def downgrade() -> None:
op.drop_column("stories", "last_error")
op.drop_column("stories", "audio_status")
op.drop_column("stories", "image_status")
op.drop_column("stories", "generation_status")

View File

@@ -0,0 +1,25 @@
"""add audio cache path to stories
Revision ID: 0010_add_story_audio_cache_path
Revises: 0009_add_story_generation_statuses
Create Date: 2026-04-17
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "0010_add_story_audio_cache_path"
down_revision = "0009_add_story_generation_statuses"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("stories", sa.Column("audio_path", sa.String(length=500), nullable=True))
def downgrade() -> None:
op.drop_column("stories", "audio_path")

View File

@@ -1,27 +1,28 @@
"""Story related APIs."""
import asyncio
import json
import uuid
from typing import AsyncGenerator
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sse_starlette.sse import EventSourceResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.core.logging import get_logger
from app.core.rate_limiter import check_rate_limit
from app.db.database import get_db
from app.db.models import User
from fastapi import APIRouter, Depends, Response
from sse_starlette.sse import EventSourceResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.core.logging import get_logger
from app.core.rate_limiter import check_rate_limit
from app.db.database import get_db
from app.db.models import User
from app.schemas.story_schemas import (
GenerateRequest,
StoryResponse,
AchievementItem,
FullStoryResponse,
GenerateRequest,
StoryDetailResponse,
StoryImageResponse,
StoryListItem,
StoryResponse,
StorybookRequest,
StorybookResponse,
StoryListItem,
AchievementItem,
)
from app.services import story_service
from app.services.memory_service import build_enhanced_memory_context
@@ -29,153 +30,202 @@ from app.services.provider_router import (
generate_story_content,
generate_image,
)
logger = get_logger(__name__)
router = APIRouter()
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10
@router.post("/stories/generate", response_model=StoryResponse)
async def generate_story(
request: GenerateRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate or enhance a story."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_and_save_story(request, user.id, db)
@router.post("/stories/generate/full", response_model=FullStoryResponse)
async def generate_story_full(
request: GenerateRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate complete story (story + parallel image/audio generation)."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_full_story_service(request, user.id, db)
@router.post("/stories/generate/stream")
from app.services.story_status import StoryAssetStatus, sync_story_status
logger = get_logger(__name__)
router = APIRouter()
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10
@router.post("/stories/generate", response_model=StoryResponse)
async def generate_story(
request: GenerateRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate or enhance a story."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_and_save_story(request, user.id, db)
@router.post("/stories/generate/full", response_model=FullStoryResponse)
async def generate_story_full(
request: GenerateRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate complete story (story + parallel image/audio generation)."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_full_story_service(request, user.id, db)
@router.post("/stories/generate/stream")
async def generate_story_stream(
request: GenerateRequest,
req: Request,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""流式生成故事SSE"""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
# Validation
profile_id, universe_id = await story_service.validate_profile_and_universe(
request.child_profile_id, request.universe_id, user.id, db
)
# Build Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
async def event_generator() -> AsyncGenerator[dict, None]:
story_id = str(uuid.uuid4())
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
# Step 1: Generate Content
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("sse_story_generation_failed", error=str(e))
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
return
# Save Story
story = await story_service.create_story_from_result(
result, user.id, profile_id, universe_id, db
)
"""流式生成故事SSE"""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
# Validation
profile_id, universe_id = await story_service.validate_profile_and_universe(
request.child_profile_id, request.universe_id, user.id, db
)
# Build Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
async def event_generator() -> AsyncGenerator[dict, None]:
story_id = str(uuid.uuid4())
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
# Step 1: Generate Content
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("sse_story_generation_failed", error=str(e))
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
return
# Save Story
story = await story_service.create_story_from_result(
result, user.id, profile_id, universe_id, db
)
yield {
"event": "story_ready",
"data": json.dumps({
"id": story.id,
"title": story.title,
"content": story.story_text,
"content": story.story_text,
"cover_prompt": story.cover_prompt,
"mode": story.mode,
"child_profile_id": story.child_profile_id,
"universe_id": story.universe_id,
"generation_status": story.generation_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
}),
}
# Step 2: Generate Image
if story.cover_prompt:
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
# Direct call to provider router's generate_image, sharing db session
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
sync_story_status(
story,
image_status=StoryAssetStatus.READY,
)
await db.commit()
yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})}
yield {
"event": "image_ready",
"data": json.dumps(
{
"image_url": image_url,
"generation_status": story.generation_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
}
),
}
except Exception as e:
sync_story_status(
story,
image_status=StoryAssetStatus.FAILED,
last_error=str(e),
)
await db.commit()
logger.warning("sse_image_generation_failed", story_id=story.id, error=str(e))
yield {"event": "image_failed", "data": json.dumps({"error": str(e)})}
yield {
"event": "image_failed",
"data": json.dumps(
{
"error": str(e),
"generation_status": story.generation_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
}
),
}
yield {"event": "complete", "data": json.dumps({"story_id": story.id})}
return EventSourceResponse(event_generator())
@router.post("/storybook/generate", response_model=StorybookResponse)
async def generate_storybook_api(
request: StorybookRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate storybook."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_storybook_service(request, user.id, db)
# ==================== Missing Endpoints (Issue #5) ====================
@router.get("/stories", response_model=list[StoryListItem])
async def list_stories(
limit: int = 20,
offset: int = 0,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List stories."""
return await story_service.list_stories(user.id, limit, offset, db)
@router.get("/stories/{story_id}", response_model=StoryResponse)
yield {
"event": "complete",
"data": json.dumps(
{
"story_id": story.id,
"generation_status": story.generation_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
}
),
}
return EventSourceResponse(event_generator())
@router.post("/storybook/generate", response_model=StorybookResponse)
async def generate_storybook_api(
request: StorybookRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate storybook."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_storybook_service(request, user.id, db)
# ==================== Missing Endpoints (Issue #5) ====================
@router.get("/stories", response_model=list[StoryListItem])
async def list_stories(
limit: int = 20,
offset: int = 0,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List stories."""
return await story_service.list_stories(user.id, limit, offset, db)
@router.get("/stories/{story_id}", response_model=StoryDetailResponse)
async def get_story(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story detail."""
return await story_service.get_story_detail(story_id, user.id, db)
@router.delete("/stories/{story_id}")
async def delete_story(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Delete story."""
await story_service.delete_story(story_id, user.id, db)
return {"message": "Deleted"}
@router.post("/image/generate/{story_id}")
"""Get story detail."""
return await story_service.get_story_detail(story_id, user.id, db)
@router.delete("/stories/{story_id}")
async def delete_story(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Delete story."""
await story_service.delete_story(story_id, user.id, db)
return {"message": "Deleted"}
@router.post("/image/generate/{story_id}", response_model=StoryImageResponse)
async def generate_story_image(
story_id: int,
user: User = Depends(require_user),
@@ -183,25 +233,32 @@ async def generate_story_image(
):
"""Generate cover image for story."""
url = await story_service.generate_story_cover(story_id, user.id, db)
return {"image_url": url}
@router.get("/audio/{story_id}")
async def get_story_audio(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story audio (MP3)."""
audio_bytes = await story_service.generate_story_audio(story_id, user.id, db)
return Response(content=audio_bytes, media_type="audio/mpeg")
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
async def get_story_achievements(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story achievements."""
return await story_service.get_story_achievements(story_id, user.id, db)
story = await story_service.get_story_detail(story_id, user.id, db)
return {
"image_url": url,
"generation_status": story.generation_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
}
@router.get("/audio/{story_id}")
async def get_story_audio(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story audio (MP3)."""
audio_bytes = await story_service.generate_story_audio(story_id, user.id, db)
return Response(content=audio_bytes, media_type="audio/mpeg")
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
async def get_story_achievements(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story achievements."""
return await story_service.get_story_achievements(story_id, user.id, db)

View File

@@ -1,130 +1,134 @@
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""应用全局配置"""
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
# 应用基础配置
app_name: str = "DreamWeaver"
debug: bool = False
secret_key: str = Field(..., description="JWT 签名密钥")
base_url: str = Field("http://localhost:8000", description="后端对外回调地址")
# 数据库
database_url: str = Field(..., description="SQLAlchemy async URL")
# OAuth - GitHub
github_client_id: str = ""
github_client_secret: str = ""
# OAuth - Google
google_client_id: str = ""
google_client_secret: str = ""
# AI Capability Keys
text_api_key: str = ""
tts_api_base: str = ""
tts_api_key: str = ""
image_api_key: str = ""
# Additional Provider API Keys
openai_api_key: str = ""
elevenlabs_api_key: str = ""
cqtai_api_key: str = ""
minimax_api_key: str = ""
minimax_group_id: str = ""
antigravity_api_key: str = ""
antigravity_api_base: str = ""
# AI Model Configuration
text_model: str = "gemini-2.0-flash"
openai_model: str = "gpt-4o-mini"
tts_model: str = ""
image_model: str = "nano-banana-pro"
tts_minimax_model: str = "speech-2.6-turbo"
tts_elevenlabs_model: str = "eleven_multilingual_v2"
tts_edge_voice: str = "zh-CN-XiaoxiaoNeural"
antigravity_model: str = "gemini-3-pro-image"
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""应用全局配置"""
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
# 应用基础配置
app_name: str = "DreamWeaver"
debug: bool = False
secret_key: str = Field(..., description="JWT 签名密钥")
base_url: str = Field("http://localhost:8000", description="后端对外回调地址")
# 数据库
database_url: str = Field(..., description="SQLAlchemy async URL")
# OAuth - GitHub
github_client_id: str = ""
github_client_secret: str = ""
# OAuth - Google
google_client_id: str = ""
google_client_secret: str = ""
# AI Capability Keys
text_api_key: str = ""
tts_api_base: str = ""
tts_api_key: str = ""
image_api_key: str = ""
# Additional Provider API Keys
openai_api_key: str = ""
elevenlabs_api_key: str = ""
cqtai_api_key: str = ""
minimax_api_key: str = ""
minimax_group_id: str = ""
antigravity_api_key: str = ""
antigravity_api_base: str = ""
# AI Model Configuration
text_model: str = "gemini-2.0-flash"
openai_model: str = "gpt-4o-mini"
tts_model: str = ""
image_model: str = "nano-banana-pro"
tts_minimax_model: str = "speech-2.6-turbo"
tts_elevenlabs_model: str = "eleven_multilingual_v2"
tts_edge_voice: str = "zh-CN-XiaoxiaoNeural"
antigravity_model: str = "gemini-3-pro-image"
# Provider routing (ordered lists)
text_providers: list[str] = Field(default_factory=lambda: ["gemini"])
image_providers: list[str] = Field(default_factory=lambda: ["cqtai"])
tts_providers: list[str] = Field(default_factory=lambda: ["minimax", "elevenlabs", "edge_tts"])
story_audio_cache_dir: str = Field(
"storage/audio",
description="Directory for cached story audio files",
)
# Celery (Redis)
celery_broker_url: str = Field("redis://localhost:6379/0")
celery_result_backend: str = Field("redis://localhost:6379/0")
# Generic Redis
redis_url: str = Field("redis://localhost:6379/0", description="Redis connection URL")
redis_sentinel_enabled: bool = Field(False, description="Whether to enable Redis Sentinel")
redis_sentinel_nodes: str = Field(
"",
description="Comma-separated Redis Sentinel nodes, e.g. host1:26379,host2:26379",
)
redis_sentinel_master_name: str = Field("mymaster", description="Redis Sentinel master name")
redis_sentinel_password: str = Field("", description="Password for Redis Sentinel (optional)")
redis_sentinel_db: int = Field(0, description="Redis DB index when using Sentinel")
redis_sentinel_socket_timeout: float = Field(
0.5,
description="Socket timeout in seconds for Sentinel clients",
)
# Admin console
enable_admin_console: bool = False
admin_username: str = "admin"
admin_password: str = "admin123" # 建议通过环境变量覆盖
# CORS
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173"])
@model_validator(mode="after")
def _require_core_settings(self) -> "Settings": # type: ignore[override]
missing = []
if not self.secret_key or self.secret_key == "change-me-in-production":
missing.append("SECRET_KEY")
if not self.database_url:
missing.append("DATABASE_URL")
if self.redis_sentinel_enabled and not self.redis_sentinel_nodes.strip():
missing.append("REDIS_SENTINEL_NODES")
if missing:
raise ValueError(f"Missing required settings: {', '.join(missing)}")
return self
@property
def redis_sentinel_hosts(self) -> list[tuple[str, int]]:
"""Parse Redis Sentinel nodes into (host, port) tuples."""
nodes = []
raw = self.redis_sentinel_nodes.strip()
if not raw:
return nodes
for item in raw.split(","):
value = item.strip()
if not value:
continue
if ":" not in value:
raise ValueError(f"Invalid sentinel node format: {value}")
host, port_text = value.rsplit(":", 1)
if not host:
raise ValueError(f"Invalid sentinel node host: {value}")
try:
port = int(port_text)
except ValueError as exc:
raise ValueError(f"Invalid sentinel node port: {value}") from exc
nodes.append((host, port))
return nodes
@property
def redis_sentinel_urls(self) -> list[str]:
"""Build Celery-compatible Sentinel URLs with DB index."""
return [
f"sentinel://{host}:{port}/{self.redis_sentinel_db}"
for host, port in self.redis_sentinel_hosts
]
settings = Settings()
celery_result_backend: str = Field("redis://localhost:6379/0")
# Generic Redis
redis_url: str = Field("redis://localhost:6379/0", description="Redis connection URL")
redis_sentinel_enabled: bool = Field(False, description="Whether to enable Redis Sentinel")
redis_sentinel_nodes: str = Field(
"",
description="Comma-separated Redis Sentinel nodes, e.g. host1:26379,host2:26379",
)
redis_sentinel_master_name: str = Field("mymaster", description="Redis Sentinel master name")
redis_sentinel_password: str = Field("", description="Password for Redis Sentinel (optional)")
redis_sentinel_db: int = Field(0, description="Redis DB index when using Sentinel")
redis_sentinel_socket_timeout: float = Field(
0.5,
description="Socket timeout in seconds for Sentinel clients",
)
# Admin console
enable_admin_console: bool = False
admin_username: str = "admin"
admin_password: str = "admin123" # 建议通过环境变量覆盖
# CORS
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173"])
@model_validator(mode="after")
def _require_core_settings(self) -> "Settings": # type: ignore[override]
missing = []
if not self.secret_key or self.secret_key == "change-me-in-production":
missing.append("SECRET_KEY")
if not self.database_url:
missing.append("DATABASE_URL")
if self.redis_sentinel_enabled and not self.redis_sentinel_nodes.strip():
missing.append("REDIS_SENTINEL_NODES")
if missing:
raise ValueError(f"Missing required settings: {', '.join(missing)}")
return self
@property
def redis_sentinel_hosts(self) -> list[tuple[str, int]]:
"""Parse Redis Sentinel nodes into (host, port) tuples."""
nodes = []
raw = self.redis_sentinel_nodes.strip()
if not raw:
return nodes
for item in raw.split(","):
value = item.strip()
if not value:
continue
if ":" not in value:
raise ValueError(f"Invalid sentinel node format: {value}")
host, port_text = value.rsplit(":", 1)
if not host:
raise ValueError(f"Invalid sentinel node host: {value}")
try:
port = int(port_text)
except ValueError as exc:
raise ValueError(f"Invalid sentinel node port: {value}") from exc
nodes.append((host, port))
return nodes
@property
def redis_sentinel_urls(self) -> list[str]:
"""Build Celery-compatible Sentinel URLs with DB index."""
return [
f"sentinel://{host}:{port}/{self.redis_sentinel_db}"
for host, port in self.redis_sentinel_hosts
]
settings = Settings()

View File

@@ -27,10 +27,10 @@ class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(String(255), primary_key=True) # OAuth provider user ID
id: Mapped[str] = mapped_column(String(255), primary_key=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
avatar_url: Mapped[str | None] = mapped_column(String(500))
provider: Mapped[str] = mapped_column(String(50), nullable=False) # github / google
provider: Mapped[str] = mapped_column(String(50), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
@@ -59,11 +59,22 @@ class Story(Base):
String(36), ForeignKey("story_universes.id", ondelete="SET NULL"), nullable=True
)
title: Mapped[str] = mapped_column(String(255), nullable=False)
story_text: Mapped[str] = mapped_column(Text, nullable=True) # 允许为空(绘本模式下)
pages: Mapped[list[dict] | None] = mapped_column(JSON, default=list) # 绘本分页数据
story_text: Mapped[str | None] = mapped_column(Text, nullable=True)
pages: Mapped[list[dict] | None] = mapped_column(JSON, default=list)
cover_prompt: Mapped[str | None] = mapped_column(Text)
image_url: Mapped[str | None] = mapped_column(String(500))
mode: Mapped[str] = mapped_column(String(20), nullable=False, default="generated")
generation_status: Mapped[str] = mapped_column(
String(32), nullable=False, default="narrative_ready"
)
image_status: Mapped[str] = mapped_column(
String(32), nullable=False, default="not_requested"
)
audio_status: Mapped[str] = mapped_column(
String(32), nullable=False, default="not_requested"
)
audio_path: Mapped[str | None] = mapped_column(String(500))
last_error: Mapped[str | None] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
@@ -123,6 +134,7 @@ class ChildProfile(Base):
class StoryUniverse(Base):
"""Story universe entity."""
__tablename__ = "story_universes"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
@@ -142,7 +154,9 @@ class StoryUniverse(Base):
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
child_profile: Mapped["ChildProfile"] = relationship("ChildProfile", back_populates="story_universes")
child_profile: Mapped["ChildProfile"] = relationship(
"ChildProfile", back_populates="story_universes"
)
class ReadingEvent(Base):
@@ -163,6 +177,7 @@ class ReadingEvent(Base):
DateTime(timezone=True), server_default=func.now(), index=True
)
class PushConfig(Base):
"""Push configuration entity."""

View File

@@ -1,4 +1,4 @@
"""故事相关 Pydantic 模型。"""
"""Story-related Pydantic schemas."""
from datetime import datetime
from typing import Literal
@@ -11,7 +11,13 @@ MAX_EDU_THEME_LENGTH = 200
MAX_TTS_LENGTH = 4000
# ==================== 故事模型 ====================
class StoryStatusMixin(BaseModel):
"""Shared generation status fields returned by story APIs."""
generation_status: str
image_status: str
audio_status: str
last_error: str | None = None
class GenerateRequest(BaseModel):
@@ -24,8 +30,8 @@ class GenerateRequest(BaseModel):
universe_id: str | None = None
class StoryResponse(BaseModel):
"""Story response."""
class StoryResponse(StoryStatusMixin):
"""Story generation response."""
id: int
title: str
@@ -37,7 +43,7 @@ class StoryResponse(BaseModel):
universe_id: str | None = None
class StoryListItem(BaseModel):
class StoryListItem(StoryStatusMixin):
"""Story list item."""
id: int
@@ -47,8 +53,8 @@ class StoryListItem(BaseModel):
mode: str
class FullStoryResponse(BaseModel):
"""完整故事响应(含图片和音频状态)。"""
class FullStoryResponse(StoryStatusMixin):
"""Full story response with asset status."""
id: int
title: str
@@ -62,22 +68,19 @@ class FullStoryResponse(BaseModel):
universe_id: str | None = None
# ==================== 绘本模型 ====================
class StorybookRequest(BaseModel):
"""Storybook 生成请求。"""
"""Storybook generation request."""
keywords: str = Field(..., min_length=1, max_length=200)
page_count: int = Field(default=6, ge=4, le=12)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
generate_images: bool = Field(default=False, description="是否同时生成插图")
generate_images: bool = Field(default=False, description="Whether to generate images too.")
child_profile_id: str | None = None
universe_id: str | None = None
class StorybookPageResponse(BaseModel):
"""故事书单页响应。"""
"""One storybook page."""
page_number: int
text: str
@@ -85,8 +88,8 @@ class StorybookPageResponse(BaseModel):
image_url: str | None = None
class StorybookResponse(BaseModel):
"""故事书响应。"""
class StorybookResponse(StoryStatusMixin):
"""Storybook generation response."""
id: int | None = None
title: str
@@ -97,10 +100,29 @@ class StorybookResponse(BaseModel):
cover_url: str | None = None
# ==================== 成就模型 ====================
class StoryDetailResponse(StoryStatusMixin):
"""Story detail response for both stories and storybooks."""
id: int
title: str
story_text: str | None = None
pages: list[StorybookPageResponse] | None = None
cover_prompt: str | None
image_url: str | None
mode: str
child_profile_id: str | None = None
universe_id: str | None = None
class StoryImageResponse(StoryStatusMixin):
"""Cover image generation response."""
image_url: str | None
class AchievementItem(BaseModel):
"""Achievement item returned for a story."""
type: str
description: str
obtained_at: str | None = None

View File

@@ -0,0 +1,38 @@
"""Story audio cache storage helpers."""
from __future__ import annotations
from pathlib import Path
from app.core.config import settings
def build_story_audio_path(story_id: int) -> str:
"""Build the cache path for a story audio file."""
return str(Path(settings.story_audio_cache_dir) / f"story-{story_id}.mp3")
def audio_cache_exists(audio_path: str | None) -> bool:
"""Whether the cached audio file exists on disk."""
return bool(audio_path) and Path(audio_path).is_file()
def read_audio_cache(audio_path: str) -> bytes:
"""Read cached story audio bytes."""
return Path(audio_path).read_bytes()
def write_story_audio_cache(story_id: int, audio_data: bytes) -> str:
"""Persist story audio and return the saved file path."""
final_path = Path(build_story_audio_path(story_id))
final_path.parent.mkdir(parents=True, exist_ok=True)
temp_path = final_path.with_suffix(".tmp")
temp_path.write_bytes(audio_data)
temp_path.replace(final_path)
return str(final_path)

View File

@@ -1,12 +1,9 @@
"""Story business logic service."""
import asyncio
import json
import uuid
from typing import Literal
from fastapi import HTTPException
from sqlalchemy import select, desc
from sqlalchemy import desc, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
@@ -15,90 +12,151 @@ from app.db.models import ChildProfile, Story, StoryUniverse
from app.schemas.story_schemas import (
GenerateRequest,
StorybookRequest,
FullStoryResponse,
StorybookResponse,
FullStoryResponse,
StorybookResponse,
StorybookPageResponse,
AchievementItem,
)
from app.services.audio_storage import (
audio_cache_exists,
read_audio_cache,
write_story_audio_cache,
)
from app.services.memory_service import build_enhanced_memory_context
from app.services.provider_router import (
generate_story_content,
generate_image,
generate_storybook,
)
from app.services.story_status import (
StoryAssetStatus,
sync_story_status,
)
from app.tasks.achievements import extract_story_achievements
logger = get_logger(__name__)
async def validate_profile_and_universe(
profile_id: str | None,
universe_id: str | None,
user_id: str,
db: AsyncSession,
) -> tuple[str | None, str | None]:
"""Validate child profile and universe ownership/relationship."""
if not profile_id and not universe_id:
return None, None
def _build_storybook_error_message(
*,
cover_failed: bool,
failed_pages: list[int],
) -> str | None:
"""Summarize storybook image generation errors for the latest attempt."""
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user_id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user_id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
return profile_id, universe_id
parts: list[str] = []
if cover_failed:
parts.append("封面生成失败")
if failed_pages:
pages = "".join(str(page) for page in sorted(failed_pages))
parts.append(f"{pages} 页插图生成失败")
return "".join(parts) if parts else None
def _resolve_storybook_image_status(
*,
generate_images: bool,
cover_prompt: str | None,
cover_url: str | None,
pages_data: list[dict],
) -> StoryAssetStatus:
"""Resolve the persisted image status for a storybook."""
if not generate_images:
return StoryAssetStatus.NOT_REQUESTED
expected_assets = 0
ready_assets = 0
if cover_prompt or cover_url:
expected_assets += 1
if cover_url:
ready_assets += 1
for page in pages_data:
if not page.get("image_prompt") and not page.get("image_url"):
continue
expected_assets += 1
if page.get("image_url"):
ready_assets += 1
if expected_assets == 0:
return StoryAssetStatus.NOT_REQUESTED
if ready_assets == expected_assets:
return StoryAssetStatus.READY
return StoryAssetStatus.FAILED
async def validate_profile_and_universe(
profile_id: str | None,
universe_id: str | None,
user_id: str,
db: AsyncSession,
) -> tuple[str | None, str | None]:
"""Validate child profile and universe ownership/relationship."""
if not profile_id and not universe_id:
return None, None
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user_id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user_id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
return profile_id, universe_id
async def generate_and_save_story(
request: GenerateRequest,
user_id: str,
db: AsyncSession,
) -> Story:
"""Generate generic story content and save to DB."""
# 1. Validate
profile_id, universe_id = await validate_profile_and_universe(
request.child_profile_id, request.universe_id, user_id, db
)
# 2. Build Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# 3. Generate
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as exc:
raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc
# 4. Save
"""Generate generic story content and save to DB."""
# 1. Validate
profile_id, universe_id = await validate_profile_and_universe(
request.child_profile_id, request.universe_id, user_id, db
)
# 2. Build Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# 3. Generate
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as exc:
raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc
# 4. Save
story = Story(
user_id=user_id,
child_profile_id=profile_id,
@@ -108,170 +166,209 @@ async def generate_and_save_story(
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
sync_story_status(
story,
image_status=StoryAssetStatus.NOT_REQUESTED,
audio_status=StoryAssetStatus.NOT_REQUESTED,
last_error=None,
)
db.add(story)
await db.commit()
await db.refresh(story)
# 5. Trigger Async Tasks
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return story
# 5. Trigger Async Tasks
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return story
async def generate_full_story_service(
request: GenerateRequest,
user_id: str,
db: AsyncSession,
) -> FullStoryResponse:
"""Generate story with parallel image generation."""
# 1. Generate text part
# We can reuse logic or call generate_story_content directly if we want finer control
# reusing generate_and_save_story to ensure consistency (it handles validation + saving)
story = await generate_and_save_story(request, user_id, db)
# 2. Generate Image (Parallel/Async step in this flow)
"""Generate story with parallel image generation."""
# 1. Generate text part
# We can reuse logic or call generate_story_content directly if we want finer control
# reusing generate_and_save_story to ensure consistency (it handles validation + saving)
story = await generate_and_save_story(request, user_id, db)
# 2. Generate Image (Parallel/Async step in this flow)
image_url: str | None = None
errors: dict[str, str | None] = {}
if story.cover_prompt:
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
sync_story_status(
story,
image_status=StoryAssetStatus.READY,
)
await db.commit()
except Exception as exc:
errors["image"] = str(exc)
sync_story_status(
story,
image_status=StoryAssetStatus.FAILED,
last_error=str(exc),
)
await db.commit()
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
return FullStoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=image_url,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=image_url,
audio_ready=False,
mode=story.mode,
errors=errors,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
generation_status=story.generation_status,
image_status=story.image_status,
audio_status=story.audio_status,
last_error=story.last_error,
)
async def generate_storybook_service(
request: StorybookRequest,
user_id: str,
db: AsyncSession,
) -> StorybookResponse:
"""Generate storybook with parallel image generation for pages."""
# 1. Validate
profile_id, universe_id = await validate_profile_and_universe(
request.child_profile_id, request.universe_id, user_id, db
)
logger.info(
"storybook_request",
user_id=user_id,
keywords=request.keywords,
page_count=request.page_count,
profile_id=profile_id,
universe_id=universe_id,
)
# 2. Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# 3. Generate Text Structure
try:
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
"""Generate storybook with parallel image generation for pages."""
# 1. Validate
profile_id, universe_id = await validate_profile_and_universe(
request.child_profile_id, request.universe_id, user_id, db
)
logger.info(
"storybook_request",
user_id=user_id,
keywords=request.keywords,
page_count=request.page_count,
profile_id=profile_id,
universe_id=universe_id,
)
# 2. Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# 3. Generate Text Structure
try:
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
# 4. Parallel Image Generation
final_cover_url = storybook.cover_url
cover_failed = False
failed_pages: list[int] = []
if request.generate_images:
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
tasks = []
# Cover Task
async def _gen_cover():
nonlocal cover_failed
if storybook.cover_prompt and not storybook.cover_url:
try:
return await generate_image(storybook.cover_prompt, db=db)
except Exception as e:
logger.warning("cover_gen_failed", error=str(e))
except Exception as exc:
cover_failed = True
logger.warning("cover_gen_failed", error=str(exc))
return storybook.cover_url
tasks.append(_gen_cover())
# Page Tasks
async def _gen_page(page):
if page.image_prompt and not page.image_url:
try:
url = await generate_image(page.image_prompt, db=db)
page.image_url = url
except Exception as e:
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
page.image_url = await generate_image(page.image_prompt, db=db)
except Exception as exc:
failed_pages.append(page.page_number)
logger.warning("page_gen_failed", page=page.page_number, error=str(exc))
for page in storybook.pages:
tasks.append(_gen_page(page))
# Execute
results = await asyncio.gather(*tasks, return_exceptions=True)
# Update cover result
cover_res = results[0]
if isinstance(cover_res, str):
final_cover_url = cover_res
logger.info("storybook_parallel_generation_complete")
# 5. Save to DB
pages_data = [
{
"page_number": p.page_number,
"text": p.text,
"image_prompt": p.image_prompt,
"image_url": p.image_url,
}
for p in storybook.pages
]
# 5. Save to DB
pages_data = [
{
"page_number": p.page_number,
"text": p.text,
"image_prompt": p.image_prompt,
"image_url": p.image_url,
}
for p in storybook.pages
]
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=storybook.title,
mode="storybook",
pages=pages_data,
mode="storybook",
pages=pages_data,
story_text=None,
cover_prompt=storybook.cover_prompt,
image_url=final_cover_url,
)
sync_story_status(
story,
image_status=_resolve_storybook_image_status(
generate_images=request.generate_images,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
pages_data=pages_data,
),
audio_status=StoryAssetStatus.NOT_REQUESTED,
last_error=_build_storybook_error_message(
cover_failed=cover_failed,
failed_pages=failed_pages,
),
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# 6. Build Response
response_pages = [
StorybookPageResponse(
page_number=p["page_number"],
text=p["text"],
image_prompt=p["image_prompt"],
image_url=p.get("image_url"),
)
for p in pages_data
]
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# 6. Build Response
response_pages = [
StorybookPageResponse(
page_number=p["page_number"],
text=p["text"],
image_prompt=p["image_prompt"],
image_url=p.get("image_url"),
)
for p in pages_data
]
return StorybookResponse(
id=story.id,
title=storybook.title,
@@ -280,155 +377,209 @@ async def generate_storybook_service(
pages=response_pages,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
generation_status=story.generation_status,
image_status=story.image_status,
audio_status=story.audio_status,
last_error=story.last_error,
)
# ==================== Missing Endpoints Logic (for Issue #5) ====================
async def list_stories(
user_id: str,
limit: int,
offset: int,
db: AsyncSession,
) -> list[Story]:
"""List stories for user."""
result = await db.execute(
select(Story)
.where(Story.user_id == user_id)
.order_by(desc(Story.created_at))
.offset(offset)
.limit(limit)
)
return result.scalars().all()
async def get_story_detail(
story_id: int,
user_id: str,
db: AsyncSession,
) -> Story:
"""Get story detail."""
result = await db.execute(
select(Story).where(Story.id == story_id, Story.user_id == user_id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
return story
async def delete_story(
story_id: int,
user_id: str,
db: AsyncSession,
) -> None:
"""Delete a story."""
story = await get_story_detail(story_id, user_id, db)
await db.delete(story)
await db.commit()
# ==================== Missing Endpoints Logic (for Issue #5) ====================
async def list_stories(
user_id: str,
limit: int,
offset: int,
db: AsyncSession,
) -> list[Story]:
"""List stories for user."""
result = await db.execute(
select(Story)
.where(Story.user_id == user_id)
.order_by(desc(Story.created_at))
.offset(offset)
.limit(limit)
)
return result.scalars().all()
async def get_story_detail(
story_id: int,
user_id: str,
db: AsyncSession,
) -> Story:
"""Get story detail."""
result = await db.execute(
select(Story).where(Story.id == story_id, Story.user_id == user_id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
return story
async def delete_story(
story_id: int,
user_id: str,
db: AsyncSession,
) -> None:
"""Delete a story."""
story = await get_story_detail(story_id, user_id, db)
await db.delete(story)
await db.commit()
async def create_story_from_result(
result, # StoryOutput
user_id: str,
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
) -> Story:
"""Save a generated story to DB (helper for stream endpoint)."""
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
db: AsyncSession,
) -> Story:
"""Save a generated story to DB (helper for stream endpoint)."""
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
sync_story_status(
story,
image_status=StoryAssetStatus.NOT_REQUESTED,
audio_status=StoryAssetStatus.NOT_REQUESTED,
last_error=None,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return story
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return story
async def generate_story_cover(
story_id: int,
user_id: str,
db: AsyncSession,
) -> str:
"""Generate cover image for an existing story."""
story = await get_story_detail(story_id, user_id, db)
"""Generate cover image for an existing story."""
story = await get_story_detail(story_id, user_id, db)
if not story.cover_prompt:
raise HTTPException(status_code=400, detail="Story has no cover prompt")
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
sync_story_status(
story,
image_status=StoryAssetStatus.READY,
)
await db.commit()
return image_url
except Exception as e:
sync_story_status(
story,
image_status=StoryAssetStatus.FAILED,
last_error=str(e),
)
await db.commit()
logger.error("cover_generation_failed", story_id=story_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
async def generate_story_audio(
story_id: int,
user_id: str,
db: AsyncSession,
) -> bytes:
"""Generate audio for a story."""
story = await get_story_detail(story_id, user_id, db)
"""Generate audio for a story."""
story = await get_story_detail(story_id, user_id, db)
if not story.story_text:
raise HTTPException(status_code=400, detail="Story has no text")
# TODO: Check if audio is already cached/saved?
# For now, generate on the fly via provider
if story.audio_path and audio_cache_exists(story.audio_path):
if story.audio_status != StoryAssetStatus.READY.value:
sync_story_status(story, audio_status=StoryAssetStatus.READY)
await db.commit()
return read_audio_cache(story.audio_path)
if story.audio_path and not audio_cache_exists(story.audio_path):
logger.warning(
"story_audio_cache_missing",
story_id=story_id,
audio_path=story.audio_path,
)
story.audio_path = None
if story.audio_status == StoryAssetStatus.READY.value:
sync_story_status(story, audio_status=StoryAssetStatus.NOT_REQUESTED)
await db.commit()
from app.services.provider_router import text_to_speech
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
audio_data = await text_to_speech(story.story_text, db=db)
story.audio_path = write_story_audio_cache(story.id, audio_data)
sync_story_status(
story,
audio_status=StoryAssetStatus.READY,
)
await db.commit()
return audio_data
except Exception as e:
story.audio_path = None
sync_story_status(
story,
audio_status=StoryAssetStatus.FAILED,
last_error=str(e),
)
await db.commit()
logger.error("audio_generation_failed", story_id=story_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
async def get_story_achievements(
story_id: int,
user_id: str,
db: AsyncSession,
) -> list[AchievementItem]:
"""Get achievements unlocked by a specific story."""
result = await db.execute(
select(Story)
.options(joinedload(Story.story_universe))
.where(Story.id == story_id, Story.user_id == user_id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
if not story.universe_id or not story.story_universe:
return []
universe = story.story_universe
if not universe.achievements:
return []
results = []
for ach in universe.achievements:
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
results.append(AchievementItem(
type=ach.get("type", "Unknown"),
description=ach.get("description", ""),
obtained_at=ach.get("obtained_at")
))
return results
async def get_story_achievements(
story_id: int,
user_id: str,
db: AsyncSession,
) -> list[AchievementItem]:
"""Get achievements unlocked by a specific story."""
result = await db.execute(
select(Story)
.options(joinedload(Story.story_universe))
.where(Story.id == story_id, Story.user_id == user_id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
if not story.universe_id or not story.story_universe:
return []
universe = story.story_universe
if not universe.achievements:
return []
results = []
for ach in universe.achievements:
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
results.append(AchievementItem(
type=ach.get("type", "Unknown"),
description=ach.get("description", ""),
obtained_at=ach.get("obtained_at")
))
return results

View File

@@ -0,0 +1,112 @@
"""Story generation status helpers."""
from __future__ import annotations
from enum import Enum
from typing import Protocol
class StoryGenerationStatus(str, Enum):
"""Overall story generation lifecycle."""
NARRATIVE_READY = "narrative_ready"
ASSETS_GENERATING = "assets_generating"
COMPLETED = "completed"
DEGRADED_COMPLETED = "degraded_completed"
FAILED = "failed"
class StoryAssetStatus(str, Enum):
"""Asset generation state for image and audio."""
NOT_REQUESTED = "not_requested"
GENERATING = "generating"
READY = "ready"
FAILED = "failed"
class StoryLike(Protocol):
"""Protocol for story-like objects used by status helpers."""
story_text: str | None
pages: list[dict] | None
generation_status: str
image_status: str
audio_status: str
last_error: str | None
_ERROR_UNSET = object()
def _normalize_asset_status(value: str | None) -> StoryAssetStatus:
if not value:
return StoryAssetStatus.NOT_REQUESTED
try:
return StoryAssetStatus(value)
except ValueError:
return StoryAssetStatus.NOT_REQUESTED
def has_narrative_content(story: StoryLike) -> bool:
"""Whether the story already has readable content."""
return bool(story.story_text) or bool(story.pages)
def resolve_story_generation_status(story: StoryLike) -> StoryGenerationStatus:
"""Derive the overall status from narrative and asset states."""
if not has_narrative_content(story):
return StoryGenerationStatus.FAILED
image_status = _normalize_asset_status(story.image_status)
audio_status = _normalize_asset_status(story.audio_status)
if StoryAssetStatus.GENERATING in (image_status, audio_status):
return StoryGenerationStatus.ASSETS_GENERATING
if StoryAssetStatus.FAILED in (image_status, audio_status):
return StoryGenerationStatus.DEGRADED_COMPLETED
if (
image_status == StoryAssetStatus.NOT_REQUESTED
and audio_status == StoryAssetStatus.NOT_REQUESTED
):
return StoryGenerationStatus.NARRATIVE_READY
return StoryGenerationStatus.COMPLETED
def has_failed_assets(story: StoryLike) -> bool:
"""Whether any persisted asset is still in a failed state."""
image_status = _normalize_asset_status(story.image_status)
audio_status = _normalize_asset_status(story.audio_status)
return StoryAssetStatus.FAILED in (image_status, audio_status)
def sync_story_status(
story: StoryLike,
*,
image_status: StoryAssetStatus | None = None,
audio_status: StoryAssetStatus | None = None,
last_error: str | None | object = _ERROR_UNSET,
) -> None:
"""Update asset statuses and refresh overall generation status."""
if image_status is not None:
story.image_status = image_status.value
if audio_status is not None:
story.audio_status = audio_status.value
if last_error is not _ERROR_UNSET:
story.last_error = last_error
generation_status = resolve_story_generation_status(story)
story.generation_status = generation_status.value
if last_error is _ERROR_UNSET and not has_failed_assets(story):
story.last_error = None

View File

@@ -1,4 +1,4 @@
"""测试配置和 fixtures"""
"""Pytest fixtures for backend tests."""
import os
from collections.abc import AsyncGenerator
@@ -11,6 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing")
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
from app.core.config import settings
from app.core.security import create_access_token
from app.db.database import get_db
from app.db.models import Base, Story, User
@@ -19,7 +20,8 @@ from app.main import app
@pytest.fixture
async def async_engine():
"""创建内存数据库引擎。"""
"""Create an in-memory database engine."""
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
@@ -29,7 +31,8 @@ async def async_engine():
@pytest.fixture
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""创建数据库会话。"""
"""Create a database session."""
session_factory = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
@@ -39,7 +42,8 @@ async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
@pytest.fixture
async def test_user(db_session: AsyncSession) -> User:
"""创建测试用户。"""
"""Create a test user."""
user = User(
id="github:12345",
name="Test User",
@@ -54,13 +58,74 @@ async def test_user(db_session: AsyncSession) -> User:
@pytest.fixture
async def test_story(db_session: AsyncSession, test_user: User) -> Story:
"""创建测试故事。"""
"""Create a plain generated story."""
story = Story(
user_id=test_user.id,
title="测试故事",
story_text="从前有一只小兔子...",
story_text="从前有一只小兔子",
cover_prompt="A cute rabbit in a forest",
mode="generated",
generation_status="narrative_ready",
image_status="not_requested",
audio_status="not_requested",
)
db_session.add(story)
await db_session.commit()
await db_session.refresh(story)
return story
@pytest.fixture
async def storybook_story(db_session: AsyncSession, test_user: User) -> Story:
"""Create a storybook-mode story."""
story = Story(
user_id=test_user.id,
title="森林绘本冒险",
story_text=None,
pages=[
{
"page_number": 1,
"text": "小兔子走进了会发光的森林。",
"image_prompt": "A glowing forest with a curious rabbit",
"image_url": "https://example.com/page-1.png",
},
{
"page_number": 2,
"text": "它遇见了一位会唱歌的萤火虫朋友。",
"image_prompt": "A rabbit meeting a singing firefly",
"image_url": None,
},
],
cover_prompt="A magical forest storybook cover",
image_url="https://example.com/storybook-cover.png",
mode="storybook",
generation_status="degraded_completed",
image_status="failed",
audio_status="not_requested",
last_error="第 2 页插图生成失败",
)
db_session.add(story)
await db_session.commit()
await db_session.refresh(story)
return story
@pytest.fixture
async def degraded_story_with_text(db_session: AsyncSession, test_user: User) -> Story:
"""Create a readable story whose image generation already failed."""
story = Story(
user_id=test_user.id,
title="部分完成的测试故事",
story_text="从前有一只小兔子继续冒险。",
cover_prompt="A rabbit under the moon",
mode="generated",
generation_status="degraded_completed",
image_status="failed",
audio_status="not_requested",
last_error="封面生成失败",
)
db_session.add(story)
await db_session.commit()
@@ -70,13 +135,14 @@ async def test_story(db_session: AsyncSession, test_user: User) -> Story:
@pytest.fixture
def auth_token(test_user: User) -> str:
"""生成测试用户的 JWT token"""
"""Create a JWT token for the test user."""
return create_access_token({"sub": test_user.id})
@pytest.fixture
def client(db_session: AsyncSession) -> TestClient:
"""创建测试客户端。"""
"""Create a test client."""
async def override_get_db():
yield db_session
@@ -89,35 +155,45 @@ def client(db_session: AsyncSession) -> TestClient:
@pytest.fixture
def auth_client(client: TestClient, auth_token: str) -> TestClient:
"""带认证的测试客户端。"""
"""Create an authenticated test client."""
client.cookies.set("access_token", auth_token)
return client
@pytest.fixture(autouse=True)
def bypass_rate_limit():
"""默认绕过限流,让非限流测试正常运行。"""
"""Bypass rate limiting in most tests."""
with patch("app.core.rate_limiter.get_redis", new_callable=AsyncMock) as mock_redis:
# 创建一个模拟的 Redis 客户端,所有操作返回安全默认值
redis_instance = AsyncMock()
redis_instance.incr.return_value = 1 # 始终返回 1 (不触发限流)
redis_instance.incr.return_value = 1
redis_instance.expire.return_value = True
redis_instance.get.return_value = None # 无锁定记录
redis_instance.get.return_value = None
redis_instance.ttl.return_value = 0
redis_instance.delete.return_value = 1
mock_redis.return_value = redis_instance
yield redis_instance
@pytest.fixture(autouse=True)
def isolated_story_audio_cache(tmp_path, monkeypatch):
"""Use an isolated directory for cached story audio files."""
monkeypatch.setattr(settings, "story_audio_cache_dir", str(tmp_path / "audio"))
yield
@pytest.fixture
def mock_text_provider():
"""Mock 文本生成适配器 API 调用。"""
"""Mock text generation."""
from app.services.adapters.text.models import StoryOutput
mock_result = StoryOutput(
mode="generated",
title="小兔子的冒险",
story_text="从前有一只小兔子...",
story_text="从前有一只小兔子",
cover_prompt_suggestion="A cute rabbit",
)
@@ -128,7 +204,8 @@ def mock_text_provider():
@pytest.fixture
def mock_image_provider():
"""Mock 图像生成。"""
"""Mock image generation."""
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock:
mock.return_value = "https://example.com/image.png"
yield mock
@@ -136,7 +213,8 @@ def mock_image_provider():
@pytest.fixture
def mock_tts_provider():
"""Mock TTS。"""
"""Mock text-to-speech generation."""
with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock:
mock.return_value = b"fake-audio-bytes"
yield mock
@@ -144,7 +222,8 @@ def mock_tts_provider():
@pytest.fixture
def mock_all_providers(mock_text_provider, mock_image_provider, mock_tts_provider):
"""Mock 所有 AI 供应商。"""
"""Group all mocked providers."""
return {
"text_primary": mock_text_provider,
"image_primary": mock_image_provider,

View File

@@ -1,26 +1,41 @@
"""故事 API 测试。"""
"""Tests for story-related API endpoints."""
from pathlib import Path
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from app.core.config import settings
from app.services.adapters.storybook.primary import Storybook, StorybookPage
# ── 注意 ──────────────────────────────────────────────────────────────────────
# 以下路由尚未实现 (stories.py 中没有对应端点),相关测试标记为 skip:
# GET /api/stories (列表)
# GET /api/stories/{id} (详情)
# DELETE /api/stories/{id} (删除)
# POST /api/image/generate/{id} (封面图片生成)
# GET /api/audio/{id} (音频)
# 实现后请取消 skip 标记。
def build_storybook_output() -> Storybook:
"""Create a reusable mocked storybook payload."""
return Storybook(
title="森林里的发光冒险",
main_character="小兔子露露",
art_style="温暖水彩",
cover_prompt="A glowing forest storybook cover",
pages=[
StorybookPage(
page_number=1,
text="露露第一次走进会发光的森林。",
image_prompt="Lulu entering a glowing forest",
),
StorybookPage(
page_number=2,
text="她遇到了一只会唱歌的萤火虫。",
image_prompt="Lulu meeting a singing firefly",
),
],
)
class TestStoryGenerate:
"""故事生成测试。"""
"""Tests for basic story generation."""
def test_generate_without_auth(self, client: TestClient):
"""未登录时生成故事。"""
response = client.post(
"/api/stories/generate",
json={"type": "keywords", "data": "小兔子, 森林"},
@@ -28,7 +43,6 @@ class TestStoryGenerate:
assert response.status_code == 401
def test_generate_with_empty_data(self, auth_client: TestClient):
"""空数据生成故事。"""
response = auth_client.post(
"/api/stories/generate",
json={"type": "keywords", "data": ""},
@@ -36,7 +50,6 @@ class TestStoryGenerate:
assert response.status_code == 422
def test_generate_with_invalid_type(self, auth_client: TestClient):
"""无效类型生成故事。"""
response = auth_client.post(
"/api/stories/generate",
json={"type": "invalid", "data": "test"},
@@ -44,7 +57,6 @@ class TestStoryGenerate:
assert response.status_code == 422
def test_generate_story_success(self, auth_client: TestClient, mock_text_provider):
"""成功生成故事。"""
response = auth_client.post(
"/api/stories/generate",
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
@@ -55,82 +67,96 @@ class TestStoryGenerate:
assert "title" in data
assert "story_text" in data
assert data["mode"] == "generated"
assert data["generation_status"] == "narrative_ready"
assert data["image_status"] == "not_requested"
assert data["audio_status"] == "not_requested"
assert data["last_error"] is None
class TestStoryList:
"""故事列表测试。"""
"""Tests for story listing."""
def test_list_without_auth(self, client: TestClient):
"""未登录时获取列表。"""
response = client.get("/api/stories")
assert response.status_code == 401
def test_list_empty(self, auth_client: TestClient):
"""空列表。"""
response = auth_client.get("/api/stories")
assert response.status_code == 200
assert response.json() == []
def test_list_with_stories(self, auth_client: TestClient, test_story):
"""有故事时获取列表。"""
response = auth_client.get("/api/stories")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["id"] == test_story.id
assert data[0]["title"] == test_story.title
assert data[0]["generation_status"] == "narrative_ready"
assert data[0]["image_status"] == "not_requested"
assert data[0]["audio_status"] == "not_requested"
def test_list_pagination(self, auth_client: TestClient, test_story):
"""分页测试。"""
response = auth_client.get("/api/stories?limit=1&offset=0")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert len(response.json()) == 1
response = auth_client.get("/api/stories?limit=1&offset=1")
assert response.status_code == 200
data = response.json()
assert len(data) == 0
assert len(response.json()) == 0
class TestStoryDetail:
"""故事详情测试。"""
"""Tests for story detail retrieval."""
def test_get_story_without_auth(self, client: TestClient, test_story):
"""未登录时获取详情。"""
response = client.get(f"/api/stories/{test_story.id}")
assert response.status_code == 401
def test_get_story_not_found(self, auth_client: TestClient):
"""故事不存在。"""
response = auth_client.get("/api/stories/99999")
assert response.status_code == 404
def test_get_story_success(self, auth_client: TestClient, test_story):
"""成功获取详情。"""
response = auth_client.get(f"/api/stories/{test_story.id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == test_story.id
assert data["title"] == test_story.title
assert data["story_text"] == test_story.story_text
assert data["generation_status"] == "narrative_ready"
assert data["image_status"] == "not_requested"
assert data["audio_status"] == "not_requested"
assert data["last_error"] is None
def test_get_storybook_success(self, auth_client: TestClient, storybook_story):
response = auth_client.get(f"/api/stories/{storybook_story.id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == storybook_story.id
assert data["mode"] == "storybook"
assert data["story_text"] is None
assert len(data["pages"]) == 2
assert data["pages"][0]["page_number"] == 1
assert data["image_url"] == "https://example.com/storybook-cover.png"
assert data["generation_status"] == "degraded_completed"
assert data["image_status"] == "failed"
assert data["audio_status"] == "not_requested"
assert "第 2 页" in data["last_error"]
class TestStoryDelete:
"""故事删除测试。"""
"""Tests for story deletion."""
def test_delete_without_auth(self, client: TestClient, test_story):
"""未登录时删除。"""
response = client.delete(f"/api/stories/{test_story.id}")
assert response.status_code == 401
def test_delete_not_found(self, auth_client: TestClient):
"""删除不存在的故事。"""
response = auth_client.delete("/api/stories/99999")
assert response.status_code == 404
def test_delete_success(self, auth_client: TestClient, test_story):
"""成功删除故事。"""
response = auth_client.delete(f"/api/stories/{test_story.id}")
assert response.status_code == 200
assert response.json()["message"] == "Deleted"
@@ -140,11 +166,14 @@ class TestStoryDelete:
class TestRateLimit:
"""Rate limit 测试。"""
"""Tests for story generation rate limiting."""
def test_rate_limit_allows_normal_requests(self, auth_client: TestClient, mock_text_provider, bypass_rate_limit):
"""正常请求不触发限流。"""
# bypass_rate_limit 默认 incr 返回 1不触发限流
def test_rate_limit_allows_normal_requests(
self,
auth_client: TestClient,
mock_text_provider,
bypass_rate_limit,
):
for _ in range(3):
response = auth_client.post(
"/api/stories/generate",
@@ -152,9 +181,11 @@ class TestRateLimit:
)
assert response.status_code == 200
def test_rate_limit_blocks_excess_requests(self, auth_client: TestClient, bypass_rate_limit):
"""超限请求被阻止。"""
# 让 incr 返回超限值 (> RATE_LIMIT_REQUESTS)
def test_rate_limit_blocks_excess_requests(
self,
auth_client: TestClient,
bypass_rate_limit,
):
bypass_rate_limit.incr.return_value = 11
response = auth_client.post(
@@ -166,52 +197,118 @@ class TestRateLimit:
class TestImageGenerate:
"""封面图片生成测试。"""
"""Tests for cover generation endpoint."""
def test_generate_image_without_auth(self, client: TestClient, test_story):
"""未登录时生成图片。"""
response = client.post(f"/api/image/generate/{test_story.id}")
assert response.status_code == 401
def test_generate_image_not_found(self, auth_client: TestClient):
"""故事不存在。"""
response = auth_client.post("/api/image/generate/99999")
assert response.status_code == 404
class TestAudio:
"""语音朗读测试。"""
"""Tests for story audio endpoint."""
def test_get_audio_without_auth(self, client: TestClient, test_story):
"""未登录时获取音频。"""
response = client.get(f"/api/audio/{test_story.id}")
assert response.status_code == 401
def test_get_audio_not_found(self, auth_client: TestClient):
"""故事不存在。"""
response = auth_client.get("/api/audio/99999")
assert response.status_code == 404
def test_get_audio_success(self, auth_client: TestClient, test_story, mock_tts_provider):
"""成功获取音频。"""
def test_get_audio_success(
self,
auth_client: TestClient,
test_story,
mock_tts_provider,
):
response = auth_client.get(f"/api/audio/{test_story.id}")
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert response.content == b"fake-audio-bytes"
cached_audio_path = Path(settings.story_audio_cache_dir) / f"story-{test_story.id}.mp3"
assert cached_audio_path.is_file()
second_response = auth_client.get(f"/api/audio/{test_story.id}")
assert second_response.status_code == 200
assert second_response.content == b"fake-audio-bytes"
mock_tts_provider.assert_awaited_once()
detail_response = auth_client.get(f"/api/stories/{test_story.id}")
detail = detail_response.json()
assert detail["audio_status"] == "ready"
assert detail["generation_status"] == "completed"
assert detail["last_error"] is None
def test_get_audio_regenerates_when_cache_file_is_missing(
self,
auth_client: TestClient,
test_story,
mock_tts_provider,
):
first_response = auth_client.get(f"/api/audio/{test_story.id}")
assert first_response.status_code == 200
cached_audio_path = Path(settings.story_audio_cache_dir) / f"story-{test_story.id}.mp3"
cached_audio_path.unlink()
mock_tts_provider.reset_mock()
second_response = auth_client.get(f"/api/audio/{test_story.id}")
assert second_response.status_code == 200
assert second_response.content == b"fake-audio-bytes"
assert cached_audio_path.is_file()
mock_tts_provider.assert_awaited_once()
def test_get_audio_failure_updates_status(self, auth_client: TestClient, test_story):
with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock_tts:
mock_tts.side_effect = Exception("TTS provider timeout")
response = auth_client.get(f"/api/audio/{test_story.id}")
assert response.status_code == 500
detail_response = auth_client.get(f"/api/stories/{test_story.id}")
detail = detail_response.json()
assert detail["audio_status"] == "failed"
assert detail["generation_status"] == "degraded_completed"
assert "TTS provider timeout" in detail["last_error"]
def test_get_audio_success_preserves_existing_image_error(
self,
auth_client: TestClient,
degraded_story_with_text,
mock_tts_provider,
):
response = auth_client.get(f"/api/audio/{degraded_story_with_text.id}")
assert response.status_code == 200
assert response.content == b"fake-audio-bytes"
mock_tts_provider.assert_awaited_once()
detail_response = auth_client.get(f"/api/stories/{degraded_story_with_text.id}")
detail = detail_response.json()
assert detail["audio_status"] == "ready"
assert detail["generation_status"] == "degraded_completed"
assert detail["last_error"] == "封面生成失败"
class TestGenerateFull:
"""完整故事生成测试(/api/stories/generate/full"""
"""Tests for complete story generation."""
def test_generate_full_without_auth(self, client: TestClient):
"""未登录时生成完整故事。"""
response = client.post(
"/api/stories/generate/full",
json={"type": "keywords", "data": "小兔子, 森林"},
)
assert response.status_code == 401
def test_generate_full_success(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
"""成功生成完整故事(含图片)。"""
def test_generate_full_success(
self,
auth_client: TestClient,
mock_text_provider,
mock_image_provider,
):
response = auth_client.post(
"/api/stories/generate/full",
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
@@ -223,11 +320,14 @@ class TestGenerateFull:
assert "story_text" in data
assert data["mode"] == "generated"
assert data["image_url"] == "https://example.com/image.png"
assert data["audio_ready"] is False # 音频按需生成
assert data["audio_ready"] is False
assert data["errors"] == {}
assert data["generation_status"] == "completed"
assert data["image_status"] == "ready"
assert data["audio_status"] == "not_requested"
assert data["last_error"] is None
def test_generate_full_image_failure(self, auth_client: TestClient, mock_text_provider):
"""图片生成失败时返回部分成功。"""
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_img:
mock_img.side_effect = Exception("Image API error")
response = auth_client.post(
@@ -239,9 +339,17 @@ class TestGenerateFull:
assert data["image_url"] is None
assert "image" in data["errors"]
assert "Image API error" in data["errors"]["image"]
assert data["generation_status"] == "degraded_completed"
assert data["image_status"] == "failed"
assert data["audio_status"] == "not_requested"
assert "Image API error" in data["last_error"]
def test_generate_full_with_education_theme(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
"""带教育主题生成故事。"""
def test_generate_full_with_education_theme(
self,
auth_client: TestClient,
mock_text_provider,
mock_image_provider,
):
response = auth_client.post(
"/api/stories/generate/full",
json={
@@ -257,11 +365,80 @@ class TestGenerateFull:
class TestImageGenerateSuccess:
"""封面图片生成成功测试。"""
"""Tests for successful cover generation."""
def test_generate_image_success(self, auth_client: TestClient, test_story, mock_image_provider):
"""成功生成图片。"""
def test_generate_image_success(
self,
auth_client: TestClient,
test_story,
mock_image_provider,
):
response = auth_client.post(f"/api/image/generate/{test_story.id}")
assert response.status_code == 200
data = response.json()
assert data["image_url"] == "https://example.com/image.png"
assert data["generation_status"] == "completed"
assert data["image_status"] == "ready"
assert data["audio_status"] == "not_requested"
assert data["last_error"] is None
class TestStorybookGenerate:
"""Tests for storybook generation status handling."""
def test_generate_storybook_success(self, auth_client: TestClient):
with patch("app.services.story_service.generate_storybook", new_callable=AsyncMock) as mock_storybook:
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_image:
mock_storybook.return_value = build_storybook_output()
mock_image.side_effect = [
"https://example.com/storybook-cover.png",
"https://example.com/storybook-page-1.png",
"https://example.com/storybook-page-2.png",
]
response = auth_client.post(
"/api/storybook/generate",
json={
"keywords": "森林, 发光, 友情",
"page_count": 6,
"generate_images": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] is not None
assert data["generation_status"] == "completed"
assert data["image_status"] == "ready"
assert data["audio_status"] == "not_requested"
assert data["last_error"] is None
assert len(data["pages"]) == 2
assert data["cover_url"] == "https://example.com/storybook-cover.png"
def test_generate_storybook_partial_image_failure(self, auth_client: TestClient):
async def image_side_effect(prompt: str, **kwargs):
if "singing firefly" in prompt:
raise Exception("Image API error")
slug = prompt.split()[0].lower()
return f"https://example.com/{slug}.png"
with patch("app.services.story_service.generate_storybook", new_callable=AsyncMock) as mock_storybook:
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_image:
mock_storybook.return_value = build_storybook_output()
mock_image.side_effect = image_side_effect
response = auth_client.post(
"/api/storybook/generate",
json={
"keywords": "森林, 发光, 友情",
"page_count": 6,
"generate_images": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["generation_status"] == "degraded_completed"
assert data["image_status"] == "failed"
assert data["audio_status"] == "not_requested"
assert "第 2 页插图生成失败" in data["last_error"]