diff --git a/backend/.gitignore b/backend/.gitignore index b6be84e..edc8ecc 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -1,27 +1,28 @@ -# Python -__pycache__/ -*.py[cod] -*$py.class -*.so -.Python -.venv/ -venv/ -ENV/ - -# IDE -.idea/ -.vscode/ -*.swp -*.swo - -# 环境变量 -.env - -# 测试 -.pytest_cache/ -.coverage -htmlcov/ - -# 其他 +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +.venv/ +venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# 环境变量 +.env + +# 测试 +.pytest_cache/ +.coverage +htmlcov/ + +# 其他 *.log .DS_Store +storage/ diff --git a/backend/alembic/versions/0009_add_story_generation_statuses.py b/backend/alembic/versions/0009_add_story_generation_statuses.py new file mode 100644 index 0000000..47a5f3e --- /dev/null +++ b/backend/alembic/versions/0009_add_story_generation_statuses.py @@ -0,0 +1,151 @@ +"""add story generation status fields + +Revision ID: 0009_add_story_generation_statuses +Revises: 0008_add_pages_to_stories +Create Date: 2026-04-17 + +""" + +import sqlalchemy as sa +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "0009_add_story_generation_statuses" +down_revision = "0008_add_pages_to_stories" +branch_labels = None +depends_on = None + + +stories = sa.table( + "stories", + sa.column("id", sa.Integer), + sa.column("story_text", sa.Text), + sa.column("pages", sa.JSON), + sa.column("cover_prompt", sa.Text), + sa.column("image_url", sa.String(length=500)), + sa.column("generation_status", sa.String(length=32)), + sa.column("image_status", sa.String(length=32)), + sa.column("audio_status", sa.String(length=32)), +) + + +def _resolve_image_status(row: dict) -> str: + pages = row.get("pages") or [] + + expected_assets = 0 + ready_assets = 0 + + if row.get("cover_prompt") or row.get("image_url"): + expected_assets += 1 + if row.get("image_url"): + ready_assets += 1 + + for page in pages: + if not isinstance(page, dict): + continue + if not page.get("image_prompt") and not page.get("image_url"): + continue + expected_assets += 1 + if page.get("image_url"): + ready_assets += 1 + + if expected_assets == 0: + return "not_requested" + + if ready_assets == expected_assets: + return "ready" + + return "failed" + + +def _resolve_generation_status( + *, + story_text: str | None, + pages: list[dict] | None, + image_status: str, + audio_status: str, +) -> str: + has_narrative = bool(story_text) or bool(pages) + if not has_narrative: + return "failed" + + if "generating" in {image_status, audio_status}: + return "assets_generating" + + if "failed" in {image_status, audio_status}: + return "degraded_completed" + + if image_status == "not_requested" and audio_status == "not_requested": + return "narrative_ready" + + return "completed" + + +def upgrade() -> None: + op.add_column( + "stories", + sa.Column( + "generation_status", + sa.String(length=32), + nullable=False, + server_default="narrative_ready", + ), + ) + op.add_column( + "stories", + sa.Column( + "image_status", + sa.String(length=32), + nullable=False, + server_default="not_requested", + ), + ) + op.add_column( + "stories", + sa.Column( + "audio_status", + sa.String(length=32), + nullable=False, + server_default="not_requested", + ), + ) + op.add_column("stories", sa.Column("last_error", sa.Text(), nullable=True)) + + connection = op.get_bind() + rows = connection.execute( + sa.select( + stories.c.id, + stories.c.story_text, + stories.c.pages, + stories.c.cover_prompt, + stories.c.image_url, + ) + ).mappings() + + for row in rows: + image_status = _resolve_image_status(row) + audio_status = "not_requested" + generation_status = _resolve_generation_status( + story_text=row.get("story_text"), + pages=row.get("pages"), + image_status=image_status, + audio_status=audio_status, + ) + + connection.execute( + stories.update() + .where(stories.c.id == row["id"]) + .values( + generation_status=generation_status, + image_status=image_status, + audio_status=audio_status, + ) + ) + + +def downgrade() -> None: + op.drop_column("stories", "last_error") + op.drop_column("stories", "audio_status") + op.drop_column("stories", "image_status") + op.drop_column("stories", "generation_status") diff --git a/backend/alembic/versions/0010_add_story_audio_cache_path.py b/backend/alembic/versions/0010_add_story_audio_cache_path.py new file mode 100644 index 0000000..2bb09c9 --- /dev/null +++ b/backend/alembic/versions/0010_add_story_audio_cache_path.py @@ -0,0 +1,25 @@ +"""add audio cache path to stories + +Revision ID: 0010_add_story_audio_cache_path +Revises: 0009_add_story_generation_statuses +Create Date: 2026-04-17 + +""" + +import sqlalchemy as sa +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "0010_add_story_audio_cache_path" +down_revision = "0009_add_story_generation_statuses" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column("stories", sa.Column("audio_path", sa.String(length=500), nullable=True)) + + +def downgrade() -> None: + op.drop_column("stories", "audio_path") diff --git a/backend/app/api/stories.py b/backend/app/api/stories.py index cd84d74..24419b8 100644 --- a/backend/app/api/stories.py +++ b/backend/app/api/stories.py @@ -1,27 +1,28 @@ """Story related APIs.""" -import asyncio import json import uuid from typing import AsyncGenerator -from fastapi import APIRouter, Depends, HTTPException, Request, Response -from sse_starlette.sse import EventSourceResponse -from sqlalchemy.ext.asyncio import AsyncSession - -from app.core.deps import require_user -from app.core.logging import get_logger -from app.core.rate_limiter import check_rate_limit -from app.db.database import get_db -from app.db.models import User +from fastapi import APIRouter, Depends, Response +from sse_starlette.sse import EventSourceResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.deps import require_user +from app.core.logging import get_logger +from app.core.rate_limiter import check_rate_limit +from app.db.database import get_db +from app.db.models import User from app.schemas.story_schemas import ( - GenerateRequest, - StoryResponse, + AchievementItem, FullStoryResponse, + GenerateRequest, + StoryDetailResponse, + StoryImageResponse, + StoryListItem, + StoryResponse, StorybookRequest, StorybookResponse, - StoryListItem, - AchievementItem, ) from app.services import story_service from app.services.memory_service import build_enhanced_memory_context @@ -29,153 +30,202 @@ from app.services.provider_router import ( generate_story_content, generate_image, ) - -logger = get_logger(__name__) -router = APIRouter() - -RATE_LIMIT_WINDOW = 60 # seconds -RATE_LIMIT_REQUESTS = 10 - - -@router.post("/stories/generate", response_model=StoryResponse) -async def generate_story( - request: GenerateRequest, - user: User = Depends(require_user), - db: AsyncSession = Depends(get_db), -): - """Generate or enhance a story.""" - await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) - return await story_service.generate_and_save_story(request, user.id, db) - - -@router.post("/stories/generate/full", response_model=FullStoryResponse) -async def generate_story_full( - request: GenerateRequest, - user: User = Depends(require_user), - db: AsyncSession = Depends(get_db), -): - """Generate complete story (story + parallel image/audio generation).""" - await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) - return await story_service.generate_full_story_service(request, user.id, db) - - -@router.post("/stories/generate/stream") +from app.services.story_status import StoryAssetStatus, sync_story_status + +logger = get_logger(__name__) +router = APIRouter() + +RATE_LIMIT_WINDOW = 60 # seconds +RATE_LIMIT_REQUESTS = 10 + + +@router.post("/stories/generate", response_model=StoryResponse) +async def generate_story( + request: GenerateRequest, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Generate or enhance a story.""" + await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) + return await story_service.generate_and_save_story(request, user.id, db) + + +@router.post("/stories/generate/full", response_model=FullStoryResponse) +async def generate_story_full( + request: GenerateRequest, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Generate complete story (story + parallel image/audio generation).""" + await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) + return await story_service.generate_full_story_service(request, user.id, db) + + +@router.post("/stories/generate/stream") async def generate_story_stream( request: GenerateRequest, - req: Request, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): - """流式生成故事(SSE)。""" - await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) - - # Validation - profile_id, universe_id = await story_service.validate_profile_and_universe( - request.child_profile_id, request.universe_id, user.id, db - ) - - # Build Context - memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) - - async def event_generator() -> AsyncGenerator[dict, None]: - story_id = str(uuid.uuid4()) - yield {"event": "started", "data": json.dumps({"story_id": story_id})} - - # Step 1: Generate Content - try: - result = await generate_story_content( - input_type=request.type, - data=request.data, - education_theme=request.education_theme, - memory_context=memory_context, - db=db, - ) - except Exception as e: - logger.error("sse_story_generation_failed", error=str(e)) - yield {"event": "story_failed", "data": json.dumps({"error": str(e)})} - return - - # Save Story - story = await story_service.create_story_from_result( - result, user.id, profile_id, universe_id, db - ) - + """流式生成故事(SSE)。""" + await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) + + # Validation + profile_id, universe_id = await story_service.validate_profile_and_universe( + request.child_profile_id, request.universe_id, user.id, db + ) + + # Build Context + memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) + + async def event_generator() -> AsyncGenerator[dict, None]: + story_id = str(uuid.uuid4()) + yield {"event": "started", "data": json.dumps({"story_id": story_id})} + + # Step 1: Generate Content + try: + result = await generate_story_content( + input_type=request.type, + data=request.data, + education_theme=request.education_theme, + memory_context=memory_context, + db=db, + ) + except Exception as e: + logger.error("sse_story_generation_failed", error=str(e)) + yield {"event": "story_failed", "data": json.dumps({"error": str(e)})} + return + + # Save Story + story = await story_service.create_story_from_result( + result, user.id, profile_id, universe_id, db + ) + yield { "event": "story_ready", "data": json.dumps({ "id": story.id, "title": story.title, - "content": story.story_text, + "content": story.story_text, "cover_prompt": story.cover_prompt, "mode": story.mode, "child_profile_id": story.child_profile_id, "universe_id": story.universe_id, + "generation_status": story.generation_status, + "image_status": story.image_status, + "audio_status": story.audio_status, + "last_error": story.last_error, }), } # Step 2: Generate Image if story.cover_prompt: + sync_story_status(story, image_status=StoryAssetStatus.GENERATING) + await db.commit() try: # Direct call to provider router's generate_image, sharing db session image_url = await generate_image(story.cover_prompt, db=db) story.image_url = image_url + sync_story_status( + story, + image_status=StoryAssetStatus.READY, + ) await db.commit() - yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})} + yield { + "event": "image_ready", + "data": json.dumps( + { + "image_url": image_url, + "generation_status": story.generation_status, + "image_status": story.image_status, + "audio_status": story.audio_status, + "last_error": story.last_error, + } + ), + } except Exception as e: + sync_story_status( + story, + image_status=StoryAssetStatus.FAILED, + last_error=str(e), + ) + await db.commit() logger.warning("sse_image_generation_failed", story_id=story.id, error=str(e)) - yield {"event": "image_failed", "data": json.dumps({"error": str(e)})} + yield { + "event": "image_failed", + "data": json.dumps( + { + "error": str(e), + "generation_status": story.generation_status, + "image_status": story.image_status, + "audio_status": story.audio_status, + "last_error": story.last_error, + } + ), + } - yield {"event": "complete", "data": json.dumps({"story_id": story.id})} - - return EventSourceResponse(event_generator()) - - -@router.post("/storybook/generate", response_model=StorybookResponse) -async def generate_storybook_api( - request: StorybookRequest, - user: User = Depends(require_user), - db: AsyncSession = Depends(get_db), -): - """Generate storybook.""" - await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) - return await story_service.generate_storybook_service(request, user.id, db) - - -# ==================== Missing Endpoints (Issue #5) ==================== - -@router.get("/stories", response_model=list[StoryListItem]) -async def list_stories( - limit: int = 20, - offset: int = 0, - user: User = Depends(require_user), - db: AsyncSession = Depends(get_db), -): - """List stories.""" - return await story_service.list_stories(user.id, limit, offset, db) - - -@router.get("/stories/{story_id}", response_model=StoryResponse) + yield { + "event": "complete", + "data": json.dumps( + { + "story_id": story.id, + "generation_status": story.generation_status, + "image_status": story.image_status, + "audio_status": story.audio_status, + "last_error": story.last_error, + } + ), + } + + return EventSourceResponse(event_generator()) + + +@router.post("/storybook/generate", response_model=StorybookResponse) +async def generate_storybook_api( + request: StorybookRequest, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Generate storybook.""" + await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) + return await story_service.generate_storybook_service(request, user.id, db) + + +# ==================== Missing Endpoints (Issue #5) ==================== + +@router.get("/stories", response_model=list[StoryListItem]) +async def list_stories( + limit: int = 20, + offset: int = 0, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """List stories.""" + return await story_service.list_stories(user.id, limit, offset, db) + + +@router.get("/stories/{story_id}", response_model=StoryDetailResponse) async def get_story( story_id: int, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): - """Get story detail.""" - return await story_service.get_story_detail(story_id, user.id, db) - - -@router.delete("/stories/{story_id}") -async def delete_story( - story_id: int, - user: User = Depends(require_user), - db: AsyncSession = Depends(get_db), -): - """Delete story.""" - await story_service.delete_story(story_id, user.id, db) - return {"message": "Deleted"} - - -@router.post("/image/generate/{story_id}") + """Get story detail.""" + return await story_service.get_story_detail(story_id, user.id, db) + + +@router.delete("/stories/{story_id}") +async def delete_story( + story_id: int, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Delete story.""" + await story_service.delete_story(story_id, user.id, db) + return {"message": "Deleted"} + + +@router.post("/image/generate/{story_id}", response_model=StoryImageResponse) async def generate_story_image( story_id: int, user: User = Depends(require_user), @@ -183,25 +233,32 @@ async def generate_story_image( ): """Generate cover image for story.""" url = await story_service.generate_story_cover(story_id, user.id, db) - return {"image_url": url} - - -@router.get("/audio/{story_id}") -async def get_story_audio( - story_id: int, - user: User = Depends(require_user), - db: AsyncSession = Depends(get_db), -): - """Get story audio (MP3).""" - audio_bytes = await story_service.generate_story_audio(story_id, user.id, db) - return Response(content=audio_bytes, media_type="audio/mpeg") - - -@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem]) -async def get_story_achievements( - story_id: int, - user: User = Depends(require_user), - db: AsyncSession = Depends(get_db), -): - """Get story achievements.""" - return await story_service.get_story_achievements(story_id, user.id, db) + story = await story_service.get_story_detail(story_id, user.id, db) + return { + "image_url": url, + "generation_status": story.generation_status, + "image_status": story.image_status, + "audio_status": story.audio_status, + "last_error": story.last_error, + } + + +@router.get("/audio/{story_id}") +async def get_story_audio( + story_id: int, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Get story audio (MP3).""" + audio_bytes = await story_service.generate_story_audio(story_id, user.id, db) + return Response(content=audio_bytes, media_type="audio/mpeg") + + +@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem]) +async def get_story_achievements( + story_id: int, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Get story achievements.""" + return await story_service.get_story_achievements(story_id, user.id, db) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index d8013d0..f8d6f80 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -1,130 +1,134 @@ -from pydantic import Field, model_validator -from pydantic_settings import BaseSettings, SettingsConfigDict - - -class Settings(BaseSettings): - """应用全局配置""" - - model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") - - # 应用基础配置 - app_name: str = "DreamWeaver" - debug: bool = False - secret_key: str = Field(..., description="JWT 签名密钥") - base_url: str = Field("http://localhost:8000", description="后端对外回调地址") - - # 数据库 - database_url: str = Field(..., description="SQLAlchemy async URL") - - # OAuth - GitHub - github_client_id: str = "" - github_client_secret: str = "" - - # OAuth - Google - google_client_id: str = "" - google_client_secret: str = "" - - # AI Capability Keys - text_api_key: str = "" - tts_api_base: str = "" - tts_api_key: str = "" - image_api_key: str = "" - - # Additional Provider API Keys - openai_api_key: str = "" - elevenlabs_api_key: str = "" - cqtai_api_key: str = "" - minimax_api_key: str = "" - minimax_group_id: str = "" - antigravity_api_key: str = "" - antigravity_api_base: str = "" - - # AI Model Configuration - text_model: str = "gemini-2.0-flash" - openai_model: str = "gpt-4o-mini" - tts_model: str = "" - image_model: str = "nano-banana-pro" - tts_minimax_model: str = "speech-2.6-turbo" - tts_elevenlabs_model: str = "eleven_multilingual_v2" - tts_edge_voice: str = "zh-CN-XiaoxiaoNeural" - antigravity_model: str = "gemini-3-pro-image" - +from pydantic import Field, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """应用全局配置""" + + model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") + + # 应用基础配置 + app_name: str = "DreamWeaver" + debug: bool = False + secret_key: str = Field(..., description="JWT 签名密钥") + base_url: str = Field("http://localhost:8000", description="后端对外回调地址") + + # 数据库 + database_url: str = Field(..., description="SQLAlchemy async URL") + + # OAuth - GitHub + github_client_id: str = "" + github_client_secret: str = "" + + # OAuth - Google + google_client_id: str = "" + google_client_secret: str = "" + + # AI Capability Keys + text_api_key: str = "" + tts_api_base: str = "" + tts_api_key: str = "" + image_api_key: str = "" + + # Additional Provider API Keys + openai_api_key: str = "" + elevenlabs_api_key: str = "" + cqtai_api_key: str = "" + minimax_api_key: str = "" + minimax_group_id: str = "" + antigravity_api_key: str = "" + antigravity_api_base: str = "" + + # AI Model Configuration + text_model: str = "gemini-2.0-flash" + openai_model: str = "gpt-4o-mini" + tts_model: str = "" + image_model: str = "nano-banana-pro" + tts_minimax_model: str = "speech-2.6-turbo" + tts_elevenlabs_model: str = "eleven_multilingual_v2" + tts_edge_voice: str = "zh-CN-XiaoxiaoNeural" + antigravity_model: str = "gemini-3-pro-image" + # Provider routing (ordered lists) text_providers: list[str] = Field(default_factory=lambda: ["gemini"]) image_providers: list[str] = Field(default_factory=lambda: ["cqtai"]) tts_providers: list[str] = Field(default_factory=lambda: ["minimax", "elevenlabs", "edge_tts"]) + story_audio_cache_dir: str = Field( + "storage/audio", + description="Directory for cached story audio files", + ) # Celery (Redis) celery_broker_url: str = Field("redis://localhost:6379/0") - celery_result_backend: str = Field("redis://localhost:6379/0") - - # Generic Redis - redis_url: str = Field("redis://localhost:6379/0", description="Redis connection URL") - redis_sentinel_enabled: bool = Field(False, description="Whether to enable Redis Sentinel") - redis_sentinel_nodes: str = Field( - "", - description="Comma-separated Redis Sentinel nodes, e.g. host1:26379,host2:26379", - ) - redis_sentinel_master_name: str = Field("mymaster", description="Redis Sentinel master name") - redis_sentinel_password: str = Field("", description="Password for Redis Sentinel (optional)") - redis_sentinel_db: int = Field(0, description="Redis DB index when using Sentinel") - redis_sentinel_socket_timeout: float = Field( - 0.5, - description="Socket timeout in seconds for Sentinel clients", - ) - - # Admin console - enable_admin_console: bool = False - admin_username: str = "admin" - admin_password: str = "admin123" # 建议通过环境变量覆盖 - - # CORS - cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173"]) - - @model_validator(mode="after") - def _require_core_settings(self) -> "Settings": # type: ignore[override] - missing = [] - if not self.secret_key or self.secret_key == "change-me-in-production": - missing.append("SECRET_KEY") - if not self.database_url: - missing.append("DATABASE_URL") - if self.redis_sentinel_enabled and not self.redis_sentinel_nodes.strip(): - missing.append("REDIS_SENTINEL_NODES") - if missing: - raise ValueError(f"Missing required settings: {', '.join(missing)}") - return self - - @property - def redis_sentinel_hosts(self) -> list[tuple[str, int]]: - """Parse Redis Sentinel nodes into (host, port) tuples.""" - nodes = [] - raw = self.redis_sentinel_nodes.strip() - if not raw: - return nodes - - for item in raw.split(","): - value = item.strip() - if not value: - continue - if ":" not in value: - raise ValueError(f"Invalid sentinel node format: {value}") - host, port_text = value.rsplit(":", 1) - if not host: - raise ValueError(f"Invalid sentinel node host: {value}") - try: - port = int(port_text) - except ValueError as exc: - raise ValueError(f"Invalid sentinel node port: {value}") from exc - nodes.append((host, port)) - return nodes - - @property - def redis_sentinel_urls(self) -> list[str]: - """Build Celery-compatible Sentinel URLs with DB index.""" - return [ - f"sentinel://{host}:{port}/{self.redis_sentinel_db}" - for host, port in self.redis_sentinel_hosts - ] - - -settings = Settings() + celery_result_backend: str = Field("redis://localhost:6379/0") + + # Generic Redis + redis_url: str = Field("redis://localhost:6379/0", description="Redis connection URL") + redis_sentinel_enabled: bool = Field(False, description="Whether to enable Redis Sentinel") + redis_sentinel_nodes: str = Field( + "", + description="Comma-separated Redis Sentinel nodes, e.g. host1:26379,host2:26379", + ) + redis_sentinel_master_name: str = Field("mymaster", description="Redis Sentinel master name") + redis_sentinel_password: str = Field("", description="Password for Redis Sentinel (optional)") + redis_sentinel_db: int = Field(0, description="Redis DB index when using Sentinel") + redis_sentinel_socket_timeout: float = Field( + 0.5, + description="Socket timeout in seconds for Sentinel clients", + ) + + # Admin console + enable_admin_console: bool = False + admin_username: str = "admin" + admin_password: str = "admin123" # 建议通过环境变量覆盖 + + # CORS + cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173"]) + + @model_validator(mode="after") + def _require_core_settings(self) -> "Settings": # type: ignore[override] + missing = [] + if not self.secret_key or self.secret_key == "change-me-in-production": + missing.append("SECRET_KEY") + if not self.database_url: + missing.append("DATABASE_URL") + if self.redis_sentinel_enabled and not self.redis_sentinel_nodes.strip(): + missing.append("REDIS_SENTINEL_NODES") + if missing: + raise ValueError(f"Missing required settings: {', '.join(missing)}") + return self + + @property + def redis_sentinel_hosts(self) -> list[tuple[str, int]]: + """Parse Redis Sentinel nodes into (host, port) tuples.""" + nodes = [] + raw = self.redis_sentinel_nodes.strip() + if not raw: + return nodes + + for item in raw.split(","): + value = item.strip() + if not value: + continue + if ":" not in value: + raise ValueError(f"Invalid sentinel node format: {value}") + host, port_text = value.rsplit(":", 1) + if not host: + raise ValueError(f"Invalid sentinel node host: {value}") + try: + port = int(port_text) + except ValueError as exc: + raise ValueError(f"Invalid sentinel node port: {value}") from exc + nodes.append((host, port)) + return nodes + + @property + def redis_sentinel_urls(self) -> list[str]: + """Build Celery-compatible Sentinel URLs with DB index.""" + return [ + f"sentinel://{host}:{port}/{self.redis_sentinel_db}" + for host, port in self.redis_sentinel_hosts + ] + + +settings = Settings() diff --git a/backend/app/db/models.py b/backend/app/db/models.py index 394d325..93ef2a2 100644 --- a/backend/app/db/models.py +++ b/backend/app/db/models.py @@ -27,10 +27,10 @@ class User(Base): __tablename__ = "users" - id: Mapped[str] = mapped_column(String(255), primary_key=True) # OAuth provider user ID + id: Mapped[str] = mapped_column(String(255), primary_key=True) name: Mapped[str] = mapped_column(String(255), nullable=False) avatar_url: Mapped[str | None] = mapped_column(String(500)) - provider: Mapped[str] = mapped_column(String(50), nullable=False) # github / google + provider: Mapped[str] = mapped_column(String(50), nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) @@ -59,11 +59,22 @@ class Story(Base): String(36), ForeignKey("story_universes.id", ondelete="SET NULL"), nullable=True ) title: Mapped[str] = mapped_column(String(255), nullable=False) - story_text: Mapped[str] = mapped_column(Text, nullable=True) # 允许为空(绘本模式下) - pages: Mapped[list[dict] | None] = mapped_column(JSON, default=list) # 绘本分页数据 + story_text: Mapped[str | None] = mapped_column(Text, nullable=True) + pages: Mapped[list[dict] | None] = mapped_column(JSON, default=list) cover_prompt: Mapped[str | None] = mapped_column(Text) image_url: Mapped[str | None] = mapped_column(String(500)) mode: Mapped[str] = mapped_column(String(20), nullable=False, default="generated") + generation_status: Mapped[str] = mapped_column( + String(32), nullable=False, default="narrative_ready" + ) + image_status: Mapped[str] = mapped_column( + String(32), nullable=False, default="not_requested" + ) + audio_status: Mapped[str] = mapped_column( + String(32), nullable=False, default="not_requested" + ) + audio_path: Mapped[str | None] = mapped_column(String(500)) + last_error: Mapped[str | None] = mapped_column(Text) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) @@ -123,6 +134,7 @@ class ChildProfile(Base): class StoryUniverse(Base): """Story universe entity.""" + __tablename__ = "story_universes" id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid) @@ -142,7 +154,9 @@ class StoryUniverse(Base): DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) - child_profile: Mapped["ChildProfile"] = relationship("ChildProfile", back_populates="story_universes") + child_profile: Mapped["ChildProfile"] = relationship( + "ChildProfile", back_populates="story_universes" + ) class ReadingEvent(Base): @@ -163,6 +177,7 @@ class ReadingEvent(Base): DateTime(timezone=True), server_default=func.now(), index=True ) + class PushConfig(Base): """Push configuration entity.""" diff --git a/backend/app/schemas/story_schemas.py b/backend/app/schemas/story_schemas.py index 80a8c80..aff4e04 100644 --- a/backend/app/schemas/story_schemas.py +++ b/backend/app/schemas/story_schemas.py @@ -1,4 +1,4 @@ -"""故事相关 Pydantic 模型。""" +"""Story-related Pydantic schemas.""" from datetime import datetime from typing import Literal @@ -11,7 +11,13 @@ MAX_EDU_THEME_LENGTH = 200 MAX_TTS_LENGTH = 4000 -# ==================== 故事模型 ==================== +class StoryStatusMixin(BaseModel): + """Shared generation status fields returned by story APIs.""" + + generation_status: str + image_status: str + audio_status: str + last_error: str | None = None class GenerateRequest(BaseModel): @@ -24,8 +30,8 @@ class GenerateRequest(BaseModel): universe_id: str | None = None -class StoryResponse(BaseModel): - """Story response.""" +class StoryResponse(StoryStatusMixin): + """Story generation response.""" id: int title: str @@ -37,7 +43,7 @@ class StoryResponse(BaseModel): universe_id: str | None = None -class StoryListItem(BaseModel): +class StoryListItem(StoryStatusMixin): """Story list item.""" id: int @@ -47,8 +53,8 @@ class StoryListItem(BaseModel): mode: str -class FullStoryResponse(BaseModel): - """完整故事响应(含图片和音频状态)。""" +class FullStoryResponse(StoryStatusMixin): + """Full story response with asset status.""" id: int title: str @@ -62,22 +68,19 @@ class FullStoryResponse(BaseModel): universe_id: str | None = None -# ==================== 绘本模型 ==================== - - class StorybookRequest(BaseModel): - """Storybook 生成请求。""" + """Storybook generation request.""" keywords: str = Field(..., min_length=1, max_length=200) page_count: int = Field(default=6, ge=4, le=12) education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH) - generate_images: bool = Field(default=False, description="是否同时生成插图") + generate_images: bool = Field(default=False, description="Whether to generate images too.") child_profile_id: str | None = None universe_id: str | None = None class StorybookPageResponse(BaseModel): - """故事书单页响应。""" + """One storybook page.""" page_number: int text: str @@ -85,8 +88,8 @@ class StorybookPageResponse(BaseModel): image_url: str | None = None -class StorybookResponse(BaseModel): - """故事书响应。""" +class StorybookResponse(StoryStatusMixin): + """Storybook generation response.""" id: int | None = None title: str @@ -97,10 +100,29 @@ class StorybookResponse(BaseModel): cover_url: str | None = None -# ==================== 成就模型 ==================== +class StoryDetailResponse(StoryStatusMixin): + """Story detail response for both stories and storybooks.""" + + id: int + title: str + story_text: str | None = None + pages: list[StorybookPageResponse] | None = None + cover_prompt: str | None + image_url: str | None + mode: str + child_profile_id: str | None = None + universe_id: str | None = None + + +class StoryImageResponse(StoryStatusMixin): + """Cover image generation response.""" + + image_url: str | None class AchievementItem(BaseModel): + """Achievement item returned for a story.""" + type: str description: str obtained_at: str | None = None diff --git a/backend/app/services/audio_storage.py b/backend/app/services/audio_storage.py new file mode 100644 index 0000000..82919f3 --- /dev/null +++ b/backend/app/services/audio_storage.py @@ -0,0 +1,38 @@ +"""Story audio cache storage helpers.""" + +from __future__ import annotations + +from pathlib import Path + +from app.core.config import settings + + +def build_story_audio_path(story_id: int) -> str: + """Build the cache path for a story audio file.""" + + return str(Path(settings.story_audio_cache_dir) / f"story-{story_id}.mp3") + + +def audio_cache_exists(audio_path: str | None) -> bool: + """Whether the cached audio file exists on disk.""" + + return bool(audio_path) and Path(audio_path).is_file() + + +def read_audio_cache(audio_path: str) -> bytes: + """Read cached story audio bytes.""" + + return Path(audio_path).read_bytes() + + +def write_story_audio_cache(story_id: int, audio_data: bytes) -> str: + """Persist story audio and return the saved file path.""" + + final_path = Path(build_story_audio_path(story_id)) + final_path.parent.mkdir(parents=True, exist_ok=True) + + temp_path = final_path.with_suffix(".tmp") + temp_path.write_bytes(audio_data) + temp_path.replace(final_path) + + return str(final_path) diff --git a/backend/app/services/story_service.py b/backend/app/services/story_service.py index 838275f..ebac765 100644 --- a/backend/app/services/story_service.py +++ b/backend/app/services/story_service.py @@ -1,12 +1,9 @@ """Story business logic service.""" import asyncio -import json -import uuid -from typing import Literal from fastapi import HTTPException -from sqlalchemy import select, desc +from sqlalchemy import desc, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -15,90 +12,151 @@ from app.db.models import ChildProfile, Story, StoryUniverse from app.schemas.story_schemas import ( GenerateRequest, StorybookRequest, - FullStoryResponse, - StorybookResponse, + FullStoryResponse, + StorybookResponse, StorybookPageResponse, AchievementItem, ) +from app.services.audio_storage import ( + audio_cache_exists, + read_audio_cache, + write_story_audio_cache, +) from app.services.memory_service import build_enhanced_memory_context from app.services.provider_router import ( generate_story_content, generate_image, generate_storybook, ) +from app.services.story_status import ( + StoryAssetStatus, + sync_story_status, +) from app.tasks.achievements import extract_story_achievements logger = get_logger(__name__) -async def validate_profile_and_universe( - profile_id: str | None, - universe_id: str | None, - user_id: str, - db: AsyncSession, -) -> tuple[str | None, str | None]: - """Validate child profile and universe ownership/relationship.""" - if not profile_id and not universe_id: - return None, None +def _build_storybook_error_message( + *, + cover_failed: bool, + failed_pages: list[int], +) -> str | None: + """Summarize storybook image generation errors for the latest attempt.""" - if profile_id: - result = await db.execute( - select(ChildProfile).where( - ChildProfile.id == profile_id, - ChildProfile.user_id == user_id, - ) - ) - profile = result.scalar_one_or_none() - if not profile: - raise HTTPException(status_code=404, detail="孩子档案不存在") - - if universe_id: - result = await db.execute( - select(StoryUniverse) - .join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id) - .where( - StoryUniverse.id == universe_id, - ChildProfile.user_id == user_id, - ) - ) - universe = result.scalar_one_or_none() - if not universe: - raise HTTPException(status_code=404, detail="故事宇宙不存在") - if profile_id and universe.child_profile_id != profile_id: - raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配") - if not profile_id: - profile_id = universe.child_profile_id - - return profile_id, universe_id + parts: list[str] = [] + if cover_failed: + parts.append("封面生成失败") + if failed_pages: + pages = "、".join(str(page) for page in sorted(failed_pages)) + parts.append(f"第 {pages} 页插图生成失败") + return ";".join(parts) if parts else None +def _resolve_storybook_image_status( + *, + generate_images: bool, + cover_prompt: str | None, + cover_url: str | None, + pages_data: list[dict], +) -> StoryAssetStatus: + """Resolve the persisted image status for a storybook.""" + + if not generate_images: + return StoryAssetStatus.NOT_REQUESTED + + expected_assets = 0 + ready_assets = 0 + + if cover_prompt or cover_url: + expected_assets += 1 + if cover_url: + ready_assets += 1 + + for page in pages_data: + if not page.get("image_prompt") and not page.get("image_url"): + continue + expected_assets += 1 + if page.get("image_url"): + ready_assets += 1 + + if expected_assets == 0: + return StoryAssetStatus.NOT_REQUESTED + + if ready_assets == expected_assets: + return StoryAssetStatus.READY + + return StoryAssetStatus.FAILED + + +async def validate_profile_and_universe( + profile_id: str | None, + universe_id: str | None, + user_id: str, + db: AsyncSession, +) -> tuple[str | None, str | None]: + """Validate child profile and universe ownership/relationship.""" + if not profile_id and not universe_id: + return None, None + + if profile_id: + result = await db.execute( + select(ChildProfile).where( + ChildProfile.id == profile_id, + ChildProfile.user_id == user_id, + ) + ) + profile = result.scalar_one_or_none() + if not profile: + raise HTTPException(status_code=404, detail="孩子档案不存在") + + if universe_id: + result = await db.execute( + select(StoryUniverse) + .join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id) + .where( + StoryUniverse.id == universe_id, + ChildProfile.user_id == user_id, + ) + ) + universe = result.scalar_one_or_none() + if not universe: + raise HTTPException(status_code=404, detail="故事宇宙不存在") + if profile_id and universe.child_profile_id != profile_id: + raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配") + if not profile_id: + profile_id = universe.child_profile_id + + return profile_id, universe_id + + async def generate_and_save_story( request: GenerateRequest, user_id: str, db: AsyncSession, ) -> Story: - """Generate generic story content and save to DB.""" - # 1. Validate - profile_id, universe_id = await validate_profile_and_universe( - request.child_profile_id, request.universe_id, user_id, db - ) - - # 2. Build Context - memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) - - # 3. Generate - try: - result = await generate_story_content( - input_type=request.type, - data=request.data, - education_theme=request.education_theme, - memory_context=memory_context, - db=db, - ) - except Exception as exc: - raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc - - # 4. Save + """Generate generic story content and save to DB.""" + # 1. Validate + profile_id, universe_id = await validate_profile_and_universe( + request.child_profile_id, request.universe_id, user_id, db + ) + + # 2. Build Context + memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) + + # 3. Generate + try: + result = await generate_story_content( + input_type=request.type, + data=request.data, + education_theme=request.education_theme, + memory_context=memory_context, + db=db, + ) + except Exception as exc: + raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc + + # 4. Save story = Story( user_id=user_id, child_profile_id=profile_id, @@ -108,170 +166,209 @@ async def generate_and_save_story( cover_prompt=result.cover_prompt_suggestion, mode=result.mode, ) + sync_story_status( + story, + image_status=StoryAssetStatus.NOT_REQUESTED, + audio_status=StoryAssetStatus.NOT_REQUESTED, + last_error=None, + ) db.add(story) await db.commit() await db.refresh(story) - - # 5. Trigger Async Tasks - if universe_id: - extract_story_achievements.delay(story.id, universe_id) - - return story - - + + # 5. Trigger Async Tasks + if universe_id: + extract_story_achievements.delay(story.id, universe_id) + + return story + + async def generate_full_story_service( request: GenerateRequest, user_id: str, db: AsyncSession, ) -> FullStoryResponse: - """Generate story with parallel image generation.""" - # 1. Generate text part - # We can reuse logic or call generate_story_content directly if we want finer control - # reusing generate_and_save_story to ensure consistency (it handles validation + saving) - story = await generate_and_save_story(request, user_id, db) - - # 2. Generate Image (Parallel/Async step in this flow) + """Generate story with parallel image generation.""" + # 1. Generate text part + # We can reuse logic or call generate_story_content directly if we want finer control + # reusing generate_and_save_story to ensure consistency (it handles validation + saving) + story = await generate_and_save_story(request, user_id, db) + + # 2. Generate Image (Parallel/Async step in this flow) image_url: str | None = None errors: dict[str, str | None] = {} if story.cover_prompt: + sync_story_status(story, image_status=StoryAssetStatus.GENERATING) + await db.commit() try: image_url = await generate_image(story.cover_prompt, db=db) story.image_url = image_url + sync_story_status( + story, + image_status=StoryAssetStatus.READY, + ) await db.commit() except Exception as exc: errors["image"] = str(exc) + sync_story_status( + story, + image_status=StoryAssetStatus.FAILED, + last_error=str(exc), + ) + await db.commit() logger.warning("image_generation_failed", story_id=story.id, error=str(exc)) return FullStoryResponse( id=story.id, title=story.title, - story_text=story.story_text, - cover_prompt=story.cover_prompt, - image_url=image_url, + story_text=story.story_text, + cover_prompt=story.cover_prompt, + image_url=image_url, audio_ready=False, mode=story.mode, errors=errors, child_profile_id=story.child_profile_id, universe_id=story.universe_id, + generation_status=story.generation_status, + image_status=story.image_status, + audio_status=story.audio_status, + last_error=story.last_error, ) - - + + async def generate_storybook_service( request: StorybookRequest, user_id: str, db: AsyncSession, ) -> StorybookResponse: - """Generate storybook with parallel image generation for pages.""" - # 1. Validate - profile_id, universe_id = await validate_profile_and_universe( - request.child_profile_id, request.universe_id, user_id, db - ) - - logger.info( - "storybook_request", - user_id=user_id, - keywords=request.keywords, - page_count=request.page_count, - profile_id=profile_id, - universe_id=universe_id, - ) - - # 2. Context - memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) - - # 3. Generate Text Structure - try: - storybook = await generate_storybook( - keywords=request.keywords, - page_count=request.page_count, - education_theme=request.education_theme, - memory_context=memory_context, - db=db, - ) - except Exception as e: - logger.error("storybook_generation_failed", error=str(e)) - raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}") - + """Generate storybook with parallel image generation for pages.""" + # 1. Validate + profile_id, universe_id = await validate_profile_and_universe( + request.child_profile_id, request.universe_id, user_id, db + ) + + logger.info( + "storybook_request", + user_id=user_id, + keywords=request.keywords, + page_count=request.page_count, + profile_id=profile_id, + universe_id=universe_id, + ) + + # 2. Context + memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) + + # 3. Generate Text Structure + try: + storybook = await generate_storybook( + keywords=request.keywords, + page_count=request.page_count, + education_theme=request.education_theme, + memory_context=memory_context, + db=db, + ) + except Exception as e: + logger.error("storybook_generation_failed", error=str(e)) + raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}") + # 4. Parallel Image Generation final_cover_url = storybook.cover_url + cover_failed = False + failed_pages: list[int] = [] + if request.generate_images: logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages)) - + tasks = [] - # Cover Task async def _gen_cover(): + nonlocal cover_failed + if storybook.cover_prompt and not storybook.cover_url: try: return await generate_image(storybook.cover_prompt, db=db) - except Exception as e: - logger.warning("cover_gen_failed", error=str(e)) + except Exception as exc: + cover_failed = True + logger.warning("cover_gen_failed", error=str(exc)) return storybook.cover_url + tasks.append(_gen_cover()) - # Page Tasks async def _gen_page(page): if page.image_prompt and not page.image_url: try: - url = await generate_image(page.image_prompt, db=db) - page.image_url = url - except Exception as e: - logger.warning("page_gen_failed", page=page.page_number, error=str(e)) + page.image_url = await generate_image(page.image_prompt, db=db) + except Exception as exc: + failed_pages.append(page.page_number) + logger.warning("page_gen_failed", page=page.page_number, error=str(exc)) for page in storybook.pages: tasks.append(_gen_page(page)) - # Execute results = await asyncio.gather(*tasks, return_exceptions=True) - - # Update cover result + cover_res = results[0] if isinstance(cover_res, str): final_cover_url = cover_res logger.info("storybook_parallel_generation_complete") - - # 5. Save to DB - pages_data = [ - { - "page_number": p.page_number, - "text": p.text, - "image_prompt": p.image_prompt, - "image_url": p.image_url, - } - for p in storybook.pages - ] - + + # 5. Save to DB + pages_data = [ + { + "page_number": p.page_number, + "text": p.text, + "image_prompt": p.image_prompt, + "image_url": p.image_url, + } + for p in storybook.pages + ] + story = Story( user_id=user_id, child_profile_id=profile_id, universe_id=universe_id, title=storybook.title, - mode="storybook", - pages=pages_data, + mode="storybook", + pages=pages_data, story_text=None, cover_prompt=storybook.cover_prompt, image_url=final_cover_url, ) + sync_story_status( + story, + image_status=_resolve_storybook_image_status( + generate_images=request.generate_images, + cover_prompt=storybook.cover_prompt, + cover_url=final_cover_url, + pages_data=pages_data, + ), + audio_status=StoryAssetStatus.NOT_REQUESTED, + last_error=_build_storybook_error_message( + cover_failed=cover_failed, + failed_pages=failed_pages, + ), + ) db.add(story) await db.commit() await db.refresh(story) - - if universe_id: - extract_story_achievements.delay(story.id, universe_id) - - # 6. Build Response - response_pages = [ - StorybookPageResponse( - page_number=p["page_number"], - text=p["text"], - image_prompt=p["image_prompt"], - image_url=p.get("image_url"), - ) - for p in pages_data - ] - + + if universe_id: + extract_story_achievements.delay(story.id, universe_id) + + # 6. Build Response + response_pages = [ + StorybookPageResponse( + page_number=p["page_number"], + text=p["text"], + image_prompt=p["image_prompt"], + image_url=p.get("image_url"), + ) + for p in pages_data + ] + return StorybookResponse( id=story.id, title=storybook.title, @@ -280,155 +377,209 @@ async def generate_storybook_service( pages=response_pages, cover_prompt=storybook.cover_prompt, cover_url=final_cover_url, + generation_status=story.generation_status, + image_status=story.image_status, + audio_status=story.audio_status, + last_error=story.last_error, ) - - -# ==================== Missing Endpoints Logic (for Issue #5) ==================== - -async def list_stories( - user_id: str, - limit: int, - offset: int, - db: AsyncSession, -) -> list[Story]: - """List stories for user.""" - result = await db.execute( - select(Story) - .where(Story.user_id == user_id) - .order_by(desc(Story.created_at)) - .offset(offset) - .limit(limit) - ) - return result.scalars().all() - - -async def get_story_detail( - story_id: int, - user_id: str, - db: AsyncSession, -) -> Story: - """Get story detail.""" - result = await db.execute( - select(Story).where(Story.id == story_id, Story.user_id == user_id) - ) - story = result.scalar_one_or_none() - if not story: - raise HTTPException(status_code=404, detail="Story not found") - return story - - -async def delete_story( - story_id: int, - user_id: str, - db: AsyncSession, -) -> None: - """Delete a story.""" - story = await get_story_detail(story_id, user_id, db) - await db.delete(story) - await db.commit() - - + + +# ==================== Missing Endpoints Logic (for Issue #5) ==================== + +async def list_stories( + user_id: str, + limit: int, + offset: int, + db: AsyncSession, +) -> list[Story]: + """List stories for user.""" + result = await db.execute( + select(Story) + .where(Story.user_id == user_id) + .order_by(desc(Story.created_at)) + .offset(offset) + .limit(limit) + ) + return result.scalars().all() + + +async def get_story_detail( + story_id: int, + user_id: str, + db: AsyncSession, +) -> Story: + """Get story detail.""" + result = await db.execute( + select(Story).where(Story.id == story_id, Story.user_id == user_id) + ) + story = result.scalar_one_or_none() + if not story: + raise HTTPException(status_code=404, detail="Story not found") + return story + + +async def delete_story( + story_id: int, + user_id: str, + db: AsyncSession, +) -> None: + """Delete a story.""" + story = await get_story_detail(story_id, user_id, db) + await db.delete(story) + await db.commit() + + async def create_story_from_result( result, # StoryOutput user_id: str, profile_id: str | None, universe_id: str | None, - db: AsyncSession, -) -> Story: - """Save a generated story to DB (helper for stream endpoint).""" - story = Story( - user_id=user_id, - child_profile_id=profile_id, - universe_id=universe_id, + db: AsyncSession, +) -> Story: + """Save a generated story to DB (helper for stream endpoint).""" + story = Story( + user_id=user_id, + child_profile_id=profile_id, + universe_id=universe_id, title=result.title, story_text=result.story_text, cover_prompt=result.cover_prompt_suggestion, mode=result.mode, ) + sync_story_status( + story, + image_status=StoryAssetStatus.NOT_REQUESTED, + audio_status=StoryAssetStatus.NOT_REQUESTED, + last_error=None, + ) db.add(story) await db.commit() await db.refresh(story) - - if universe_id: - extract_story_achievements.delay(story.id, universe_id) - - return story - - + + if universe_id: + extract_story_achievements.delay(story.id, universe_id) + + return story + + async def generate_story_cover( story_id: int, user_id: str, db: AsyncSession, ) -> str: - """Generate cover image for an existing story.""" - story = await get_story_detail(story_id, user_id, db) - + """Generate cover image for an existing story.""" + story = await get_story_detail(story_id, user_id, db) + if not story.cover_prompt: raise HTTPException(status_code=400, detail="Story has no cover prompt") - + + sync_story_status(story, image_status=StoryAssetStatus.GENERATING) + await db.commit() + try: image_url = await generate_image(story.cover_prompt, db=db) story.image_url = image_url + sync_story_status( + story, + image_status=StoryAssetStatus.READY, + ) await db.commit() return image_url except Exception as e: + sync_story_status( + story, + image_status=StoryAssetStatus.FAILED, + last_error=str(e), + ) + await db.commit() logger.error("cover_generation_failed", story_id=story_id, error=str(e)) raise HTTPException(status_code=500, detail=f"Image generation failed: {e}") - - + + async def generate_story_audio( story_id: int, user_id: str, db: AsyncSession, ) -> bytes: - """Generate audio for a story.""" - story = await get_story_detail(story_id, user_id, db) - + """Generate audio for a story.""" + story = await get_story_detail(story_id, user_id, db) + if not story.story_text: raise HTTPException(status_code=400, detail="Story has no text") - # TODO: Check if audio is already cached/saved? - # For now, generate on the fly via provider + if story.audio_path and audio_cache_exists(story.audio_path): + if story.audio_status != StoryAssetStatus.READY.value: + sync_story_status(story, audio_status=StoryAssetStatus.READY) + await db.commit() + return read_audio_cache(story.audio_path) + + if story.audio_path and not audio_cache_exists(story.audio_path): + logger.warning( + "story_audio_cache_missing", + story_id=story_id, + audio_path=story.audio_path, + ) + story.audio_path = None + if story.audio_status == StoryAssetStatus.READY.value: + sync_story_status(story, audio_status=StoryAssetStatus.NOT_REQUESTED) + await db.commit() + from app.services.provider_router import text_to_speech - + + sync_story_status(story, audio_status=StoryAssetStatus.GENERATING) + await db.commit() + try: audio_data = await text_to_speech(story.story_text, db=db) + story.audio_path = write_story_audio_cache(story.id, audio_data) + sync_story_status( + story, + audio_status=StoryAssetStatus.READY, + ) + await db.commit() return audio_data except Exception as e: + story.audio_path = None + sync_story_status( + story, + audio_status=StoryAssetStatus.FAILED, + last_error=str(e), + ) + await db.commit() logger.error("audio_generation_failed", story_id=story_id, error=str(e)) raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}") - - -async def get_story_achievements( - story_id: int, - user_id: str, - db: AsyncSession, -) -> list[AchievementItem]: - """Get achievements unlocked by a specific story.""" - result = await db.execute( - select(Story) - .options(joinedload(Story.story_universe)) - .where(Story.id == story_id, Story.user_id == user_id) - ) - story = result.scalar_one_or_none() - - if not story: - raise HTTPException(status_code=404, detail="Story not found") - - if not story.universe_id or not story.story_universe: - return [] - - universe = story.story_universe - if not universe.achievements: - return [] - - results = [] - for ach in universe.achievements: - if isinstance(ach, dict) and ach.get("source_story_id") == story_id: - results.append(AchievementItem( - type=ach.get("type", "Unknown"), - description=ach.get("description", ""), - obtained_at=ach.get("obtained_at") - )) - return results - + + +async def get_story_achievements( + story_id: int, + user_id: str, + db: AsyncSession, +) -> list[AchievementItem]: + """Get achievements unlocked by a specific story.""" + result = await db.execute( + select(Story) + .options(joinedload(Story.story_universe)) + .where(Story.id == story_id, Story.user_id == user_id) + ) + story = result.scalar_one_or_none() + + if not story: + raise HTTPException(status_code=404, detail="Story not found") + + if not story.universe_id or not story.story_universe: + return [] + + universe = story.story_universe + if not universe.achievements: + return [] + + results = [] + for ach in universe.achievements: + if isinstance(ach, dict) and ach.get("source_story_id") == story_id: + results.append(AchievementItem( + type=ach.get("type", "Unknown"), + description=ach.get("description", ""), + obtained_at=ach.get("obtained_at") + )) + return results + diff --git a/backend/app/services/story_status.py b/backend/app/services/story_status.py new file mode 100644 index 0000000..ee62564 --- /dev/null +++ b/backend/app/services/story_status.py @@ -0,0 +1,112 @@ +"""Story generation status helpers.""" + +from __future__ import annotations + +from enum import Enum +from typing import Protocol + + +class StoryGenerationStatus(str, Enum): + """Overall story generation lifecycle.""" + + NARRATIVE_READY = "narrative_ready" + ASSETS_GENERATING = "assets_generating" + COMPLETED = "completed" + DEGRADED_COMPLETED = "degraded_completed" + FAILED = "failed" + + +class StoryAssetStatus(str, Enum): + """Asset generation state for image and audio.""" + + NOT_REQUESTED = "not_requested" + GENERATING = "generating" + READY = "ready" + FAILED = "failed" + + +class StoryLike(Protocol): + """Protocol for story-like objects used by status helpers.""" + + story_text: str | None + pages: list[dict] | None + generation_status: str + image_status: str + audio_status: str + last_error: str | None + + +_ERROR_UNSET = object() + + +def _normalize_asset_status(value: str | None) -> StoryAssetStatus: + if not value: + return StoryAssetStatus.NOT_REQUESTED + + try: + return StoryAssetStatus(value) + except ValueError: + return StoryAssetStatus.NOT_REQUESTED + + +def has_narrative_content(story: StoryLike) -> bool: + """Whether the story already has readable content.""" + + return bool(story.story_text) or bool(story.pages) + + +def resolve_story_generation_status(story: StoryLike) -> StoryGenerationStatus: + """Derive the overall status from narrative and asset states.""" + + if not has_narrative_content(story): + return StoryGenerationStatus.FAILED + + image_status = _normalize_asset_status(story.image_status) + audio_status = _normalize_asset_status(story.audio_status) + + if StoryAssetStatus.GENERATING in (image_status, audio_status): + return StoryGenerationStatus.ASSETS_GENERATING + + if StoryAssetStatus.FAILED in (image_status, audio_status): + return StoryGenerationStatus.DEGRADED_COMPLETED + + if ( + image_status == StoryAssetStatus.NOT_REQUESTED + and audio_status == StoryAssetStatus.NOT_REQUESTED + ): + return StoryGenerationStatus.NARRATIVE_READY + + return StoryGenerationStatus.COMPLETED + + +def has_failed_assets(story: StoryLike) -> bool: + """Whether any persisted asset is still in a failed state.""" + + image_status = _normalize_asset_status(story.image_status) + audio_status = _normalize_asset_status(story.audio_status) + return StoryAssetStatus.FAILED in (image_status, audio_status) + + +def sync_story_status( + story: StoryLike, + *, + image_status: StoryAssetStatus | None = None, + audio_status: StoryAssetStatus | None = None, + last_error: str | None | object = _ERROR_UNSET, +) -> None: + """Update asset statuses and refresh overall generation status.""" + + if image_status is not None: + story.image_status = image_status.value + + if audio_status is not None: + story.audio_status = audio_status.value + + if last_error is not _ERROR_UNSET: + story.last_error = last_error + + generation_status = resolve_story_generation_status(story) + story.generation_status = generation_status.value + + if last_error is _ERROR_UNSET and not has_failed_assets(story): + story.last_error = None diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 19786a0..e9ebe62 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,4 +1,4 @@ -"""测试配置和 fixtures。""" +"""Pytest fixtures for backend tests.""" import os from collections.abc import AsyncGenerator @@ -11,6 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing") os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:") +from app.core.config import settings from app.core.security import create_access_token from app.db.database import get_db from app.db.models import Base, Story, User @@ -19,7 +20,8 @@ from app.main import app @pytest.fixture async def async_engine(): - """创建内存数据库引擎。""" + """Create an in-memory database engine.""" + engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) @@ -29,7 +31,8 @@ async def async_engine(): @pytest.fixture async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]: - """创建数据库会话。""" + """Create a database session.""" + session_factory = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False ) @@ -39,7 +42,8 @@ async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]: @pytest.fixture async def test_user(db_session: AsyncSession) -> User: - """创建测试用户。""" + """Create a test user.""" + user = User( id="github:12345", name="Test User", @@ -54,13 +58,74 @@ async def test_user(db_session: AsyncSession) -> User: @pytest.fixture async def test_story(db_session: AsyncSession, test_user: User) -> Story: - """创建测试故事。""" + """Create a plain generated story.""" + story = Story( user_id=test_user.id, title="测试故事", - story_text="从前有一只小兔子...", + story_text="从前有一只小兔子。", cover_prompt="A cute rabbit in a forest", mode="generated", + generation_status="narrative_ready", + image_status="not_requested", + audio_status="not_requested", + ) + db_session.add(story) + await db_session.commit() + await db_session.refresh(story) + return story + + +@pytest.fixture +async def storybook_story(db_session: AsyncSession, test_user: User) -> Story: + """Create a storybook-mode story.""" + + story = Story( + user_id=test_user.id, + title="森林绘本冒险", + story_text=None, + pages=[ + { + "page_number": 1, + "text": "小兔子走进了会发光的森林。", + "image_prompt": "A glowing forest with a curious rabbit", + "image_url": "https://example.com/page-1.png", + }, + { + "page_number": 2, + "text": "它遇见了一位会唱歌的萤火虫朋友。", + "image_prompt": "A rabbit meeting a singing firefly", + "image_url": None, + }, + ], + cover_prompt="A magical forest storybook cover", + image_url="https://example.com/storybook-cover.png", + mode="storybook", + generation_status="degraded_completed", + image_status="failed", + audio_status="not_requested", + last_error="第 2 页插图生成失败", + ) + db_session.add(story) + await db_session.commit() + await db_session.refresh(story) + return story + + +@pytest.fixture +async def degraded_story_with_text(db_session: AsyncSession, test_user: User) -> Story: + """Create a readable story whose image generation already failed.""" + + story = Story( + user_id=test_user.id, + title="部分完成的测试故事", + story_text="从前有一只小兔子继续冒险。", + cover_prompt="A rabbit under the moon", + mode="generated", + generation_status="degraded_completed", + image_status="failed", + audio_status="not_requested", + last_error="封面生成失败", ) db_session.add(story) await db_session.commit() @@ -70,13 +135,14 @@ async def test_story(db_session: AsyncSession, test_user: User) -> Story: @pytest.fixture def auth_token(test_user: User) -> str: - """生成测试用户的 JWT token。""" + """Create a JWT token for the test user.""" + return create_access_token({"sub": test_user.id}) @pytest.fixture def client(db_session: AsyncSession) -> TestClient: - """创建测试客户端。""" + """Create a test client.""" async def override_get_db(): yield db_session @@ -89,35 +155,45 @@ def client(db_session: AsyncSession) -> TestClient: @pytest.fixture def auth_client(client: TestClient, auth_token: str) -> TestClient: - """带认证的测试客户端。""" + """Create an authenticated test client.""" + client.cookies.set("access_token", auth_token) return client @pytest.fixture(autouse=True) def bypass_rate_limit(): - """默认绕过限流,让非限流测试正常运行。""" + """Bypass rate limiting in most tests.""" + with patch("app.core.rate_limiter.get_redis", new_callable=AsyncMock) as mock_redis: - # 创建一个模拟的 Redis 客户端,所有操作返回安全默认值 redis_instance = AsyncMock() - redis_instance.incr.return_value = 1 # 始终返回 1 (不触发限流) + redis_instance.incr.return_value = 1 redis_instance.expire.return_value = True - redis_instance.get.return_value = None # 无锁定记录 + redis_instance.get.return_value = None redis_instance.ttl.return_value = 0 redis_instance.delete.return_value = 1 mock_redis.return_value = redis_instance yield redis_instance +@pytest.fixture(autouse=True) +def isolated_story_audio_cache(tmp_path, monkeypatch): + """Use an isolated directory for cached story audio files.""" + + monkeypatch.setattr(settings, "story_audio_cache_dir", str(tmp_path / "audio")) + yield + + @pytest.fixture def mock_text_provider(): - """Mock 文本生成适配器 API 调用。""" + """Mock text generation.""" + from app.services.adapters.text.models import StoryOutput mock_result = StoryOutput( mode="generated", title="小兔子的冒险", - story_text="从前有一只小兔子...", + story_text="从前有一只小兔子。", cover_prompt_suggestion="A cute rabbit", ) @@ -128,7 +204,8 @@ def mock_text_provider(): @pytest.fixture def mock_image_provider(): - """Mock 图像生成。""" + """Mock image generation.""" + with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock: mock.return_value = "https://example.com/image.png" yield mock @@ -136,7 +213,8 @@ def mock_image_provider(): @pytest.fixture def mock_tts_provider(): - """Mock TTS。""" + """Mock text-to-speech generation.""" + with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock: mock.return_value = b"fake-audio-bytes" yield mock @@ -144,7 +222,8 @@ def mock_tts_provider(): @pytest.fixture def mock_all_providers(mock_text_provider, mock_image_provider, mock_tts_provider): - """Mock 所有 AI 供应商。""" + """Group all mocked providers.""" + return { "text_primary": mock_text_provider, "image_primary": mock_image_provider, diff --git a/backend/tests/test_stories.py b/backend/tests/test_stories.py index 121f3fd..21d6b21 100644 --- a/backend/tests/test_stories.py +++ b/backend/tests/test_stories.py @@ -1,26 +1,41 @@ -"""故事 API 测试。""" +"""Tests for story-related API endpoints.""" +from pathlib import Path from unittest.mock import AsyncMock, patch -import pytest from fastapi.testclient import TestClient +from app.core.config import settings +from app.services.adapters.storybook.primary import Storybook, StorybookPage -# ── 注意 ────────────────────────────────────────────────────────────────────── -# 以下路由尚未实现 (stories.py 中没有对应端点),相关测试标记为 skip: -# GET /api/stories (列表) -# GET /api/stories/{id} (详情) -# DELETE /api/stories/{id} (删除) -# POST /api/image/generate/{id} (封面图片生成) -# GET /api/audio/{id} (音频) -# 实现后请取消 skip 标记。 + +def build_storybook_output() -> Storybook: + """Create a reusable mocked storybook payload.""" + + return Storybook( + title="森林里的发光冒险", + main_character="小兔子露露", + art_style="温暖水彩", + cover_prompt="A glowing forest storybook cover", + pages=[ + StorybookPage( + page_number=1, + text="露露第一次走进会发光的森林。", + image_prompt="Lulu entering a glowing forest", + ), + StorybookPage( + page_number=2, + text="她遇到了一只会唱歌的萤火虫。", + image_prompt="Lulu meeting a singing firefly", + ), + ], + ) class TestStoryGenerate: - """故事生成测试。""" + """Tests for basic story generation.""" def test_generate_without_auth(self, client: TestClient): - """未登录时生成故事。""" response = client.post( "/api/stories/generate", json={"type": "keywords", "data": "小兔子, 森林"}, @@ -28,7 +43,6 @@ class TestStoryGenerate: assert response.status_code == 401 def test_generate_with_empty_data(self, auth_client: TestClient): - """空数据生成故事。""" response = auth_client.post( "/api/stories/generate", json={"type": "keywords", "data": ""}, @@ -36,7 +50,6 @@ class TestStoryGenerate: assert response.status_code == 422 def test_generate_with_invalid_type(self, auth_client: TestClient): - """无效类型生成故事。""" response = auth_client.post( "/api/stories/generate", json={"type": "invalid", "data": "test"}, @@ -44,7 +57,6 @@ class TestStoryGenerate: assert response.status_code == 422 def test_generate_story_success(self, auth_client: TestClient, mock_text_provider): - """成功生成故事。""" response = auth_client.post( "/api/stories/generate", json={"type": "keywords", "data": "小兔子, 森林, 勇气"}, @@ -55,82 +67,96 @@ class TestStoryGenerate: assert "title" in data assert "story_text" in data assert data["mode"] == "generated" + assert data["generation_status"] == "narrative_ready" + assert data["image_status"] == "not_requested" + assert data["audio_status"] == "not_requested" + assert data["last_error"] is None class TestStoryList: - """故事列表测试。""" + """Tests for story listing.""" def test_list_without_auth(self, client: TestClient): - """未登录时获取列表。""" response = client.get("/api/stories") assert response.status_code == 401 def test_list_empty(self, auth_client: TestClient): - """空列表。""" response = auth_client.get("/api/stories") assert response.status_code == 200 assert response.json() == [] def test_list_with_stories(self, auth_client: TestClient, test_story): - """有故事时获取列表。""" response = auth_client.get("/api/stories") assert response.status_code == 200 data = response.json() assert len(data) == 1 assert data[0]["id"] == test_story.id assert data[0]["title"] == test_story.title + assert data[0]["generation_status"] == "narrative_ready" + assert data[0]["image_status"] == "not_requested" + assert data[0]["audio_status"] == "not_requested" def test_list_pagination(self, auth_client: TestClient, test_story): - """分页测试。""" response = auth_client.get("/api/stories?limit=1&offset=0") assert response.status_code == 200 - data = response.json() - assert len(data) == 1 + assert len(response.json()) == 1 response = auth_client.get("/api/stories?limit=1&offset=1") assert response.status_code == 200 - data = response.json() - assert len(data) == 0 + assert len(response.json()) == 0 class TestStoryDetail: - """故事详情测试。""" + """Tests for story detail retrieval.""" def test_get_story_without_auth(self, client: TestClient, test_story): - """未登录时获取详情。""" response = client.get(f"/api/stories/{test_story.id}") assert response.status_code == 401 def test_get_story_not_found(self, auth_client: TestClient): - """故事不存在。""" response = auth_client.get("/api/stories/99999") assert response.status_code == 404 def test_get_story_success(self, auth_client: TestClient, test_story): - """成功获取详情。""" response = auth_client.get(f"/api/stories/{test_story.id}") assert response.status_code == 200 data = response.json() assert data["id"] == test_story.id assert data["title"] == test_story.title assert data["story_text"] == test_story.story_text + assert data["generation_status"] == "narrative_ready" + assert data["image_status"] == "not_requested" + assert data["audio_status"] == "not_requested" + assert data["last_error"] is None + + def test_get_storybook_success(self, auth_client: TestClient, storybook_story): + response = auth_client.get(f"/api/stories/{storybook_story.id}") + assert response.status_code == 200 + data = response.json() + assert data["id"] == storybook_story.id + assert data["mode"] == "storybook" + assert data["story_text"] is None + assert len(data["pages"]) == 2 + assert data["pages"][0]["page_number"] == 1 + assert data["image_url"] == "https://example.com/storybook-cover.png" + assert data["generation_status"] == "degraded_completed" + assert data["image_status"] == "failed" + assert data["audio_status"] == "not_requested" + assert "第 2 页" in data["last_error"] class TestStoryDelete: - """故事删除测试。""" + """Tests for story deletion.""" def test_delete_without_auth(self, client: TestClient, test_story): - """未登录时删除。""" response = client.delete(f"/api/stories/{test_story.id}") assert response.status_code == 401 def test_delete_not_found(self, auth_client: TestClient): - """删除不存在的故事。""" response = auth_client.delete("/api/stories/99999") assert response.status_code == 404 def test_delete_success(self, auth_client: TestClient, test_story): - """成功删除故事。""" response = auth_client.delete(f"/api/stories/{test_story.id}") assert response.status_code == 200 assert response.json()["message"] == "Deleted" @@ -140,11 +166,14 @@ class TestStoryDelete: class TestRateLimit: - """Rate limit 测试。""" + """Tests for story generation rate limiting.""" - def test_rate_limit_allows_normal_requests(self, auth_client: TestClient, mock_text_provider, bypass_rate_limit): - """正常请求不触发限流。""" - # bypass_rate_limit 默认 incr 返回 1,不触发限流 + def test_rate_limit_allows_normal_requests( + self, + auth_client: TestClient, + mock_text_provider, + bypass_rate_limit, + ): for _ in range(3): response = auth_client.post( "/api/stories/generate", @@ -152,9 +181,11 @@ class TestRateLimit: ) assert response.status_code == 200 - def test_rate_limit_blocks_excess_requests(self, auth_client: TestClient, bypass_rate_limit): - """超限请求被阻止。""" - # 让 incr 返回超限值 (> RATE_LIMIT_REQUESTS) + def test_rate_limit_blocks_excess_requests( + self, + auth_client: TestClient, + bypass_rate_limit, + ): bypass_rate_limit.incr.return_value = 11 response = auth_client.post( @@ -166,52 +197,118 @@ class TestRateLimit: class TestImageGenerate: - """封面图片生成测试。""" + """Tests for cover generation endpoint.""" def test_generate_image_without_auth(self, client: TestClient, test_story): - """未登录时生成图片。""" response = client.post(f"/api/image/generate/{test_story.id}") assert response.status_code == 401 def test_generate_image_not_found(self, auth_client: TestClient): - """故事不存在。""" response = auth_client.post("/api/image/generate/99999") assert response.status_code == 404 class TestAudio: - """语音朗读测试。""" + """Tests for story audio endpoint.""" def test_get_audio_without_auth(self, client: TestClient, test_story): - """未登录时获取音频。""" response = client.get(f"/api/audio/{test_story.id}") assert response.status_code == 401 def test_get_audio_not_found(self, auth_client: TestClient): - """故事不存在。""" response = auth_client.get("/api/audio/99999") assert response.status_code == 404 - def test_get_audio_success(self, auth_client: TestClient, test_story, mock_tts_provider): - """成功获取音频。""" + def test_get_audio_success( + self, + auth_client: TestClient, + test_story, + mock_tts_provider, + ): response = auth_client.get(f"/api/audio/{test_story.id}") assert response.status_code == 200 assert response.headers["content-type"] == "audio/mpeg" + assert response.content == b"fake-audio-bytes" + + cached_audio_path = Path(settings.story_audio_cache_dir) / f"story-{test_story.id}.mp3" + assert cached_audio_path.is_file() + + second_response = auth_client.get(f"/api/audio/{test_story.id}") + assert second_response.status_code == 200 + assert second_response.content == b"fake-audio-bytes" + mock_tts_provider.assert_awaited_once() + + detail_response = auth_client.get(f"/api/stories/{test_story.id}") + detail = detail_response.json() + assert detail["audio_status"] == "ready" + assert detail["generation_status"] == "completed" + assert detail["last_error"] is None + + def test_get_audio_regenerates_when_cache_file_is_missing( + self, + auth_client: TestClient, + test_story, + mock_tts_provider, + ): + first_response = auth_client.get(f"/api/audio/{test_story.id}") + assert first_response.status_code == 200 + + cached_audio_path = Path(settings.story_audio_cache_dir) / f"story-{test_story.id}.mp3" + cached_audio_path.unlink() + mock_tts_provider.reset_mock() + + second_response = auth_client.get(f"/api/audio/{test_story.id}") + assert second_response.status_code == 200 + assert second_response.content == b"fake-audio-bytes" + assert cached_audio_path.is_file() + mock_tts_provider.assert_awaited_once() + + def test_get_audio_failure_updates_status(self, auth_client: TestClient, test_story): + with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock_tts: + mock_tts.side_effect = Exception("TTS provider timeout") + response = auth_client.get(f"/api/audio/{test_story.id}") + assert response.status_code == 500 + + detail_response = auth_client.get(f"/api/stories/{test_story.id}") + detail = detail_response.json() + assert detail["audio_status"] == "failed" + assert detail["generation_status"] == "degraded_completed" + assert "TTS provider timeout" in detail["last_error"] + + def test_get_audio_success_preserves_existing_image_error( + self, + auth_client: TestClient, + degraded_story_with_text, + mock_tts_provider, + ): + response = auth_client.get(f"/api/audio/{degraded_story_with_text.id}") + assert response.status_code == 200 + assert response.content == b"fake-audio-bytes" + mock_tts_provider.assert_awaited_once() + + detail_response = auth_client.get(f"/api/stories/{degraded_story_with_text.id}") + detail = detail_response.json() + assert detail["audio_status"] == "ready" + assert detail["generation_status"] == "degraded_completed" + assert detail["last_error"] == "封面生成失败" class TestGenerateFull: - """完整故事生成测试(/api/stories/generate/full)。""" + """Tests for complete story generation.""" def test_generate_full_without_auth(self, client: TestClient): - """未登录时生成完整故事。""" response = client.post( "/api/stories/generate/full", json={"type": "keywords", "data": "小兔子, 森林"}, ) assert response.status_code == 401 - def test_generate_full_success(self, auth_client: TestClient, mock_text_provider, mock_image_provider): - """成功生成完整故事(含图片)。""" + def test_generate_full_success( + self, + auth_client: TestClient, + mock_text_provider, + mock_image_provider, + ): response = auth_client.post( "/api/stories/generate/full", json={"type": "keywords", "data": "小兔子, 森林, 勇气"}, @@ -223,11 +320,14 @@ class TestGenerateFull: assert "story_text" in data assert data["mode"] == "generated" assert data["image_url"] == "https://example.com/image.png" - assert data["audio_ready"] is False # 音频按需生成 + assert data["audio_ready"] is False assert data["errors"] == {} + assert data["generation_status"] == "completed" + assert data["image_status"] == "ready" + assert data["audio_status"] == "not_requested" + assert data["last_error"] is None def test_generate_full_image_failure(self, auth_client: TestClient, mock_text_provider): - """图片生成失败时返回部分成功。""" with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_img: mock_img.side_effect = Exception("Image API error") response = auth_client.post( @@ -239,9 +339,17 @@ class TestGenerateFull: assert data["image_url"] is None assert "image" in data["errors"] assert "Image API error" in data["errors"]["image"] + assert data["generation_status"] == "degraded_completed" + assert data["image_status"] == "failed" + assert data["audio_status"] == "not_requested" + assert "Image API error" in data["last_error"] - def test_generate_full_with_education_theme(self, auth_client: TestClient, mock_text_provider, mock_image_provider): - """带教育主题生成故事。""" + def test_generate_full_with_education_theme( + self, + auth_client: TestClient, + mock_text_provider, + mock_image_provider, + ): response = auth_client.post( "/api/stories/generate/full", json={ @@ -257,11 +365,80 @@ class TestGenerateFull: class TestImageGenerateSuccess: - """封面图片生成成功测试。""" + """Tests for successful cover generation.""" - def test_generate_image_success(self, auth_client: TestClient, test_story, mock_image_provider): - """成功生成图片。""" + def test_generate_image_success( + self, + auth_client: TestClient, + test_story, + mock_image_provider, + ): response = auth_client.post(f"/api/image/generate/{test_story.id}") assert response.status_code == 200 data = response.json() assert data["image_url"] == "https://example.com/image.png" + assert data["generation_status"] == "completed" + assert data["image_status"] == "ready" + assert data["audio_status"] == "not_requested" + assert data["last_error"] is None + + +class TestStorybookGenerate: + """Tests for storybook generation status handling.""" + + def test_generate_storybook_success(self, auth_client: TestClient): + with patch("app.services.story_service.generate_storybook", new_callable=AsyncMock) as mock_storybook: + with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_image: + mock_storybook.return_value = build_storybook_output() + mock_image.side_effect = [ + "https://example.com/storybook-cover.png", + "https://example.com/storybook-page-1.png", + "https://example.com/storybook-page-2.png", + ] + + response = auth_client.post( + "/api/storybook/generate", + json={ + "keywords": "森林, 发光, 友情", + "page_count": 6, + "generate_images": True, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["id"] is not None + assert data["generation_status"] == "completed" + assert data["image_status"] == "ready" + assert data["audio_status"] == "not_requested" + assert data["last_error"] is None + assert len(data["pages"]) == 2 + assert data["cover_url"] == "https://example.com/storybook-cover.png" + + def test_generate_storybook_partial_image_failure(self, auth_client: TestClient): + async def image_side_effect(prompt: str, **kwargs): + if "singing firefly" in prompt: + raise Exception("Image API error") + slug = prompt.split()[0].lower() + return f"https://example.com/{slug}.png" + + with patch("app.services.story_service.generate_storybook", new_callable=AsyncMock) as mock_storybook: + with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_image: + mock_storybook.return_value = build_storybook_output() + mock_image.side_effect = image_side_effect + + response = auth_client.post( + "/api/storybook/generate", + json={ + "keywords": "森林, 发光, 友情", + "page_count": 6, + "generate_images": True, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["generation_status"] == "degraded_completed" + assert data["image_status"] == "failed" + assert data["audio_status"] == "not_requested" + assert "第 2 页插图生成失败" in data["last_error"] diff --git a/frontend/src/stores/storybook.ts b/frontend/src/stores/storybook.ts index 345c93a..5d00f3f 100644 --- a/frontend/src/stores/storybook.ts +++ b/frontend/src/stores/storybook.ts @@ -1,14 +1,14 @@ - -import { defineStore } from 'pinia' -import { ref } from 'vue' - -export interface StorybookPage { - page_number: number - text: string - image_prompt: string - image_url?: string -} - + +import { defineStore } from 'pinia' +import { ref } from 'vue' + +export interface StorybookPage { + page_number: number + text: string + image_prompt: string + image_url?: string +} + export interface Storybook { id?: number // 新增 title: string @@ -17,22 +17,26 @@ export interface Storybook { pages: StorybookPage[] cover_prompt: string cover_url?: string + generation_status?: string + image_status?: string + audio_status?: string + last_error?: string | null } - -export const useStorybookStore = defineStore('storybook', () => { - const currentStorybook = ref(null) - - function setStorybook(storybook: Storybook) { - currentStorybook.value = storybook - } - - function clearStorybook() { - currentStorybook.value = null - } - - return { - currentStorybook, - setStorybook, - clearStorybook, - } -}) + +export const useStorybookStore = defineStore('storybook', () => { + const currentStorybook = ref(null) + + function setStorybook(storybook: Storybook) { + currentStorybook.value = storybook + } + + function clearStorybook() { + currentStorybook.value = null + } + + return { + currentStorybook, + setStorybook, + clearStorybook, + } +}) diff --git a/frontend/src/utils/storyStatus.ts b/frontend/src/utils/storyStatus.ts new file mode 100644 index 0000000..363a279 --- /dev/null +++ b/frontend/src/utils/storyStatus.ts @@ -0,0 +1,79 @@ +export type StoryGenerationStatus = + | 'narrative_ready' + | 'assets_generating' + | 'completed' + | 'degraded_completed' + | 'failed' + +export type StoryAssetStatus = + | 'not_requested' + | 'generating' + | 'ready' + | 'failed' + +interface StatusMeta { + label: string + description: string + badgeClass: string +} + +const generationStatusMetaMap: Record = { + narrative_ready: { + label: '文本已完成', + description: '故事内容已经生成,可以继续补充封面或音频。', + badgeClass: 'bg-sky-50 text-sky-700 border border-sky-100', + }, + assets_generating: { + label: '资源生成中', + description: '封面或音频正在生成中,请稍候查看结果。', + badgeClass: 'bg-amber-50 text-amber-700 border border-amber-100', + }, + completed: { + label: '内容可用', + description: '当前内容已经达到可阅读状态。', + badgeClass: 'bg-emerald-50 text-emerald-700 border border-emerald-100', + }, + degraded_completed: { + label: '部分降级完成', + description: '核心内容可用,但有部分资源生成失败。', + badgeClass: 'bg-orange-50 text-orange-700 border border-orange-100', + }, + failed: { + label: '生成失败', + description: '当前内容还未成功生成,请稍后重试。', + badgeClass: 'bg-rose-50 text-rose-700 border border-rose-100', + }, +} + +const assetStatusMetaMap: Record = { + not_requested: { + label: '未请求', + description: '还没有发起该资源生成。', + badgeClass: 'bg-slate-100 text-slate-600 border border-slate-200', + }, + generating: { + label: '生成中', + description: '资源正在生成,请稍候。', + badgeClass: 'bg-amber-50 text-amber-700 border border-amber-100', + }, + ready: { + label: '已就绪', + description: '该资源可使用。', + badgeClass: 'bg-emerald-50 text-emerald-700 border border-emerald-100', + }, + failed: { + label: '失败', + description: '最近一次生成失败,可以稍后重试。', + badgeClass: 'bg-rose-50 text-rose-700 border border-rose-100', + }, +} + +export function getGenerationStatusMeta(status?: string): StatusMeta { + return generationStatusMetaMap[(status ?? 'narrative_ready') as StoryGenerationStatus] + ?? generationStatusMetaMap.narrative_ready +} + +export function getAssetStatusMeta(status?: string): StatusMeta { + return assetStatusMetaMap[(status ?? 'not_requested') as StoryAssetStatus] + ?? assetStatusMetaMap.not_requested +} diff --git a/frontend/src/views/MyStories.vue b/frontend/src/views/MyStories.vue index 3b46ebb..6d12a3a 100644 --- a/frontend/src/views/MyStories.vue +++ b/frontend/src/views/MyStories.vue @@ -1,19 +1,20 @@