Refactor Phase 2: Split stories.py into Schema/Service/Controller, add missing endpoints, fix async bug

This commit is contained in:
zhangtuo
2026-02-10 17:14:54 +08:00
parent c351d16d3e
commit 9cdff18336
13 changed files with 881 additions and 612 deletions

View File

@@ -2,136 +2,41 @@
import asyncio
import json
import time
import uuid
from typing import AsyncGenerator, Literal
from typing import AsyncGenerator
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sse_starlette.sse import EventSourceResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.core.logging import get_logger
from app.db.database import get_db
from app.db.models import ChildProfile, Story, StoryUniverse, User
from app.services.provider_router import (
generate_image,
generate_story_content,
generate_storybook,
text_to_speech,
)
from app.core.rate_limiter import check_rate_limit
from app.tasks.achievements import extract_story_achievements
from app.db.database import get_db
from app.db.models import User
from app.schemas.story_schemas import (
GenerateRequest,
StoryResponse,
FullStoryResponse,
StorybookRequest,
StorybookResponse,
StoryListItem,
AchievementItem,
)
from app.services import story_service
from app.services.memory_service import build_enhanced_memory_context
from app.services.provider_router import (
generate_story_content,
generate_image,
)
logger = get_logger(__name__)
router = APIRouter()
MAX_DATA_LENGTH = 2000
MAX_EDU_THEME_LENGTH = 200
MAX_TTS_LENGTH = 4000
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10
class GenerateRequest(BaseModel):
"""Story generation request."""
type: Literal["keywords", "full_story"]
data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
child_profile_id: str | None = None
universe_id: str | None = None
class StoryResponse(BaseModel):
"""Story response."""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
mode: str
child_profile_id: str | None = None
universe_id: str | None = None
class StoryListItem(BaseModel):
"""Story list item."""
id: int
title: str
image_url: str | None
created_at: str
mode: str
class FullStoryResponse(BaseModel):
"""完整故事响应(含图片和音频状态)。"""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
audio_ready: bool
mode: str
errors: dict[str, str | None] = Field(default_factory=dict)
child_profile_id: str | None = None
universe_id: str | None = None
from app.services.memory_service import build_enhanced_memory_context
async def _validate_profile_and_universe(
request: GenerateRequest,
user: User,
db: AsyncSession,
) -> tuple[str | None, str | None]:
if not request.child_profile_id and not request.universe_id:
return None, None
profile_id = request.child_profile_id
universe_id = request.universe_id
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
return profile_id, universe_id
@router.post("/stories/generate", response_model=StoryResponse)
async def generate_story(
request: GenerateRequest,
@@ -140,47 +45,7 @@ async def generate_story(
):
"""Generate or enhance a story."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return StoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=story.image_url,
mode=story.mode,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
return await story_service.generate_and_save_story(request, user.id, db)
@router.post("/stories/generate/full", response_model=FullStoryResponse)
@@ -189,71 +54,9 @@ async def generate_story_full(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""生成完整故事(故事 + 并行生成图片和音频)。
部分成功策略:故事必须成功,图片/音频失败不影响整体。
"""
"""Generate complete story (story + parallel image/audio generation)."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# Step 1: 故事生成(必须成功)
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except Exception as exc:
logger.error("story_generation_failed", error=str(exc))
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
# 保存故事
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# Step 2: 生成封面图片(音频按需生成,避免浪费)
errors: dict[str, str | None] = {}
image_url: str | None = None
if story.cover_prompt:
try:
image_url = await generate_image(story.cover_prompt)
story.image_url = image_url
await db.commit()
except Exception as exc:
errors["image"] = str(exc)
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
# 注意:音频不在此处预生成,用户通过 /api/audio/{id} 按需获取
# 这样避免生成后丢弃造成的成本浪费
return FullStoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=image_url,
audio_ready=False, # 音频需要用户主动请求
mode=story.mode,
errors=errors,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
return await story_service.generate_full_story_service(request, user.id, db)
@router.post("/stories/generate/stream")
@@ -263,53 +66,39 @@ async def generate_story_stream(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""流式生成故事SSE
事件流程:
- started: 返回 story_id
- story_ready: 返回 title, content
- story_failed: 返回 error
- image_ready: 返回 image_url
- image_failed: 返回 error
- complete: 结束流
"""
"""流式生成故事SSE"""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
# Validation
profile_id, universe_id = await story_service.validate_profile_and_universe(
request.child_profile_id, request.universe_id, user.id, db
)
# Build Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
async def event_generator() -> AsyncGenerator[dict, None]:
story_id = str(uuid.uuid4())
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
# Step 1: 生成故事
# Step 1: Generate Content
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("sse_story_generation_failed", error=str(e))
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
return
# 保存故事
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
# Save Story
story = await story_service.create_story_from_result(
result, user.id, profile_id, universe_id, db
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
yield {
"event": "story_ready",
@@ -324,10 +113,11 @@ async def generate_story_stream(
}),
}
# Step 2: 并行生成图片(音频按需)
# Step 2: Generate Image
if story.cover_prompt:
try:
image_url = await generate_image(story.cover_prompt)
# Direct call to provider router's generate_image, sharing db session
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
await db.commit()
yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})}
@@ -340,217 +130,71 @@ async def generate_story_stream(
return EventSourceResponse(event_generator())
# ==================== Storybook API ====================
class StorybookRequest(BaseModel):
"""Storybook 生成请求。"""
keywords: str = Field(..., min_length=1, max_length=200)
page_count: int = Field(default=6, ge=4, le=12)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
generate_images: bool = Field(default=False, description="是否同时生成插图")
child_profile_id: str | None = None
universe_id: str | None = None
class StorybookPageResponse(BaseModel):
"""故事书单页响应。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
class StorybookResponse(BaseModel):
"""故事书响应。"""
id: int | None = None
title: str
main_character: str
art_style: str
pages: list[StorybookPageResponse]
cover_prompt: str
cover_url: str | None = None
@router.post("/storybook/generate", response_model=StorybookResponse)
async def generate_storybook_api(
request: StorybookRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""生成分页故事书并保存。
返回故事书结构,包含每页文字和图像提示词。
"""
"""Generate storybook."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
# 验证档案和宇宙
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数
# 这里我们直接手动验证,或重构验证函数。为了简单,手动调用部分逻辑。
# 构建临时的 GenerateRequest 用于验证验证函数签名(或者直接手动查库更好)
profile_id = request.child_profile_id
universe_id = request.universe_id
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
if not result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
logger.info(
"storybook_request",
user_id=user.id,
keywords=request.keywords,
page_count=request.page_count,
profile_id=profile_id,
universe_id=universe_id,
)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
try:
# 注意generate_storybook 目前可能不支持记忆上下文注入
# 我们需要看看 generate_storybook 的签名
# 如果不支持,记忆功能在绘本模式下暂不可用,但基本参数传递是支持的
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
# ==============================================================================
# 核心升级: 并行全量生成 (Parallel Full Rendering)
# ==============================================================================
final_cover_url = storybook.cover_url
if request.generate_images:
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
# 1. 准备所有生图任务 (封面 + 所有内页)
tasks = []
# 封面任务
async def _gen_cover():
if storybook.cover_prompt and not storybook.cover_url:
try:
return await generate_image(storybook.cover_prompt, db=db)
except Exception as e:
logger.warning("cover_gen_failed", error=str(e))
return storybook.cover_url
tasks.append(_gen_cover())
# 内页任务
async def _gen_page(page):
if page.image_prompt and not page.image_url:
try:
url = await generate_image(page.image_prompt, db=db)
page.image_url = url
except Exception as e:
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
for page in storybook.pages:
tasks.append(_gen_page(page))
# 2. 并发执行所有任务
# 使用 return_exceptions=True 防止单张失败影响整体
results = await asyncio.gather(*tasks, return_exceptions=True)
# 3. 更新封面结果 (results[0] 是封面任务的返回值)
cover_res = results[0]
if isinstance(cover_res, str):
final_cover_url = cover_res
logger.info("storybook_parallel_generation_complete")
# ==============================================================================
# 构建并保存 Story 对象
# 将 pages 对象转换为字典列表以存入 JSON 字段
pages_data = [
{
"page_number": p.page_number,
"text": p.text,
"image_prompt": p.image_prompt,
"image_url": p.image_url,
}
for p in storybook.pages
]
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=storybook.title,
mode="storybook",
pages=pages_data, # 存入 JSON 字段
story_text=None, # 绘本模式下,主文本可为空,或者可以存个摘要
cover_prompt=storybook.cover_prompt,
image_url=final_cover_url,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# 构建响应 (使用更新后的 pages_data)
response_pages = [
StorybookPageResponse(
page_number=p["page_number"],
text=p["text"],
image_prompt=p["image_prompt"],
image_url=p.get("image_url"),
)
for p in pages_data
]
return StorybookResponse(
id=story.id,
title=storybook.title,
main_character=storybook.main_character,
art_style=storybook.art_style,
pages=response_pages,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
)
return await story_service.generate_storybook_service(request, user.id, db)
class AchievementItem(BaseModel):
type: str
description: str
obtained_at: str | None = None
# ==================== Missing Endpoints (Issue #5) ====================
@router.get("/stories", response_model=list[StoryListItem])
async def list_stories(
limit: int = 20,
offset: int = 0,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List stories."""
return await story_service.list_stories(user.id, limit, offset, db)
@router.get("/stories/{story_id}", response_model=StoryResponse)
async def get_story(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story detail."""
return await story_service.get_story_detail(story_id, user.id, db)
@router.delete("/stories/{story_id}")
async def delete_story(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Delete story."""
await story_service.delete_story(story_id, user.id, db)
return {"message": "Deleted"}
@router.post("/image/generate/{story_id}")
async def generate_story_image(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate cover image for story."""
url = await story_service.generate_story_cover(story_id, user.id, db)
return {"image_url": url}
@router.get("/audio/{story_id}")
async def get_story_audio(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story audio (MP3)."""
audio_bytes = await story_service.generate_story_audio(story_id, user.id, db)
return Response(content=audio_bytes, media_type="audio/mpeg")
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
@@ -559,32 +203,5 @@ async def get_story_achievements(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get achievements unlocked by a specific story."""
# 使用 joinedload 避免 N+1 查询
result = await db.execute(
select(Story)
.options(joinedload(Story.story_universe))
.where(Story.id == story_id, Story.user_id == user.id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
if not story.universe_id or not story.story_universe:
return []
universe = story.story_universe
if not universe.achievements:
return []
results = []
for ach in universe.achievements:
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
results.append(AchievementItem(
type=ach.get("type", "Unknown"),
description=ach.get("description", ""),
obtained_at=ach.get("obtained_at")
))
return results
"""Get story achievements."""
return await story_service.get_story_achievements(story_id, user.id, db)

View File

@@ -52,6 +52,9 @@ class Settings(BaseSettings):
# Celery (Redis)
celery_broker_url: str = Field("redis://localhost:6379/0")
celery_result_backend: str = Field("redis://localhost:6379/0")
# Generic Redis
redis_url: str = Field("redis://localhost:6379/0", description="Redis connection URL")
# Admin console
enable_admin_console: bool = False

25
backend/app/core/redis.py Normal file
View File

@@ -0,0 +1,25 @@
"""Redis client module."""
from typing import AsyncGenerator
from redis.asyncio import Redis, from_url
from app.core.config import settings
_redis_pool: Redis | None = None
async def get_redis() -> Redis:
"""Get global Redis client instance."""
global _redis_pool
if _redis_pool is None:
_redis_pool = from_url(settings.redis_url, encoding="utf-8", decode_responses=True)
return _redis_pool
async def close_redis():
"""Close Redis connection."""
global _redis_pool
if _redis_pool:
await _redis_pool.close()
_redis_pool = None

View File

@@ -0,0 +1 @@
"""故事相关 Schema 模块。"""

View File

@@ -0,0 +1,106 @@
"""故事相关 Pydantic 模型。"""
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, Field
MAX_DATA_LENGTH = 2000
MAX_EDU_THEME_LENGTH = 200
MAX_TTS_LENGTH = 4000
# ==================== 故事模型 ====================
class GenerateRequest(BaseModel):
"""Story generation request."""
type: Literal["keywords", "full_story"]
data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
child_profile_id: str | None = None
universe_id: str | None = None
class StoryResponse(BaseModel):
"""Story response."""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
mode: str
child_profile_id: str | None = None
universe_id: str | None = None
class StoryListItem(BaseModel):
"""Story list item."""
id: int
title: str
image_url: str | None
created_at: datetime
mode: str
class FullStoryResponse(BaseModel):
"""完整故事响应(含图片和音频状态)。"""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
audio_ready: bool
mode: str
errors: dict[str, str | None] = Field(default_factory=dict)
child_profile_id: str | None = None
universe_id: str | None = None
# ==================== 绘本模型 ====================
class StorybookRequest(BaseModel):
"""Storybook 生成请求。"""
keywords: str = Field(..., min_length=1, max_length=200)
page_count: int = Field(default=6, ge=4, le=12)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
generate_images: bool = Field(default=False, description="是否同时生成插图")
child_profile_id: str | None = None
universe_id: str | None = None
class StorybookPageResponse(BaseModel):
"""故事书单页响应。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
class StorybookResponse(BaseModel):
"""故事书响应。"""
id: int | None = None
title: str
main_character: str
art_style: str
pages: list[StorybookPageResponse]
cover_prompt: str
cover_url: str | None = None
# ==================== 成就模型 ====================
class AchievementItem(BaseModel):
type: str
description: str
obtained_at: str | None = None

View File

@@ -1,31 +1,109 @@
"""In-memory cache for providers loaded from DB."""
"""Redis-backed cache for providers loaded from DB."""
import json
from collections import defaultdict
from typing import Literal
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.core.redis import get_redis
from app.db.admin_models import Provider
logger = get_logger(__name__)
ProviderType = Literal["text", "image", "tts", "storybook"]
_cache: dict[ProviderType, list[Provider]] = defaultdict(list)
class CachedProvider(BaseModel):
"""Serializable provider configuration matching DB model fields."""
id: str
name: str
type: str
adapter: str
model: str | None = None
api_base: str | None = None
api_key: str | None = None
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_json: dict | None = None
config_ref: str | None = None
async def reload_providers(db: AsyncSession):
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
providers = result.scalars().all()
grouped: dict[ProviderType, list[Provider]] = defaultdict(list)
for p in providers:
grouped[p.type].append(p)
# sort by priority desc, then weight desc
for k in grouped:
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
_cache.clear()
_cache.update(grouped)
return _cache
# Local memory fallback (L1 cache)
_local_cache: dict[ProviderType, list[CachedProvider]] = defaultdict(list)
CACHE_KEY = "dreamweaver:providers:config"
def get_providers(provider_type: ProviderType) -> list[Provider]:
return _cache.get(provider_type, [])
async def reload_providers(db: AsyncSession) -> dict[ProviderType, list[CachedProvider]]:
"""Reload providers from DB and update Redis cache."""
try:
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
providers = result.scalars().all()
# Convert to Pydantic models
cached_list = []
for p in providers:
cached_list.append(CachedProvider(
id=p.id,
name=p.name,
type=p.type,
adapter=p.adapter,
model=p.model,
api_base=p.api_base,
api_key=p.api_key,
timeout_ms=p.timeout_ms,
max_retries=p.max_retries,
weight=p.weight,
priority=p.priority,
enabled=p.enabled,
config_json=p.config_json,
config_ref=p.config_ref
))
# Group by type
grouped: dict[str, list[CachedProvider]] = defaultdict(list)
for cp in cached_list:
grouped[cp.type].append(cp)
# Sort
for k in grouped:
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
# Update Redis
redis = await get_redis()
# Serialize entire dict structure
# Pydantic -> dict -> json
json_data = {k: [p.model_dump() for p in v] for k, v in grouped.items()}
await redis.set(CACHE_KEY, json.dumps(json_data))
# Update local cache
_local_cache.clear()
_local_cache.update(grouped)
return grouped
except Exception as e:
logger.error("failed_to_reload_providers", error=str(e))
raise
async def get_providers(provider_type: ProviderType) -> list[CachedProvider]:
"""Get providers from Redis (preferred) or local fallback."""
try:
redis = await get_redis()
data = await redis.get(CACHE_KEY)
if data:
raw_dict = json.loads(data)
if provider_type in raw_dict:
return [CachedProvider(**item) for item in raw_dict[provider_type]]
return []
except Exception as e:
logger.warning("redis_cache_read_failed", error=str(e))
# Fallback to local memory
return _local_cache.get(provider_type, [])

View File

@@ -177,7 +177,7 @@ def _build_config_from_provider(provider: "Provider") -> AdapterConfig:
)
def _get_providers_with_config(
async def _get_providers_with_config(
provider_type: ProviderType,
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
"""获取供应商列表及其配置。
@@ -185,7 +185,7 @@ def _get_providers_with_config(
Returns:
[(adapter_name, config, provider_or_none), ...] 按优先级排序
"""
db_providers = get_providers(provider_type)
db_providers = await get_providers(provider_type)
if db_providers:
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
@@ -265,7 +265,7 @@ async def _route_with_failover(
user_id: 用户 ID可选用于成本追踪和预算检查
**kwargs: 传递给适配器的参数
"""
providers = _get_providers_with_config(provider_type)
providers = await _get_providers_with_config(provider_type)
if not providers:
raise ValueError(f"No {provider_type} providers configured.")

View File

@@ -0,0 +1,434 @@
"""Story business logic service."""
import asyncio
import json
import uuid
from typing import Literal
from fastapi import HTTPException
from sqlalchemy import select, desc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.core.logging import get_logger
from app.db.models import ChildProfile, Story, StoryUniverse
from app.schemas.story_schemas import (
GenerateRequest,
StorybookRequest,
FullStoryResponse,
StorybookResponse,
StorybookPageResponse,
AchievementItem,
)
from app.services.memory_service import build_enhanced_memory_context
from app.services.provider_router import (
generate_story_content,
generate_image,
generate_storybook,
)
from app.tasks.achievements import extract_story_achievements
logger = get_logger(__name__)
async def validate_profile_and_universe(
profile_id: str | None,
universe_id: str | None,
user_id: str,
db: AsyncSession,
) -> tuple[str | None, str | None]:
"""Validate child profile and universe ownership/relationship."""
if not profile_id and not universe_id:
return None, None
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user_id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user_id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
return profile_id, universe_id
async def generate_and_save_story(
request: GenerateRequest,
user_id: str,
db: AsyncSession,
) -> Story:
"""Generate generic story content and save to DB."""
# 1. Validate
profile_id, universe_id = await validate_profile_and_universe(
request.child_profile_id, request.universe_id, user_id, db
)
# 2. Build Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# 3. Generate
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as exc:
raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc
# 4. Save
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
# 5. Trigger Async Tasks
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return story
async def generate_full_story_service(
request: GenerateRequest,
user_id: str,
db: AsyncSession,
) -> FullStoryResponse:
"""Generate story with parallel image generation."""
# 1. Generate text part
# We can reuse logic or call generate_story_content directly if we want finer control
# reusing generate_and_save_story to ensure consistency (it handles validation + saving)
story = await generate_and_save_story(request, user_id, db)
# 2. Generate Image (Parallel/Async step in this flow)
image_url: str | None = None
errors: dict[str, str | None] = {}
if story.cover_prompt:
try:
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
await db.commit()
except Exception as exc:
errors["image"] = str(exc)
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
return FullStoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=image_url,
audio_ready=False,
mode=story.mode,
errors=errors,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
async def generate_storybook_service(
request: StorybookRequest,
user_id: str,
db: AsyncSession,
) -> StorybookResponse:
"""Generate storybook with parallel image generation for pages."""
# 1. Validate
profile_id, universe_id = await validate_profile_and_universe(
request.child_profile_id, request.universe_id, user_id, db
)
logger.info(
"storybook_request",
user_id=user_id,
keywords=request.keywords,
page_count=request.page_count,
profile_id=profile_id,
universe_id=universe_id,
)
# 2. Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# 3. Generate Text Structure
try:
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
# 4. Parallel Image Generation
final_cover_url = storybook.cover_url
if request.generate_images:
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
tasks = []
# Cover Task
async def _gen_cover():
if storybook.cover_prompt and not storybook.cover_url:
try:
return await generate_image(storybook.cover_prompt, db=db)
except Exception as e:
logger.warning("cover_gen_failed", error=str(e))
return storybook.cover_url
tasks.append(_gen_cover())
# Page Tasks
async def _gen_page(page):
if page.image_prompt and not page.image_url:
try:
url = await generate_image(page.image_prompt, db=db)
page.image_url = url
except Exception as e:
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
for page in storybook.pages:
tasks.append(_gen_page(page))
# Execute
results = await asyncio.gather(*tasks, return_exceptions=True)
# Update cover result
cover_res = results[0]
if isinstance(cover_res, str):
final_cover_url = cover_res
logger.info("storybook_parallel_generation_complete")
# 5. Save to DB
pages_data = [
{
"page_number": p.page_number,
"text": p.text,
"image_prompt": p.image_prompt,
"image_url": p.image_url,
}
for p in storybook.pages
]
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=storybook.title,
mode="storybook",
pages=pages_data,
story_text=None,
cover_prompt=storybook.cover_prompt,
image_url=final_cover_url,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# 6. Build Response
response_pages = [
StorybookPageResponse(
page_number=p["page_number"],
text=p["text"],
image_prompt=p["image_prompt"],
image_url=p.get("image_url"),
)
for p in pages_data
]
return StorybookResponse(
id=story.id,
title=storybook.title,
main_character=storybook.main_character,
art_style=storybook.art_style,
pages=response_pages,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
)
# ==================== Missing Endpoints Logic (for Issue #5) ====================
async def list_stories(
user_id: str,
limit: int,
offset: int,
db: AsyncSession,
) -> list[Story]:
"""List stories for user."""
result = await db.execute(
select(Story)
.where(Story.user_id == user_id)
.order_by(desc(Story.created_at))
.offset(offset)
.limit(limit)
)
return result.scalars().all()
async def get_story_detail(
story_id: int,
user_id: str,
db: AsyncSession,
) -> Story:
"""Get story detail."""
result = await db.execute(
select(Story).where(Story.id == story_id, Story.user_id == user_id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
return story
async def delete_story(
story_id: int,
user_id: str,
db: AsyncSession,
) -> None:
"""Delete a story."""
story = await get_story_detail(story_id, user_id, db)
await db.delete(story)
await db.commit()
async def create_story_from_result(
result, # StoryOutput
user_id: str,
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
) -> Story:
"""Save a generated story to DB (helper for stream endpoint)."""
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return story
async def generate_story_cover(
story_id: int,
user_id: str,
db: AsyncSession,
) -> str:
"""Generate cover image for an existing story."""
story = await get_story_detail(story_id, user_id, db)
if not story.cover_prompt:
raise HTTPException(status_code=400, detail="Story has no cover prompt")
try:
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
await db.commit()
return image_url
except Exception as e:
logger.error("cover_generation_failed", story_id=story_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
async def generate_story_audio(
story_id: int,
user_id: str,
db: AsyncSession,
) -> bytes:
"""Generate audio for a story."""
story = await get_story_detail(story_id, user_id, db)
if not story.story_text:
raise HTTPException(status_code=400, detail="Story has no text")
# TODO: Check if audio is already cached/saved?
# For now, generate on the fly via provider
from app.services.provider_router import text_to_speech
try:
audio_data = await text_to_speech(story.story_text, db=db)
return audio_data
except Exception as e:
logger.error("audio_generation_failed", story_id=story_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
async def get_story_achievements(
story_id: int,
user_id: str,
db: AsyncSession,
) -> list[AchievementItem]:
"""Get achievements unlocked by a specific story."""
result = await db.execute(
select(Story)
.options(joinedload(Story.story_universe))
.where(Story.id == story_id, Story.user_id == user_id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
if not story.universe_id or not story.story_universe:
return []
universe = story.story_universe
if not universe.achievements:
return []
results = []
for ach in universe.achievements:
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
results.append(AchievementItem(
type=ach.get("type", "Unknown"),
description=ach.get("description", ""),
obtained_at=ach.get("obtained_at")
))
return results