diff --git a/backend/Dockerfile b/backend/Dockerfile index 23de17c..3abb42c 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -2,26 +2,32 @@ FROM python:3.11-slim WORKDIR /app -# 安装系统依赖 (如果需要) -# RUN apt-get update && apt-get install -y gcc libpq-dev && rm -rf /var/lib/apt/lists/* +# 设置环境变量 +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 -# 复制项目文件 +# 安装系统工具 (curl用于可能的健康检查) +RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/* + +# 1. 缓存层:仅复制依赖定义并安装 +# 创建伪造的 app 目录以满足 pip install . 的要求 COPY pyproject.toml . -# 复制源码 +RUN mkdir app && touch app/__init__.py +RUN pip install --no-cache-dir . + +# 2. 源码层:复制真实代码 COPY app ./app COPY alembic ./alembic COPY alembic.ini . -# 安装依赖 -# 使用 pip 安装当前目录 (.),会自动解析 pyproject.toml -RUN pip install --no-cache-dir . +# 再次安装本身(不带依赖),确保源码更新被标记为已安装 +RUN pip install --no-cache-dir --no-deps . -# 创建静态文件目录 (用于存放生成的图片) +# 创建静态文件目录 RUN mkdir -p static/images # 暴露端口 EXPOSE 8000 -# 启动命令 -# 生产环境建议使用 gunicorn 或 uvicorn --workers +# 默认启动命令(可被 docker-compose 覆盖) CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/backend/app/api/stories.py b/backend/app/api/stories.py index ec34b65..cd84d74 100644 --- a/backend/app/api/stories.py +++ b/backend/app/api/stories.py @@ -2,136 +2,41 @@ import asyncio import json -import time import uuid -from typing import AsyncGenerator, Literal +from typing import AsyncGenerator -from fastapi import APIRouter, Depends, HTTPException, Request -from fastapi.responses import Response -from pydantic import BaseModel, Field -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload +from fastapi import APIRouter, Depends, HTTPException, Request, Response from sse_starlette.sse import EventSourceResponse +from sqlalchemy.ext.asyncio import AsyncSession from app.core.deps import require_user from app.core.logging import get_logger -from app.db.database import get_db -from app.db.models import ChildProfile, Story, StoryUniverse, User -from app.services.provider_router import ( - generate_image, - generate_story_content, - generate_storybook, - text_to_speech, -) from app.core.rate_limiter import check_rate_limit -from app.tasks.achievements import extract_story_achievements +from app.db.database import get_db +from app.db.models import User +from app.schemas.story_schemas import ( + GenerateRequest, + StoryResponse, + FullStoryResponse, + StorybookRequest, + StorybookResponse, + StoryListItem, + AchievementItem, +) +from app.services import story_service +from app.services.memory_service import build_enhanced_memory_context +from app.services.provider_router import ( + generate_story_content, + generate_image, +) logger = get_logger(__name__) - router = APIRouter() -MAX_DATA_LENGTH = 2000 -MAX_EDU_THEME_LENGTH = 200 -MAX_TTS_LENGTH = 4000 - RATE_LIMIT_WINDOW = 60 # seconds RATE_LIMIT_REQUESTS = 10 -class GenerateRequest(BaseModel): - """Story generation request.""" - - type: Literal["keywords", "full_story"] - data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH) - education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH) - child_profile_id: str | None = None - universe_id: str | None = None - - -class StoryResponse(BaseModel): - """Story response.""" - - id: int - title: str - story_text: str - cover_prompt: str | None - image_url: str | None - mode: str - child_profile_id: str | None = None - universe_id: str | None = None - - -class StoryListItem(BaseModel): - """Story list item.""" - - id: int - title: str - image_url: str | None - created_at: str - mode: str - - -class FullStoryResponse(BaseModel): - """完整故事响应(含图片和音频状态)。""" - - id: int - title: str - story_text: str - cover_prompt: str | None - image_url: str | None - audio_ready: bool - mode: str - errors: dict[str, str | None] = Field(default_factory=dict) - child_profile_id: str | None = None - universe_id: str | None = None - - -from app.services.memory_service import build_enhanced_memory_context - - -async def _validate_profile_and_universe( - request: GenerateRequest, - user: User, - db: AsyncSession, -) -> tuple[str | None, str | None]: - if not request.child_profile_id and not request.universe_id: - return None, None - - profile_id = request.child_profile_id - universe_id = request.universe_id - - if profile_id: - result = await db.execute( - select(ChildProfile).where( - ChildProfile.id == profile_id, - ChildProfile.user_id == user.id, - ) - ) - profile = result.scalar_one_or_none() - if not profile: - raise HTTPException(status_code=404, detail="孩子档案不存在") - - if universe_id: - result = await db.execute( - select(StoryUniverse) - .join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id) - .where( - StoryUniverse.id == universe_id, - ChildProfile.user_id == user.id, - ) - ) - universe = result.scalar_one_or_none() - if not universe: - raise HTTPException(status_code=404, detail="故事宇宙不存在") - if profile_id and universe.child_profile_id != profile_id: - raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配") - if not profile_id: - profile_id = universe.child_profile_id - - return profile_id, universe_id - - @router.post("/stories/generate", response_model=StoryResponse) async def generate_story( request: GenerateRequest, @@ -140,47 +45,7 @@ async def generate_story( ): """Generate or enhance a story.""" await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) - profile_id, universe_id = await _validate_profile_and_universe(request, user, db) - memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) - - try: - result = await generate_story_content( - input_type=request.type, - data=request.data, - education_theme=request.education_theme, - memory_context=memory_context, - ) - except HTTPException: - raise - except Exception: - raise HTTPException(status_code=502, detail="Story generation failed, please try again.") - - story = Story( - user_id=user.id, - child_profile_id=profile_id, - universe_id=universe_id, - title=result.title, - story_text=result.story_text, - cover_prompt=result.cover_prompt_suggestion, - mode=result.mode, - ) - db.add(story) - await db.commit() - await db.refresh(story) - - if universe_id: - extract_story_achievements.delay(story.id, universe_id) - - return StoryResponse( - id=story.id, - title=story.title, - story_text=story.story_text, - cover_prompt=story.cover_prompt, - image_url=story.image_url, - mode=story.mode, - child_profile_id=story.child_profile_id, - universe_id=story.universe_id, - ) + return await story_service.generate_and_save_story(request, user.id, db) @router.post("/stories/generate/full", response_model=FullStoryResponse) @@ -189,71 +54,9 @@ async def generate_story_full( user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): - """生成完整故事(故事 + 并行生成图片和音频)。 - - 部分成功策略:故事必须成功,图片/音频失败不影响整体。 - """ + """Generate complete story (story + parallel image/audio generation).""" await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) - profile_id, universe_id = await _validate_profile_and_universe(request, user, db) - memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) - - # Step 1: 故事生成(必须成功) - try: - result = await generate_story_content( - input_type=request.type, - data=request.data, - education_theme=request.education_theme, - memory_context=memory_context, - ) - except Exception as exc: - logger.error("story_generation_failed", error=str(exc)) - raise HTTPException(status_code=502, detail="Story generation failed, please try again.") - - # 保存故事 - story = Story( - user_id=user.id, - child_profile_id=profile_id, - universe_id=universe_id, - title=result.title, - story_text=result.story_text, - cover_prompt=result.cover_prompt_suggestion, - mode=result.mode, - ) - db.add(story) - await db.commit() - await db.refresh(story) - - if universe_id: - extract_story_achievements.delay(story.id, universe_id) - - # Step 2: 生成封面图片(音频按需生成,避免浪费) - errors: dict[str, str | None] = {} - image_url: str | None = None - - if story.cover_prompt: - try: - image_url = await generate_image(story.cover_prompt) - story.image_url = image_url - await db.commit() - except Exception as exc: - errors["image"] = str(exc) - logger.warning("image_generation_failed", story_id=story.id, error=str(exc)) - - # 注意:音频不在此处预生成,用户通过 /api/audio/{id} 按需获取 - # 这样避免生成后丢弃造成的成本浪费 - - return FullStoryResponse( - id=story.id, - title=story.title, - story_text=story.story_text, - cover_prompt=story.cover_prompt, - image_url=image_url, - audio_ready=False, # 音频需要用户主动请求 - mode=story.mode, - errors=errors, - child_profile_id=story.child_profile_id, - universe_id=story.universe_id, - ) + return await story_service.generate_full_story_service(request, user.id, db) @router.post("/stories/generate/stream") @@ -263,53 +66,39 @@ async def generate_story_stream( user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): - """流式生成故事(SSE)。 - - 事件流程: - - started: 返回 story_id - - story_ready: 返回 title, content - - story_failed: 返回 error - - image_ready: 返回 image_url - - image_failed: 返回 error - - complete: 结束流 - """ + """流式生成故事(SSE)。""" await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) - profile_id, universe_id = await _validate_profile_and_universe(request, user, db) + + # Validation + profile_id, universe_id = await story_service.validate_profile_and_universe( + request.child_profile_id, request.universe_id, user.id, db + ) + + # Build Context memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) async def event_generator() -> AsyncGenerator[dict, None]: story_id = str(uuid.uuid4()) yield {"event": "started", "data": json.dumps({"story_id": story_id})} - # Step 1: 生成故事 + # Step 1: Generate Content try: result = await generate_story_content( input_type=request.type, data=request.data, education_theme=request.education_theme, memory_context=memory_context, + db=db, ) except Exception as e: logger.error("sse_story_generation_failed", error=str(e)) yield {"event": "story_failed", "data": json.dumps({"error": str(e)})} return - # 保存故事 - story = Story( - user_id=user.id, - child_profile_id=profile_id, - universe_id=universe_id, - title=result.title, - story_text=result.story_text, - cover_prompt=result.cover_prompt_suggestion, - mode=result.mode, + # Save Story + story = await story_service.create_story_from_result( + result, user.id, profile_id, universe_id, db ) - db.add(story) - await db.commit() - await db.refresh(story) - - if universe_id: - extract_story_achievements.delay(story.id, universe_id) yield { "event": "story_ready", @@ -324,10 +113,11 @@ async def generate_story_stream( }), } - # Step 2: 并行生成图片(音频按需) + # Step 2: Generate Image if story.cover_prompt: try: - image_url = await generate_image(story.cover_prompt) + # Direct call to provider router's generate_image, sharing db session + image_url = await generate_image(story.cover_prompt, db=db) story.image_url = image_url await db.commit() yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})} @@ -340,217 +130,71 @@ async def generate_story_stream( return EventSourceResponse(event_generator()) -# ==================== Storybook API ==================== - - -class StorybookRequest(BaseModel): - """Storybook 生成请求。""" - - keywords: str = Field(..., min_length=1, max_length=200) - page_count: int = Field(default=6, ge=4, le=12) - education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH) - generate_images: bool = Field(default=False, description="是否同时生成插图") - child_profile_id: str | None = None - universe_id: str | None = None - - -class StorybookPageResponse(BaseModel): - """故事书单页响应。""" - - page_number: int - text: str - image_prompt: str - image_url: str | None = None - - -class StorybookResponse(BaseModel): - """故事书响应。""" - - id: int | None = None - title: str - main_character: str - art_style: str - pages: list[StorybookPageResponse] - cover_prompt: str - cover_url: str | None = None - - @router.post("/storybook/generate", response_model=StorybookResponse) async def generate_storybook_api( request: StorybookRequest, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): - """生成分页故事书并保存。 - - 返回故事书结构,包含每页文字和图像提示词。 - """ + """Generate storybook.""" await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) - - # 验证档案和宇宙 - # 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数 - # 这里我们直接手动验证,或重构验证函数。为了简单,手动调用部分逻辑。 - - # 构建临时的 GenerateRequest 用于验证验证函数签名(或者直接手动查库更好) - profile_id = request.child_profile_id - universe_id = request.universe_id - - if profile_id: - result = await db.execute( - select(ChildProfile).where( - ChildProfile.id == profile_id, - ChildProfile.user_id == user.id, - ) - ) - if not result.scalar_one_or_none(): - raise HTTPException(status_code=404, detail="孩子档案不存在") - - if universe_id: - result = await db.execute( - select(StoryUniverse) - .join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id) - .where( - StoryUniverse.id == universe_id, - ChildProfile.user_id == user.id, - ) - ) - universe = result.scalar_one_or_none() - if not universe: - raise HTTPException(status_code=404, detail="故事宇宙不存在") - if profile_id and universe.child_profile_id != profile_id: - raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配") - if not profile_id: - profile_id = universe.child_profile_id - - logger.info( - "storybook_request", - user_id=user.id, - keywords=request.keywords, - page_count=request.page_count, - profile_id=profile_id, - universe_id=universe_id, - ) - - memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) - - try: - # 注意:generate_storybook 目前可能不支持记忆上下文注入 - # 我们需要看看 generate_storybook 的签名 - # 如果不支持,记忆功能在绘本模式下暂不可用,但基本参数传递是支持的 - storybook = await generate_storybook( - keywords=request.keywords, - page_count=request.page_count, - education_theme=request.education_theme, - memory_context=memory_context, - db=db, - ) - except Exception as e: - logger.error("storybook_generation_failed", error=str(e)) - raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}") - - # ============================================================================== - # 核心升级: 并行全量生成 (Parallel Full Rendering) - # ============================================================================== - final_cover_url = storybook.cover_url - - if request.generate_images: - logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages)) - - # 1. 准备所有生图任务 (封面 + 所有内页) - tasks = [] - - # 封面任务 - async def _gen_cover(): - if storybook.cover_prompt and not storybook.cover_url: - try: - return await generate_image(storybook.cover_prompt, db=db) - except Exception as e: - logger.warning("cover_gen_failed", error=str(e)) - return storybook.cover_url - - tasks.append(_gen_cover()) - - # 内页任务 - async def _gen_page(page): - if page.image_prompt and not page.image_url: - try: - url = await generate_image(page.image_prompt, db=db) - page.image_url = url - except Exception as e: - logger.warning("page_gen_failed", page=page.page_number, error=str(e)) - - for page in storybook.pages: - tasks.append(_gen_page(page)) - - # 2. 并发执行所有任务 - # 使用 return_exceptions=True 防止单张失败影响整体 - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 3. 更新封面结果 (results[0] 是封面任务的返回值) - cover_res = results[0] - if isinstance(cover_res, str): - final_cover_url = cover_res - - logger.info("storybook_parallel_generation_complete") - - # ============================================================================== - - # 构建并保存 Story 对象 - # 将 pages 对象转换为字典列表以存入 JSON 字段 - pages_data = [ - { - "page_number": p.page_number, - "text": p.text, - "image_prompt": p.image_prompt, - "image_url": p.image_url, - } - for p in storybook.pages - ] - - story = Story( - user_id=user.id, - child_profile_id=profile_id, - universe_id=universe_id, - title=storybook.title, - mode="storybook", - pages=pages_data, # 存入 JSON 字段 - story_text=None, # 绘本模式下,主文本可为空,或者可以存个摘要 - cover_prompt=storybook.cover_prompt, - image_url=final_cover_url, - ) - db.add(story) - await db.commit() - await db.refresh(story) - - if universe_id: - extract_story_achievements.delay(story.id, universe_id) - - # 构建响应 (使用更新后的 pages_data) - response_pages = [ - StorybookPageResponse( - page_number=p["page_number"], - text=p["text"], - image_prompt=p["image_prompt"], - image_url=p.get("image_url"), - ) - for p in pages_data - ] - - return StorybookResponse( - id=story.id, - title=storybook.title, - main_character=storybook.main_character, - art_style=storybook.art_style, - pages=response_pages, - cover_prompt=storybook.cover_prompt, - cover_url=final_cover_url, - ) + return await story_service.generate_storybook_service(request, user.id, db) -class AchievementItem(BaseModel): - type: str - description: str - obtained_at: str | None = None +# ==================== Missing Endpoints (Issue #5) ==================== + +@router.get("/stories", response_model=list[StoryListItem]) +async def list_stories( + limit: int = 20, + offset: int = 0, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """List stories.""" + return await story_service.list_stories(user.id, limit, offset, db) + + +@router.get("/stories/{story_id}", response_model=StoryResponse) +async def get_story( + story_id: int, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Get story detail.""" + return await story_service.get_story_detail(story_id, user.id, db) + + +@router.delete("/stories/{story_id}") +async def delete_story( + story_id: int, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Delete story.""" + await story_service.delete_story(story_id, user.id, db) + return {"message": "Deleted"} + + +@router.post("/image/generate/{story_id}") +async def generate_story_image( + story_id: int, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Generate cover image for story.""" + url = await story_service.generate_story_cover(story_id, user.id, db) + return {"image_url": url} + + +@router.get("/audio/{story_id}") +async def get_story_audio( + story_id: int, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Get story audio (MP3).""" + audio_bytes = await story_service.generate_story_audio(story_id, user.id, db) + return Response(content=audio_bytes, media_type="audio/mpeg") @router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem]) @@ -559,32 +203,5 @@ async def get_story_achievements( user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): - """Get achievements unlocked by a specific story.""" - # 使用 joinedload 避免 N+1 查询 - result = await db.execute( - select(Story) - .options(joinedload(Story.story_universe)) - .where(Story.id == story_id, Story.user_id == user.id) - ) - story = result.scalar_one_or_none() - - if not story: - raise HTTPException(status_code=404, detail="Story not found") - - if not story.universe_id or not story.story_universe: - return [] - - universe = story.story_universe - if not universe.achievements: - return [] - - results = [] - for ach in universe.achievements: - if isinstance(ach, dict) and ach.get("source_story_id") == story_id: - results.append(AchievementItem( - type=ach.get("type", "Unknown"), - description=ach.get("description", ""), - obtained_at=ach.get("obtained_at") - )) - - return results + """Get story achievements.""" + return await story_service.get_story_achievements(story_id, user.id, db) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index ca038b9..8823b15 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -52,6 +52,9 @@ class Settings(BaseSettings): # Celery (Redis) celery_broker_url: str = Field("redis://localhost:6379/0") celery_result_backend: str = Field("redis://localhost:6379/0") + + # Generic Redis + redis_url: str = Field("redis://localhost:6379/0", description="Redis connection URL") # Admin console enable_admin_console: bool = False diff --git a/backend/app/core/redis.py b/backend/app/core/redis.py new file mode 100644 index 0000000..f8e6962 --- /dev/null +++ b/backend/app/core/redis.py @@ -0,0 +1,25 @@ +"""Redis client module.""" + +from typing import AsyncGenerator + +from redis.asyncio import Redis, from_url + +from app.core.config import settings + +_redis_pool: Redis | None = None + + +async def get_redis() -> Redis: + """Get global Redis client instance.""" + global _redis_pool + if _redis_pool is None: + _redis_pool = from_url(settings.redis_url, encoding="utf-8", decode_responses=True) + return _redis_pool + + +async def close_redis(): + """Close Redis connection.""" + global _redis_pool + if _redis_pool: + await _redis_pool.close() + _redis_pool = None diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py new file mode 100644 index 0000000..d2beb0e --- /dev/null +++ b/backend/app/schemas/__init__.py @@ -0,0 +1 @@ +"""故事相关 Schema 模块。""" diff --git a/backend/app/schemas/story_schemas.py b/backend/app/schemas/story_schemas.py new file mode 100644 index 0000000..80a8c80 --- /dev/null +++ b/backend/app/schemas/story_schemas.py @@ -0,0 +1,106 @@ +"""故事相关 Pydantic 模型。""" + +from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, Field + + +MAX_DATA_LENGTH = 2000 +MAX_EDU_THEME_LENGTH = 200 +MAX_TTS_LENGTH = 4000 + + +# ==================== 故事模型 ==================== + + +class GenerateRequest(BaseModel): + """Story generation request.""" + + type: Literal["keywords", "full_story"] + data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH) + education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH) + child_profile_id: str | None = None + universe_id: str | None = None + + +class StoryResponse(BaseModel): + """Story response.""" + + id: int + title: str + story_text: str + cover_prompt: str | None + image_url: str | None + mode: str + child_profile_id: str | None = None + universe_id: str | None = None + + +class StoryListItem(BaseModel): + """Story list item.""" + + id: int + title: str + image_url: str | None + created_at: datetime + mode: str + + +class FullStoryResponse(BaseModel): + """完整故事响应(含图片和音频状态)。""" + + id: int + title: str + story_text: str + cover_prompt: str | None + image_url: str | None + audio_ready: bool + mode: str + errors: dict[str, str | None] = Field(default_factory=dict) + child_profile_id: str | None = None + universe_id: str | None = None + + +# ==================== 绘本模型 ==================== + + +class StorybookRequest(BaseModel): + """Storybook 生成请求。""" + + keywords: str = Field(..., min_length=1, max_length=200) + page_count: int = Field(default=6, ge=4, le=12) + education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH) + generate_images: bool = Field(default=False, description="是否同时生成插图") + child_profile_id: str | None = None + universe_id: str | None = None + + +class StorybookPageResponse(BaseModel): + """故事书单页响应。""" + + page_number: int + text: str + image_prompt: str + image_url: str | None = None + + +class StorybookResponse(BaseModel): + """故事书响应。""" + + id: int | None = None + title: str + main_character: str + art_style: str + pages: list[StorybookPageResponse] + cover_prompt: str + cover_url: str | None = None + + +# ==================== 成就模型 ==================== + + +class AchievementItem(BaseModel): + type: str + description: str + obtained_at: str | None = None diff --git a/backend/app/services/provider_cache.py b/backend/app/services/provider_cache.py index bd59404..e1a0db3 100644 --- a/backend/app/services/provider_cache.py +++ b/backend/app/services/provider_cache.py @@ -1,31 +1,109 @@ -"""In-memory cache for providers loaded from DB.""" +"""Redis-backed cache for providers loaded from DB.""" +import json from collections import defaultdict from typing import Literal +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.core.logging import get_logger +from app.core.redis import get_redis from app.db.admin_models import Provider +logger = get_logger(__name__) + ProviderType = Literal["text", "image", "tts", "storybook"] -_cache: dict[ProviderType, list[Provider]] = defaultdict(list) + +class CachedProvider(BaseModel): + """Serializable provider configuration matching DB model fields.""" + id: str + name: str + type: str + adapter: str + model: str | None = None + api_base: str | None = None + api_key: str | None = None + timeout_ms: int = 60000 + max_retries: int = 1 + weight: int = 1 + priority: int = 0 + enabled: bool = True + config_json: dict | None = None + config_ref: str | None = None -async def reload_providers(db: AsyncSession): - result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712 - providers = result.scalars().all() - grouped: dict[ProviderType, list[Provider]] = defaultdict(list) - for p in providers: - grouped[p.type].append(p) - # sort by priority desc, then weight desc - for k in grouped: - grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True) - _cache.clear() - _cache.update(grouped) - return _cache +# Local memory fallback (L1 cache) +_local_cache: dict[ProviderType, list[CachedProvider]] = defaultdict(list) +CACHE_KEY = "dreamweaver:providers:config" -def get_providers(provider_type: ProviderType) -> list[Provider]: - return _cache.get(provider_type, []) +async def reload_providers(db: AsyncSession) -> dict[ProviderType, list[CachedProvider]]: + """Reload providers from DB and update Redis cache.""" + try: + result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712 + providers = result.scalars().all() + + # Convert to Pydantic models + cached_list = [] + for p in providers: + cached_list.append(CachedProvider( + id=p.id, + name=p.name, + type=p.type, + adapter=p.adapter, + model=p.model, + api_base=p.api_base, + api_key=p.api_key, + timeout_ms=p.timeout_ms, + max_retries=p.max_retries, + weight=p.weight, + priority=p.priority, + enabled=p.enabled, + config_json=p.config_json, + config_ref=p.config_ref + )) + + # Group by type + grouped: dict[str, list[CachedProvider]] = defaultdict(list) + for cp in cached_list: + grouped[cp.type].append(cp) + + # Sort + for k in grouped: + grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True) + + # Update Redis + redis = await get_redis() + # Serialize entire dict structure + # Pydantic -> dict -> json + json_data = {k: [p.model_dump() for p in v] for k, v in grouped.items()} + await redis.set(CACHE_KEY, json.dumps(json_data)) + + # Update local cache + _local_cache.clear() + _local_cache.update(grouped) + return grouped + + except Exception as e: + logger.error("failed_to_reload_providers", error=str(e)) + raise + + +async def get_providers(provider_type: ProviderType) -> list[CachedProvider]: + """Get providers from Redis (preferred) or local fallback.""" + try: + redis = await get_redis() + data = await redis.get(CACHE_KEY) + if data: + raw_dict = json.loads(data) + if provider_type in raw_dict: + return [CachedProvider(**item) for item in raw_dict[provider_type]] + return [] + except Exception as e: + logger.warning("redis_cache_read_failed", error=str(e)) + + # Fallback to local memory + return _local_cache.get(provider_type, []) diff --git a/backend/app/services/provider_router.py b/backend/app/services/provider_router.py index 259a152..d49947a 100644 --- a/backend/app/services/provider_router.py +++ b/backend/app/services/provider_router.py @@ -177,7 +177,7 @@ def _build_config_from_provider(provider: "Provider") -> AdapterConfig: ) -def _get_providers_with_config( +async def _get_providers_with_config( provider_type: ProviderType, ) -> list[tuple[str, AdapterConfig, "Provider | None"]]: """获取供应商列表及其配置。 @@ -185,7 +185,7 @@ def _get_providers_with_config( Returns: [(adapter_name, config, provider_or_none), ...] 按优先级排序 """ - db_providers = get_providers(provider_type) + db_providers = await get_providers(provider_type) if db_providers: return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers] @@ -265,7 +265,7 @@ async def _route_with_failover( user_id: 用户 ID(可选,用于成本追踪和预算检查) **kwargs: 传递给适配器的参数 """ - providers = _get_providers_with_config(provider_type) + providers = await _get_providers_with_config(provider_type) if not providers: raise ValueError(f"No {provider_type} providers configured.") diff --git a/backend/app/services/story_service.py b/backend/app/services/story_service.py new file mode 100644 index 0000000..838275f --- /dev/null +++ b/backend/app/services/story_service.py @@ -0,0 +1,434 @@ +"""Story business logic service.""" + +import asyncio +import json +import uuid +from typing import Literal + +from fastapi import HTTPException +from sqlalchemy import select, desc +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from app.core.logging import get_logger +from app.db.models import ChildProfile, Story, StoryUniverse +from app.schemas.story_schemas import ( + GenerateRequest, + StorybookRequest, + FullStoryResponse, + StorybookResponse, + StorybookPageResponse, + AchievementItem, +) +from app.services.memory_service import build_enhanced_memory_context +from app.services.provider_router import ( + generate_story_content, + generate_image, + generate_storybook, +) +from app.tasks.achievements import extract_story_achievements + +logger = get_logger(__name__) + + +async def validate_profile_and_universe( + profile_id: str | None, + universe_id: str | None, + user_id: str, + db: AsyncSession, +) -> tuple[str | None, str | None]: + """Validate child profile and universe ownership/relationship.""" + if not profile_id and not universe_id: + return None, None + + if profile_id: + result = await db.execute( + select(ChildProfile).where( + ChildProfile.id == profile_id, + ChildProfile.user_id == user_id, + ) + ) + profile = result.scalar_one_or_none() + if not profile: + raise HTTPException(status_code=404, detail="孩子档案不存在") + + if universe_id: + result = await db.execute( + select(StoryUniverse) + .join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id) + .where( + StoryUniverse.id == universe_id, + ChildProfile.user_id == user_id, + ) + ) + universe = result.scalar_one_or_none() + if not universe: + raise HTTPException(status_code=404, detail="故事宇宙不存在") + if profile_id and universe.child_profile_id != profile_id: + raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配") + if not profile_id: + profile_id = universe.child_profile_id + + return profile_id, universe_id + + +async def generate_and_save_story( + request: GenerateRequest, + user_id: str, + db: AsyncSession, +) -> Story: + """Generate generic story content and save to DB.""" + # 1. Validate + profile_id, universe_id = await validate_profile_and_universe( + request.child_profile_id, request.universe_id, user_id, db + ) + + # 2. Build Context + memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) + + # 3. Generate + try: + result = await generate_story_content( + input_type=request.type, + data=request.data, + education_theme=request.education_theme, + memory_context=memory_context, + db=db, + ) + except Exception as exc: + raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc + + # 4. Save + story = Story( + user_id=user_id, + child_profile_id=profile_id, + universe_id=universe_id, + title=result.title, + story_text=result.story_text, + cover_prompt=result.cover_prompt_suggestion, + mode=result.mode, + ) + db.add(story) + await db.commit() + await db.refresh(story) + + # 5. Trigger Async Tasks + if universe_id: + extract_story_achievements.delay(story.id, universe_id) + + return story + + +async def generate_full_story_service( + request: GenerateRequest, + user_id: str, + db: AsyncSession, +) -> FullStoryResponse: + """Generate story with parallel image generation.""" + # 1. Generate text part + # We can reuse logic or call generate_story_content directly if we want finer control + # reusing generate_and_save_story to ensure consistency (it handles validation + saving) + story = await generate_and_save_story(request, user_id, db) + + # 2. Generate Image (Parallel/Async step in this flow) + image_url: str | None = None + errors: dict[str, str | None] = {} + + if story.cover_prompt: + try: + image_url = await generate_image(story.cover_prompt, db=db) + story.image_url = image_url + await db.commit() + except Exception as exc: + errors["image"] = str(exc) + logger.warning("image_generation_failed", story_id=story.id, error=str(exc)) + + return FullStoryResponse( + id=story.id, + title=story.title, + story_text=story.story_text, + cover_prompt=story.cover_prompt, + image_url=image_url, + audio_ready=False, + mode=story.mode, + errors=errors, + child_profile_id=story.child_profile_id, + universe_id=story.universe_id, + ) + + +async def generate_storybook_service( + request: StorybookRequest, + user_id: str, + db: AsyncSession, +) -> StorybookResponse: + """Generate storybook with parallel image generation for pages.""" + # 1. Validate + profile_id, universe_id = await validate_profile_and_universe( + request.child_profile_id, request.universe_id, user_id, db + ) + + logger.info( + "storybook_request", + user_id=user_id, + keywords=request.keywords, + page_count=request.page_count, + profile_id=profile_id, + universe_id=universe_id, + ) + + # 2. Context + memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) + + # 3. Generate Text Structure + try: + storybook = await generate_storybook( + keywords=request.keywords, + page_count=request.page_count, + education_theme=request.education_theme, + memory_context=memory_context, + db=db, + ) + except Exception as e: + logger.error("storybook_generation_failed", error=str(e)) + raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}") + + # 4. Parallel Image Generation + final_cover_url = storybook.cover_url + if request.generate_images: + logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages)) + + tasks = [] + + # Cover Task + async def _gen_cover(): + if storybook.cover_prompt and not storybook.cover_url: + try: + return await generate_image(storybook.cover_prompt, db=db) + except Exception as e: + logger.warning("cover_gen_failed", error=str(e)) + return storybook.cover_url + tasks.append(_gen_cover()) + + # Page Tasks + async def _gen_page(page): + if page.image_prompt and not page.image_url: + try: + url = await generate_image(page.image_prompt, db=db) + page.image_url = url + except Exception as e: + logger.warning("page_gen_failed", page=page.page_number, error=str(e)) + + for page in storybook.pages: + tasks.append(_gen_page(page)) + + # Execute + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Update cover result + cover_res = results[0] + if isinstance(cover_res, str): + final_cover_url = cover_res + + logger.info("storybook_parallel_generation_complete") + + # 5. Save to DB + pages_data = [ + { + "page_number": p.page_number, + "text": p.text, + "image_prompt": p.image_prompt, + "image_url": p.image_url, + } + for p in storybook.pages + ] + + story = Story( + user_id=user_id, + child_profile_id=profile_id, + universe_id=universe_id, + title=storybook.title, + mode="storybook", + pages=pages_data, + story_text=None, + cover_prompt=storybook.cover_prompt, + image_url=final_cover_url, + ) + db.add(story) + await db.commit() + await db.refresh(story) + + if universe_id: + extract_story_achievements.delay(story.id, universe_id) + + # 6. Build Response + response_pages = [ + StorybookPageResponse( + page_number=p["page_number"], + text=p["text"], + image_prompt=p["image_prompt"], + image_url=p.get("image_url"), + ) + for p in pages_data + ] + + return StorybookResponse( + id=story.id, + title=storybook.title, + main_character=storybook.main_character, + art_style=storybook.art_style, + pages=response_pages, + cover_prompt=storybook.cover_prompt, + cover_url=final_cover_url, + ) + + +# ==================== Missing Endpoints Logic (for Issue #5) ==================== + +async def list_stories( + user_id: str, + limit: int, + offset: int, + db: AsyncSession, +) -> list[Story]: + """List stories for user.""" + result = await db.execute( + select(Story) + .where(Story.user_id == user_id) + .order_by(desc(Story.created_at)) + .offset(offset) + .limit(limit) + ) + return result.scalars().all() + + +async def get_story_detail( + story_id: int, + user_id: str, + db: AsyncSession, +) -> Story: + """Get story detail.""" + result = await db.execute( + select(Story).where(Story.id == story_id, Story.user_id == user_id) + ) + story = result.scalar_one_or_none() + if not story: + raise HTTPException(status_code=404, detail="Story not found") + return story + + +async def delete_story( + story_id: int, + user_id: str, + db: AsyncSession, +) -> None: + """Delete a story.""" + story = await get_story_detail(story_id, user_id, db) + await db.delete(story) + await db.commit() + + +async def create_story_from_result( + result, # StoryOutput + user_id: str, + profile_id: str | None, + universe_id: str | None, + db: AsyncSession, +) -> Story: + """Save a generated story to DB (helper for stream endpoint).""" + story = Story( + user_id=user_id, + child_profile_id=profile_id, + universe_id=universe_id, + title=result.title, + story_text=result.story_text, + cover_prompt=result.cover_prompt_suggestion, + mode=result.mode, + ) + db.add(story) + await db.commit() + await db.refresh(story) + + if universe_id: + extract_story_achievements.delay(story.id, universe_id) + + return story + + +async def generate_story_cover( + story_id: int, + user_id: str, + db: AsyncSession, +) -> str: + """Generate cover image for an existing story.""" + story = await get_story_detail(story_id, user_id, db) + + if not story.cover_prompt: + raise HTTPException(status_code=400, detail="Story has no cover prompt") + + try: + image_url = await generate_image(story.cover_prompt, db=db) + story.image_url = image_url + await db.commit() + return image_url + except Exception as e: + logger.error("cover_generation_failed", story_id=story_id, error=str(e)) + raise HTTPException(status_code=500, detail=f"Image generation failed: {e}") + + +async def generate_story_audio( + story_id: int, + user_id: str, + db: AsyncSession, +) -> bytes: + """Generate audio for a story.""" + story = await get_story_detail(story_id, user_id, db) + + if not story.story_text: + raise HTTPException(status_code=400, detail="Story has no text") + + # TODO: Check if audio is already cached/saved? + # For now, generate on the fly via provider + from app.services.provider_router import text_to_speech + + try: + audio_data = await text_to_speech(story.story_text, db=db) + return audio_data + except Exception as e: + logger.error("audio_generation_failed", story_id=story_id, error=str(e)) + raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}") + + +async def get_story_achievements( + story_id: int, + user_id: str, + db: AsyncSession, +) -> list[AchievementItem]: + """Get achievements unlocked by a specific story.""" + result = await db.execute( + select(Story) + .options(joinedload(Story.story_universe)) + .where(Story.id == story_id, Story.user_id == user_id) + ) + story = result.scalar_one_or_none() + + if not story: + raise HTTPException(status_code=404, detail="Story not found") + + if not story.universe_id or not story.story_universe: + return [] + + universe = story.story_universe + if not universe.achievements: + return [] + + results = [] + for ach in universe.achievements: + if isinstance(ach, dict) and ach.get("source_story_id") == story_id: + results.append(AchievementItem( + type=ach.get("type", "Unknown"), + description=ach.get("description", ""), + obtained_at=ach.get("obtained_at") + )) + return results + diff --git a/backend/docs/memory_system_prd.md b/backend/docs/memory_system_prd.md deleted file mode 100644 index 4076ad8..0000000 --- a/backend/docs/memory_system_prd.md +++ /dev/null @@ -1,93 +0,0 @@ -# 梦语织机 (DreamWeaver) 记忆系统升级 PRD -> 版本: v1.0 | 状态: 规划中 | 优先级: High - -## 1. 核心愿景 (Vision) - -将当前的"数据存储"升级为有温度的**"情感连接系统"**。 -我们不只是在记住数据,而是在**维护孩子与故事世界的关系**。让每一个故事不再是孤立的碎片,而是构建孩子专属"故事宇宙"的砖瓦。 - ---- - -## 2. 产品痛点与解决方案 - -| 用户角色 | 核心痛点 | 解决方案 | 预期价值 | -|---------|---------|---------|---------| -| **孩子** | "上次的小兔子怎么不认识我了?"
故事之间缺乏连续性,只有单次体验。 | **角色一致性与记忆注入**
故事开头主动提及往事,角色性格延续。 | 建立情感依恋,提升沉浸感。 | -| **家长** | "这App除了生成故事还能干嘛?"
无法感知产品的长期教育价值。 | **显性化成长轨迹**
词汇量统计、主题变化、成就徽章可视化。 | 提高付费意愿,提供社交货币。 | -| **平台** | 用户用完即走,缺乏留存壁垒。 | **沉没成本与情感资产**
积累的记忆越多,越舍不得离开。 | 提升长期留存率 (LTV)。 | - ---- - -## 3. 功能架构:记忆分层模型 - -### 3.1 层级 1: 核心档案 (Identity Layer) -*性质:永久、静态、显性* -* **数据**: 姓名、年龄、性别。 -* **输入**: 家长在 Onboarding 阶段手动输入。 -* **作用**: 决定故事的基础适龄性和称呼。 - -### 3.2 层级 2: 故事宇宙 (Universe Layer) -*性质:长期、动态积累、半显性* -* **主角设定**: 姓名、性格特征(勇敢/害羞)、外貌特征(戴眼镜/卷发)。 -* **常驻配角**: 从随机故事中涌现出的固定伙伴(如"爱吃胡萝卜的松鼠奇奇")。 -* **世界观**: 故事发生的背景(魔法森林、未来城市、海底世界)。 -* **成就系统**: 孩子获得的虚拟奖励(勇气勋章、小小探险家)。 - -### 3.3 层级 3: 工作记忆 (Working Memory) -*性质:短期、自动衰减、隐性* -* **关键情节**: 最近 3 个故事的结局和核心冲突。 -* **情感标记**: 孩子对特定内容的反应(根据“重播”、“跳过”推断)。 -* **新学词汇**: 故事中出现的高级词汇。 - ---- - -## 4. 关键功能特性 (Feature Specs) - -### 4.1 智能开场白 (Memory Injection) -在生成新故事时,Prompt 必须包含一段"记忆唤醒"指令。 -* **示例**: "小明,还记得上周我们帮小松鼠找回了松果吗?今天,小松鼠带来了一位新朋友..." -* **策略**: 提取权重最高的 Top 3 记忆注入 Prompt。 - -### 4.2 成长时间轴 (Growth Timeline) -一个可视化的 H5 页面或 App 模块,以时间轴形式展示里程碑。 -* **节点类型**: - * 🌟 **初次相遇**: 创建角色的第一天。 - * 📖 **阅读打卡**: 累计阅读 10/50/100 本。 - * 🏅 **获得成就**: 获得"诚实勋章"。 - * 🧠 **能力解锁**: 第一次阅读"科幻"题材。 - -### 4.3 成就仪式感 (Achievement Ceremony) -* **触发**: 故事生成并分析后,如果获得新成就。 -* **表现**: 弹窗动画 + 音效 + "恭喜获得 [勇气] 徽章"。 -* **分享**:允许生成带二维码的成就海报。 - ---- - -## 5. 记忆类型扩展 (Memory Types) - -| 类型 Key | 描述 | 来源 | 过期策略 | -|---------|------|------|---------| -| `recent_story` | 最近读过的故事梗概 | 阅读事件 | 30天衰减 | -| `favorite_character` | 孩子喜欢的角色 | 重播/高评分 | 长期有效 | -| `scary_element` | 孩子害怕/不喜欢的元素 | 跳过/负反馈 | 长期有效 (避雷) | -| `vocabulary_growth` | 新掌握的词汇 | 故事分析 | 90天衰减 | -| `emotional_highlight` | 高光时刻 (如: 特别开心的情节) | 互动数据 | 60天衰减 | - ---- - -## 6. 实施路线图 (Roadmap) - -### Phase 1: 基础建设 (v0.3.0) -* [x] 数据库 `MemoryItem` 表 (已存在)。 -* [ ] 扩展 `MemoryItem` 类型字段,支持更多维度。 -* [ ] 优化 `_build_memory_context`,支持更自然的 Prompt 注入。 -* [ ] 前端:简单的"近期回忆"展示列表。 - -### Phase 2: 可视化与成就 (v0.4.0) -* [ ] 实现"成就提取器" (Achievement Extractor) 的闭环通知。 -* [ ] 前端:开发"我的成就"和"成长时间轴"页面。 -* [ ] 增加故事开场白的动态生成逻辑。 - -### Phase 3: 深度智能 (v0.5.0+) -* [ ] 引入向量数据库,实现基于语义的记忆检索 (不仅是时间最近)。 -* [ ] 情感分析模型:分析用户行为推断情感倾向。 diff --git a/backend/docs/refactoring_plan.md b/backend/docs/refactoring_plan.md new file mode 100644 index 0000000..73a59f2 --- /dev/null +++ b/backend/docs/refactoring_plan.md @@ -0,0 +1,98 @@ +# DreamWeaver 重构实施计划 + +## 1. 概述 + +本文档基于对当前架构的深入分析,制定了从稳定性、可维护性到可扩展性的分阶段重构计划。 + +**目标**: +- **短期**:解决单点故障风险,优化开发体验,清理关键技术债。 +- **中期**:提升系统高可用能力,增强监控与可观测性。 +- **长期**:架构演进,支持大规模并发与复杂业务场景。 + +--- + +## 2. 短期优化计划 (1-2周) + +**重点**:消除即时风险,提升部署效率。 + +### 2.1 统一镜像构建 (High Priority) +目前 `backend`, `backend-admin`, `worker`, `celery-beat` 重复构建 4 次,浪费资源且镜像版本可能不一致。 + +- **Action Items**: + - [ ] 修改 `backend/Dockerfile` 为通用基础镜像。 + - [ ] 更新 `docker-compose.yml`,定义 `backend-base` 服务或使用 `image` 标签共享镜像。 + - [ ] 确保所有 Python 服务共用同一构建产物,仅启动命令不同。 + +### 2.2 修复 Provider 缓存与限流 (High Priority) +内存缓存 (`TTLCache`, `_latency_cache`) 在多进程/多实例下失效。 + +- **Action Items**: + - [ ] 引入 Redis 作为共享缓存后端。 + - [ ] 重构 `_load_provider_cache`,将 Provider 配置缓存至 Redis。 + - [ ] 重构 `stories.py` 中的限流逻辑,使用 `redis-cell` 或简单的 Redis 计数器替代 `TTLCache`。 + +### 2.3 拆分 `stories.py` (Medium Priority) +`app/api/stories.py` 超过 600 行,包含 API 定义、业务逻辑、验证逻辑,维护困难。 + +- **Action Items**: + - [ ] 创建 `app/services/story_service.py`,迁移生成、润色、PDF生成等核心逻辑。 + - [ ] 创建 `app/schemas/story_schema.py`,迁移 Pydantic 模型(`GenerateRequest`, `StoryResponse` 等)。 + - [ ] API 层 `stories.py` 仅保留路由定义和依赖注入,调用 Service 层。 + +--- + +## 3. 中期优化计划 (1-2月) + +**重点**:高可用 (HA) 与系统韧性。 + +### 3.1 数据库高可用 (Critical) +当前 PostgreSQL 为单点,且 Admin/User 混合使用。 + +- **Action Items**: + - [ ] 部署 PostgreSQL 主从复制 (Master-Slave)。 + - [ ] 配置 `PgBouncer` 或 SQLAlchemy 读写分离,减轻主库压力。 + - [ ] 实施数据库自动备份策略 (如 `pg_dump` 定时上传 S3)。 + +### 3.2 消息队列高可用 (Critical) +Redis 单点故障将导致 Celery 任务全盘停摆。 + +- **Action Items**: + - [ ] 迁移至 Redis Sentinel 或 Redis Cluster 模式。 + - [ ] 更新 Celery 配置以支持 Sentinel/Cluster 连接串。 + +### 3.3 增强可观测性 (Important) +目前仅有简单的日志,缺乏系统级指标。 + +- **Action Items**: + - [ ] 集成 Prometheus Client,暴露 `/metrics` 端点。 + - [ ] 部署 Grafana + Prometheus,监控 API 延迟、QPS、Celery 队列积压情况。 + - [ ] 完善 `ProviderMetrics`,增加可视化大盘,实时监控 AI 供应商的成本与成功率。 + +--- + +## 4. 长期架构演进 (季度规划) + +**重点**:业务解耦与规模化。 + +### 4.1 统一 API 网关 +- **当前**:前端直连后端端口,CORS 配置分散。 +- **演进**:引入 Traefik 或 Nginx 作为统一网关,管理路由、SSL、全局限流、统一鉴权。 + +### 4.2 前端工程合并 +- **当前**:User App 和 Admin Console 是完全独立的两个项目,但在组件和工具链上高度重复。 +- **演进**:使用一种 Monorepo 策略或基于路由的单一应用策略,共享组件库和类型定义,减少维护成本。 + +### 4.3 事件驱动架构完善 +- **当前**:部分业务逻辑耦合在 API 中。 +- **演进**:扩展事件总线,将“阅读记录”、“成就解锁”、“通知推送”等非核心链路完全异步化,通过 Domain Events 解耦。 + +--- + +## 5. 实施路线图 + +| 阶段 | 时间估算 | 关键里程碑 | +| :--- | :--- | :--- | +| **Phase 1: 基础夯实** | Week 1-2 | Docker 构建优化上线,Redis 替代内存缓存。 | +| **Phase 2: 代码重构** | Week 3-4 | `stories.py` 拆分完成,Service 层建立。 | +| **Phase 3: 高可用建设** | Month 2 | 数据库与 Redis 实现主备/集群模式。 | +| **Phase 4: 监控体系** | Month 2 | Grafana 监控大盘上线,关键指标报警配置完毕。 | diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index a048af1..19786a0 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -121,7 +121,7 @@ def mock_text_provider(): cover_prompt_suggestion="A cute rabbit", ) - with patch("app.api.stories.generate_story_content", new_callable=AsyncMock) as mock: + with patch("app.services.story_service.generate_story_content", new_callable=AsyncMock) as mock: mock.return_value = mock_result yield mock @@ -129,7 +129,7 @@ def mock_text_provider(): @pytest.fixture def mock_image_provider(): """Mock 图像生成。""" - with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock: + with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock: mock.return_value = "https://example.com/image.png" yield mock @@ -137,7 +137,7 @@ def mock_image_provider(): @pytest.fixture def mock_tts_provider(): """Mock TTS。""" - with patch("app.api.stories.text_to_speech", new_callable=AsyncMock) as mock: + with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock: mock.return_value = b"fake-audio-bytes" yield mock diff --git a/backend/tests/test_stories.py b/backend/tests/test_stories.py index b34b840..121f3fd 100644 --- a/backend/tests/test_stories.py +++ b/backend/tests/test_stories.py @@ -57,7 +57,6 @@ class TestStoryGenerate: assert data["mode"] == "generated" -@pytest.mark.skip(reason="GET /api/stories (列表) 端点尚未实现") class TestStoryList: """故事列表测试。""" @@ -94,7 +93,6 @@ class TestStoryList: assert len(data) == 0 -@pytest.mark.skip(reason="GET /api/stories/{id} (详情) 端点尚未实现") class TestStoryDetail: """故事详情测试。""" @@ -118,7 +116,6 @@ class TestStoryDetail: assert data["story_text"] == test_story.story_text -@pytest.mark.skip(reason="DELETE /api/stories/{id} (删除) 端点尚未实现") class TestStoryDelete: """故事删除测试。""" @@ -168,7 +165,6 @@ class TestRateLimit: assert "Too many requests" in response.json()["detail"] -@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现") class TestImageGenerate: """封面图片生成测试。""" @@ -183,7 +179,6 @@ class TestImageGenerate: assert response.status_code == 404 -@pytest.mark.skip(reason="GET /api/audio/{id} 端点尚未实现") class TestAudio: """语音朗读测试。""" @@ -233,7 +228,7 @@ class TestGenerateFull: def test_generate_full_image_failure(self, auth_client: TestClient, mock_text_provider): """图片生成失败时返回部分成功。""" - with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock_img: + with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_img: mock_img.side_effect = Exception("Image API error") response = auth_client.post( "/api/stories/generate/full", @@ -261,7 +256,6 @@ class TestGenerateFull: assert call_kwargs["education_theme"] == "勇气与友谊" -@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现") class TestImageGenerateSuccess: """封面图片生成成功测试。"""