Initial commit: clean project structure

- Backend: FastAPI + SQLAlchemy + Celery (Python 3.11+)
- Frontend: Vue 3 + TypeScript + Pinia + Tailwind
- Admin Frontend: separate Vue 3 app for management
- Docker Compose: 9 services orchestration
- Specs: design prototypes, memory system PRD, product roadmap

Cleanup performed:
- Removed temporary debug scripts from backend root
- Removed deprecated admin_app.py (embedded UI)
- Removed duplicate docs from admin-frontend
- Updated .gitignore for Vite cache and egg-info
This commit is contained in:
zhangtuo
2026-01-20 18:20:03 +08:00
commit e9d7f8832a
241 changed files with 33070 additions and 0 deletions

View File

View File

@@ -0,0 +1,307 @@
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.admin_auth import admin_guard
from app.db.admin_models import Provider
from app.db.database import get_db
from app.services.cost_tracker import cost_tracker
from app.services.secret_service import SecretService
router = APIRouter(dependencies=[Depends(admin_guard)])
class ProviderCreate(BaseModel):
name: str
type: str = Field(..., pattern="^(text|image|tts|storybook)$")
adapter: str
model: str | None = None
api_base: str | None = None
api_key: str | None = None # 可选,优先于 config_ref
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_json: dict | None = None
config_ref: str | None = None # 环境变量 key 名称(回退)
updated_by: str | None = None
class ProviderUpdate(ProviderCreate):
enabled: bool | None = None
api_key: str | None = None
config_json: dict | None = None
class ProviderResponse(BaseModel):
"""Provider 响应模型,隐藏敏感字段。"""
id: str
name: str
type: str
adapter: str
model: str | None = None
api_base: str | None = None
has_api_key: bool = False # 仅标识是否配置了 api_key不返回明文
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_ref: str | None = None
model_config = ConfigDict(from_attributes=True)
from app.services.adapters.registry import AdapterRegistry
from app.services.provider_router import DEFAULT_PROVIDERS
@router.get("/providers/adapters")
async def list_available_adapters():
"""获取所有可用的适配器类型 (定义的类)。"""
return AdapterRegistry.list_adapters()
@router.get("/providers/defaults")
async def get_env_defaults():
"""获取当前环境变量定义的默认策略 (Read-Only)。"""
return DEFAULT_PROVIDERS
@router.get("/providers", response_model=list[ProviderResponse])
async def list_providers(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Provider))
providers = result.scalars().all()
# 转换为响应模型,隐藏 api_key 明文
return [
ProviderResponse(
id=p.id,
name=p.name,
type=p.type,
adapter=p.adapter,
model=p.model,
api_base=p.api_base,
has_api_key=bool(p.api_key), # 仅标识是否有 key
timeout_ms=p.timeout_ms,
max_retries=p.max_retries,
weight=p.weight,
priority=p.priority,
enabled=p.enabled,
config_ref=p.config_ref,
)
for p in providers
]
def _to_response(provider: Provider) -> ProviderResponse:
"""将 Provider 转换为响应模型,隐藏敏感字段。"""
return ProviderResponse(
id=provider.id,
name=provider.name,
type=provider.type,
adapter=provider.adapter,
model=provider.model,
api_base=provider.api_base,
has_api_key=bool(provider.api_key),
timeout_ms=provider.timeout_ms,
max_retries=provider.max_retries,
weight=provider.weight,
priority=provider.priority,
enabled=provider.enabled,
config_ref=provider.config_ref,
)
@router.post("/providers", response_model=ProviderResponse)
async def create_provider(payload: ProviderCreate, db: AsyncSession = Depends(get_db)):
data = payload.model_dump()
# 加密 API Key
if data.get("api_key"):
data["api_key"] = SecretService.encrypt(data["api_key"])
provider = Provider(**data)
db.add(provider)
await db.commit()
await db.refresh(provider)
return _to_response(provider)
@router.put("/providers/{provider_id}", response_model=ProviderResponse)
async def update_provider(
provider_id: str, payload: ProviderUpdate, db: AsyncSession = Depends(get_db)
):
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
data = payload.model_dump(exclude_unset=True)
# 加密 API Key
if "api_key" in data and data["api_key"]:
data["api_key"] = SecretService.encrypt(data["api_key"])
for k, v in data.items():
setattr(provider, k, v)
await db.commit()
await db.refresh(provider)
return _to_response(provider)
@router.delete("/providers/{provider_id}")
async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
await db.delete(provider)
await db.commit()
return {"message": "deleted"}
# ==================== 密钥管理 API ====================
class SecretCreate(BaseModel):
"""密钥创建请求。"""
name: str = Field(..., description="密钥名称,如 CQTAI_API_KEY")
value: str = Field(..., description="密钥明文值")
class SecretResponse(BaseModel):
"""密钥响应,不返回明文。"""
name: str
created_at: str | None = None
updated_at: str | None = None
@router.get("/secrets", response_model=list[str])
async def list_secrets(db: AsyncSession = Depends(get_db)):
"""列出所有密钥名称(不返回值)。"""
return await SecretService.list_secrets(db)
@router.post("/secrets", response_model=SecretResponse)
async def create_or_update_secret(payload: SecretCreate, db: AsyncSession = Depends(get_db)):
"""创建或更新密钥。"""
secret = await SecretService.set_secret(db, payload.name, payload.value)
return SecretResponse(
name=secret.name,
created_at=secret.created_at.isoformat() if secret.created_at else None,
updated_at=secret.updated_at.isoformat() if secret.updated_at else None,
)
@router.delete("/secrets/{name}")
async def delete_secret(name: str, db: AsyncSession = Depends(get_db)):
"""删除密钥。"""
deleted = await SecretService.delete_secret(db, name)
if not deleted:
raise HTTPException(status_code=404, detail="Secret not found")
return {"message": "deleted"}
@router.get("/secrets/{name}/verify")
async def verify_secret(name: str, db: AsyncSession = Depends(get_db)):
"""验证密钥是否存在且可解密(不返回明文)。"""
value = await SecretService.get_secret(db, name)
if value is None:
raise HTTPException(status_code=404, detail="Secret not found")
return {"name": name, "valid": True, "length": len(value)}
# ==================== 成本追踪 API ====================
class BudgetUpdate(BaseModel):
"""预算更新请求。"""
daily_limit_usd: float | None = None
monthly_limit_usd: float | None = None
alert_threshold: float | None = Field(default=None, ge=0, le=1)
enabled: bool | None = None
@router.get("/costs/summary/{user_id}")
async def get_user_cost_summary(user_id: str, db: AsyncSession = Depends(get_db)):
"""获取用户成本摘要。"""
return await cost_tracker.get_cost_summary(db, user_id)
@router.get("/costs/all")
async def get_all_costs_summary(db: AsyncSession = Depends(get_db)):
"""获取所有用户成本汇总(管理员)。"""
from sqlalchemy import func
from app.db.admin_models import CostRecord
# 按用户汇总
result = await db.execute(
select(
CostRecord.user_id,
func.sum(CostRecord.estimated_cost).label("total_cost"),
func.count().label("call_count"),
).group_by(CostRecord.user_id)
)
users = [
{"user_id": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]}
for row in result.all()
]
# 按能力汇总
result = await db.execute(
select(
CostRecord.capability,
func.sum(CostRecord.estimated_cost).label("total_cost"),
func.count().label("call_count"),
).group_by(CostRecord.capability)
)
capabilities = [
{"capability": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]}
for row in result.all()
]
return {"by_user": users, "by_capability": capabilities}
@router.get("/budgets/{user_id}")
async def get_user_budget(user_id: str, db: AsyncSession = Depends(get_db)):
"""获取用户预算配置。"""
budget = await cost_tracker.get_user_budget(db, user_id)
if not budget:
return {"user_id": user_id, "budget": None}
return {
"user_id": user_id,
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd),
"monthly_limit_usd": float(budget.monthly_limit_usd),
"alert_threshold": float(budget.alert_threshold),
"enabled": budget.enabled,
},
}
@router.post("/budgets/{user_id}")
async def set_user_budget(
user_id: str, payload: BudgetUpdate, db: AsyncSession = Depends(get_db)
):
"""设置用户预算。"""
budget = await cost_tracker.set_user_budget(
db,
user_id,
daily_limit=payload.daily_limit_usd,
monthly_limit=payload.monthly_limit_usd,
alert_threshold=payload.alert_threshold,
enabled=payload.enabled,
)
return {
"user_id": user_id,
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd),
"monthly_limit_usd": float(budget.monthly_limit_usd),
"alert_threshold": float(budget.alert_threshold),
"enabled": budget.enabled,
},
}

View File

@@ -0,0 +1,14 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.admin_auth import admin_guard
from app.db.database import get_db
from app.services.provider_cache import reload_providers
router = APIRouter(dependencies=[Depends(admin_guard)])
@router.post("/providers/reload")
async def reload(db: AsyncSession = Depends(get_db)):
cache = await reload_providers(db)
return {k: len(v) for k, v in cache.items()}

272
backend/app/api/auth.py Normal file
View File

@@ -0,0 +1,272 @@
import secrets
from urllib.parse import urlencode
import httpx
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query
from fastapi.responses import RedirectResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.deps import get_current_user
from app.core.security import create_access_token
from app.db.database import get_db
from app.db.models import User
router = APIRouter()
# OAuth endpoints
GITHUB_AUTHORIZE_URL = "https://github.com/login/oauth/authorize"
GITHUB_TOKEN_URL = "https://github.com/login/oauth/access_token"
GITHUB_USER_URL = "https://api.github.com/user"
GOOGLE_AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/v2/auth"
GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token"
GOOGLE_USER_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
STATE_COOKIE = "oauth_state"
STATE_MAX_AGE = 600 # 10 minutes
def _set_state_cookie(response: RedirectResponse, provider: str, state: str) -> None:
response.set_cookie(
key=STATE_COOKIE,
value=f"{provider}:{state}",
httponly=True,
secure=not settings.debug,
samesite="lax",
max_age=STATE_MAX_AGE,
)
def _validate_state(state_from_query: str | None, state_cookie: str | None, provider: str):
if not state_from_query or not state_cookie:
raise HTTPException(status_code=400, detail="Missing OAuth state")
expected_prefix = f"{provider}:"
if not state_cookie.startswith(expected_prefix):
raise HTTPException(status_code=400, detail="OAuth state mismatch")
expected_state = state_cookie.removeprefix(expected_prefix)
if not secrets.compare_digest(state_from_query, expected_state):
raise HTTPException(status_code=400, detail="OAuth state mismatch")
@router.get("/github/signin")
async def github_signin():
"""Start GitHub OAuth with state protection."""
state = secrets.token_urlsafe(16)
params = {
"client_id": settings.github_client_id,
"redirect_uri": f"{settings.base_url}/auth/github/callback",
"scope": "read:user user:email",
"state": state,
}
url = f"{GITHUB_AUTHORIZE_URL}?{urlencode(params)}"
response = RedirectResponse(url=url)
_set_state_cookie(response, "github", state)
return response
@router.get("/github/callback")
async def github_callback(
code: str,
state: str | None = Query(default=None),
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
db: AsyncSession = Depends(get_db),
):
"""Handle GitHub OAuth callback."""
_validate_state(state, state_cookie, "github")
try:
async with httpx.AsyncClient() as client:
token_resp = await client.post(
GITHUB_TOKEN_URL,
data={
"client_id": settings.github_client_id,
"client_secret": settings.github_client_secret,
"code": code,
},
headers={"Accept": "application/json"},
)
token_resp.raise_for_status()
token_data = token_resp.json()
access_token = token_data.get("access_token")
if not access_token:
raise HTTPException(status_code=502, detail="GitHub login failed")
user_resp = await client.get(
GITHUB_USER_URL,
headers={"Authorization": f"Bearer {access_token}"},
)
user_resp.raise_for_status()
user_data = user_resp.json()
except httpx.HTTPStatusError:
raise HTTPException(status_code=502, detail="GitHub login failed")
github_id = user_data.get("id")
if github_id is None:
raise HTTPException(status_code=502, detail="GitHub login failed")
return await _handle_oauth_user(
db=db,
provider="github",
user_id=str(github_id),
name=user_data.get("name") or user_data.get("login") or "GitHub User",
avatar_url=user_data.get("avatar_url"),
)
@router.get("/google/signin")
async def google_signin():
"""Start Google OAuth with state protection."""
state = secrets.token_urlsafe(16)
params = {
"client_id": settings.google_client_id,
"redirect_uri": f"{settings.base_url}/auth/google/callback",
"response_type": "code",
"scope": "openid email profile",
"state": state,
}
url = f"{GOOGLE_AUTHORIZE_URL}?{urlencode(params)}"
response = RedirectResponse(url=url)
_set_state_cookie(response, "google", state)
return response
@router.get("/google/callback")
async def google_callback(
code: str,
state: str | None = Query(default=None),
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
db: AsyncSession = Depends(get_db),
):
"""Handle Google OAuth callback."""
_validate_state(state, state_cookie, "google")
try:
async with httpx.AsyncClient() as client:
token_resp = await client.post(
GOOGLE_TOKEN_URL,
data={
"client_id": settings.google_client_id,
"client_secret": settings.google_client_secret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": f"{settings.base_url}/auth/google/callback",
},
)
token_resp.raise_for_status()
token_data = token_resp.json()
access_token = token_data.get("access_token")
if not access_token:
raise HTTPException(status_code=502, detail="Google login failed")
user_resp = await client.get(
GOOGLE_USER_URL,
headers={"Authorization": f"Bearer {access_token}"},
)
user_resp.raise_for_status()
user_data = user_resp.json()
except httpx.HTTPStatusError:
raise HTTPException(status_code=502, detail="Google login failed")
google_id = user_data.get("id")
if google_id is None:
raise HTTPException(status_code=502, detail="Google login failed")
return await _handle_oauth_user(
db=db,
provider="google",
user_id=str(google_id),
name=user_data.get("name") or user_data.get("email") or "Google User",
avatar_url=user_data.get("picture"),
)
async def _handle_oauth_user(
db: AsyncSession,
provider: str,
user_id: str,
name: str,
avatar_url: str | None,
) -> RedirectResponse:
"""Create/update user and issue session cookie."""
full_id = f"{provider}:{user_id}"
result = await db.execute(select(User).where(User.id == full_id))
user = result.scalar_one_or_none()
if not user:
user = User(
id=full_id,
name=name,
avatar_url=avatar_url,
provider=provider,
)
db.add(user)
else:
user.name = name
user.avatar_url = avatar_url
await db.commit()
token = create_access_token({"sub": user.id})
frontend_url = "http://localhost:5173"
if settings.cors_origins and len(settings.cors_origins) > 0:
frontend_url = settings.cors_origins[0]
response = RedirectResponse(url=f"{frontend_url}/my-stories", status_code=302)
response.set_cookie(
key="access_token",
value=token,
httponly=True,
secure=not settings.debug,
samesite="lax",
max_age=60 * 60 * 24 * 7, # align with ACCESS_TOKEN_EXPIRE_DAYS
)
response.delete_cookie(STATE_COOKIE)
return response
@router.post("/signout")
async def signout():
"""Sign out and clear cookies."""
response = RedirectResponse(url=settings.cors_origins[0], status_code=302)
response.delete_cookie("access_token", samesite="lax", secure=not settings.debug)
response.delete_cookie(STATE_COOKIE, samesite="lax", secure=not settings.debug)
return response
@router.get("/session")
async def get_session(user: User | None = Depends(get_current_user)):
"""Fetch current session info."""
if not user:
return {"user": None}
return {
"user": {
"id": user.id,
"name": user.name,
"avatar_url": user.avatar_url,
"provider": user.provider,
}
}
@router.get("/dev/signin")
async def dev_signin(db: AsyncSession = Depends(get_db)):
"""Developer backdoor login. Only works in DEBUG mode."""
# if not settings.debug:
# raise HTTPException(status_code=403, detail="Developer login disabled")
try:
return await _handle_oauth_user(
db=db,
provider="github",
user_id="dev_user_001",
name="Developer",
avatar_url="https://api.dicebear.com/7.x/avataaars/svg?seed=Developer"
)
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Dev login failed: {str(e)}")

268
backend/app/api/memories.py Normal file
View File

@@ -0,0 +1,268 @@
"""Memory management APIs."""
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, User
from app.services import memory_service
from app.services.memory_service import MemoryType
router = APIRouter()
class MemoryItemResponse(BaseModel):
"""Memory item response."""
id: str
type: str
value: dict
base_weight: float
ttl_days: int | None
created_at: str
last_used_at: str | None
class Config:
from_attributes = True
class MemoryListResponse(BaseModel):
"""Memory list response."""
memories: list[MemoryItemResponse]
total: int
class CreateMemoryRequest(BaseModel):
"""Create memory request."""
type: str = Field(..., description="记忆类型")
value: dict = Field(..., description="记忆内容")
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
weight: float | None = Field(default=None, description="权重")
ttl_days: int | None = Field(default=None, description="过期天数")
class CreateCharacterMemoryRequest(BaseModel):
"""Create character memory request."""
name: str = Field(..., description="角色名称")
description: str | None = Field(default=None, description="角色描述")
source_story_id: int | None = Field(default=None, description="来源故事 ID")
affinity_score: float = Field(default=1.0, ge=0.0, le=1.0, description="喜爱程度")
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
class CreateScaryElementRequest(BaseModel):
"""Create scary element memory request."""
keyword: str = Field(..., description="回避的关键词")
category: str = Field(default="other", description="分类")
source_story_id: int | None = Field(default=None, description="来源故事 ID")
async def _verify_profile_ownership(
profile_id: str, user: User, db: AsyncSession
) -> ChildProfile:
"""验证档案所有权。"""
from sqlalchemy import select
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
return profile
@router.get("/profiles/{profile_id}/memories", response_model=MemoryListResponse)
async def list_memories(
profile_id: str,
memory_type: str | None = None,
universe_id: str | None = None,
limit: int = 50,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""获取档案的记忆列表。"""
await _verify_profile_ownership(profile_id, user, db)
memories = await memory_service.get_profile_memories(
db=db,
profile_id=profile_id,
memory_type=memory_type,
universe_id=universe_id,
limit=limit,
)
return MemoryListResponse(
memories=[
MemoryItemResponse(
id=m.id,
type=m.type,
value=m.value,
base_weight=m.base_weight,
ttl_days=m.ttl_days,
created_at=m.created_at.isoformat() if m.created_at else "",
last_used_at=m.last_used_at.isoformat() if m.last_used_at else None,
)
for m in memories
],
total=len(memories),
)
@router.post("/profiles/{profile_id}/memories", response_model=MemoryItemResponse)
async def create_memory(
profile_id: str,
payload: CreateMemoryRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""创建新的记忆项。"""
await _verify_profile_ownership(profile_id, user, db)
# 验证类型
valid_types = [
MemoryType.RECENT_STORY,
MemoryType.FAVORITE_CHARACTER,
MemoryType.SCARY_ELEMENT,
MemoryType.VOCABULARY_GROWTH,
MemoryType.EMOTIONAL_HIGHLIGHT,
MemoryType.READING_PREFERENCE,
MemoryType.MILESTONE,
MemoryType.SKILL_MASTERED,
]
if payload.type not in valid_types:
raise HTTPException(status_code=400, detail=f"无效的记忆类型: {payload.type}")
memory = await memory_service.create_memory(
db=db,
profile_id=profile_id,
memory_type=payload.type,
value=payload.value,
universe_id=payload.universe_id,
weight=payload.weight,
ttl_days=payload.ttl_days,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.post("/profiles/{profile_id}/memories/character", response_model=MemoryItemResponse)
async def create_character_memory(
profile_id: str,
payload: CreateCharacterMemoryRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""添加喜欢的角色。"""
await _verify_profile_ownership(profile_id, user, db)
memory = await memory_service.create_character_memory(
db=db,
profile_id=profile_id,
name=payload.name,
description=payload.description,
source_story_id=payload.source_story_id,
affinity_score=payload.affinity_score,
universe_id=payload.universe_id,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.post("/profiles/{profile_id}/memories/scary", response_model=MemoryItemResponse)
async def create_scary_element_memory(
profile_id: str,
payload: CreateScaryElementRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""添加回避元素。"""
await _verify_profile_ownership(profile_id, user, db)
memory = await memory_service.create_scary_element_memory(
db=db,
profile_id=profile_id,
keyword=payload.keyword,
category=payload.category,
source_story_id=payload.source_story_id,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.delete("/profiles/{profile_id}/memories/{memory_id}")
async def delete_memory(
profile_id: str,
memory_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""删除记忆项。"""
from sqlalchemy import select
from app.db.models import MemoryItem
await _verify_profile_ownership(profile_id, user, db)
result = await db.execute(
select(MemoryItem).where(
MemoryItem.id == memory_id,
MemoryItem.child_profile_id == profile_id,
)
)
memory = result.scalar_one_or_none()
if not memory:
raise HTTPException(status_code=404, detail="记忆不存在")
await db.delete(memory)
await db.commit()
return {"message": "Deleted"}
@router.get("/memory-types")
async def list_memory_types():
"""获取所有可用的记忆类型及其配置。"""
types = []
for type_name, config in MemoryType.CONFIG.items():
types.append({
"type": type_name,
"default_weight": config[0],
"default_ttl_days": config[1],
"description": config[2],
})
return {"types": types}

280
backend/app/api/profiles.py Normal file
View File

@@ -0,0 +1,280 @@
"""Child profile APIs."""
from datetime import date
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, Story, StoryUniverse, User
router = APIRouter()
MAX_PROFILES_PER_USER = 5
class ChildProfileCreate(BaseModel):
"""Create profile payload."""
name: str = Field(..., min_length=1, max_length=50)
birth_date: date | None = None
gender: str | None = Field(default=None, pattern="^(male|female|other)$")
interests: list[str] = Field(default_factory=list)
growth_themes: list[str] = Field(default_factory=list)
avatar_url: str | None = None
class ChildProfileUpdate(BaseModel):
"""Update profile payload."""
name: str | None = Field(default=None, min_length=1, max_length=50)
birth_date: date | None = None
gender: str | None = Field(default=None, pattern="^(male|female|other)$")
interests: list[str] | None = None
growth_themes: list[str] | None = None
avatar_url: str | None = None
class ChildProfileResponse(BaseModel):
"""Profile response."""
id: str
name: str
avatar_url: str | None
birth_date: date | None
gender: str | None
age: int | None
interests: list[str]
growth_themes: list[str]
stories_count: int
total_reading_time: int
class Config:
from_attributes = True
class ChildProfileListResponse(BaseModel):
"""Profile list response."""
profiles: list[ChildProfileResponse]
total: int
class TimelineEvent(BaseModel):
"""Timeline event item."""
date: str
type: Literal["story", "achievement", "milestone"]
title: str
description: str | None = None
image_url: str | None = None
metadata: dict | None = None
class TimelineResponse(BaseModel):
"""Timeline response."""
events: list[TimelineEvent]
@router.get("/profiles", response_model=ChildProfileListResponse)
async def list_profiles(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List child profiles for current user."""
result = await db.execute(
select(ChildProfile)
.where(ChildProfile.user_id == user.id)
.order_by(ChildProfile.created_at.desc())
)
profiles = result.scalars().all()
return ChildProfileListResponse(profiles=profiles, total=len(profiles))
@router.post("/profiles", response_model=ChildProfileResponse, status_code=status.HTTP_201_CREATED)
async def create_profile(
payload: ChildProfileCreate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Create a new child profile."""
count = await db.scalar(
select(func.count(ChildProfile.id)).where(ChildProfile.user_id == user.id)
)
if count and count >= MAX_PROFILES_PER_USER:
raise HTTPException(status_code=400, detail="最多只能创建 5 个孩子档案")
existing = await db.scalar(
select(ChildProfile.id).where(
ChildProfile.user_id == user.id,
ChildProfile.name == payload.name,
)
)
if existing:
raise HTTPException(status_code=409, detail="该档案名称已存在")
profile = ChildProfile(user_id=user.id, **payload.model_dump())
db.add(profile)
await db.commit()
await db.refresh(profile)
return profile
@router.get("/profiles/{profile_id}", response_model=ChildProfileResponse)
async def get_profile(
profile_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get one child profile."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
return profile
@router.put("/profiles/{profile_id}", response_model=ChildProfileResponse)
async def update_profile(
profile_id: str,
payload: ChildProfileUpdate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Update a child profile."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
updates = payload.model_dump(exclude_unset=True)
if "name" in updates:
existing = await db.scalar(
select(ChildProfile.id).where(
ChildProfile.user_id == user.id,
ChildProfile.name == updates["name"],
ChildProfile.id != profile_id,
)
)
if existing:
raise HTTPException(status_code=409, detail="该档案名称已存在")
for key, value in updates.items():
setattr(profile, key, value)
await db.commit()
await db.refresh(profile)
return profile
@router.delete("/profiles/{profile_id}")
async def delete_profile(
profile_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Delete a child profile."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
await db.delete(profile)
await db.commit()
return {"message": "Deleted"}
@router.get("/profiles/{profile_id}/timeline", response_model=TimelineResponse)
async def get_profile_timeline(
profile_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get profile growth timeline."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
events: list[TimelineEvent] = []
# 1. Milestone: Profile Created
events.append(TimelineEvent(
date=profile.created_at.isoformat(),
type="milestone",
title="初次相遇",
description=f"创建了档案 {profile.name}"
))
# 2. Stories
stories_result = await db.execute(
select(Story).where(Story.child_profile_id == profile_id)
)
for s in stories_result.scalars():
events.append(TimelineEvent(
date=s.created_at.isoformat(),
type="story",
title=s.title,
image_url=s.image_url,
metadata={"story_id": s.id, "mode": s.mode}
))
# 3. Achievements (from Universe)
universes_result = await db.execute(
select(StoryUniverse).where(StoryUniverse.child_profile_id == profile_id)
)
for u in universes_result.scalars():
if u.achievements:
for ach in u.achievements:
if isinstance(ach, dict):
obt_at = ach.get("obtained_at")
# Fallback
if not obt_at:
obt_at = u.updated_at.isoformat()
events.append(TimelineEvent(
date=obt_at,
type="achievement",
title=f"获得成就:{ach.get('type')}",
description=ach.get('description'),
metadata={"universe_id": u.id, "source_story_id": ach.get("source_story_id")}
))
# Sort by date desc
events.sort(key=lambda x: x.date, reverse=True)
return TimelineResponse(events=events)

View File

@@ -0,0 +1,120 @@
"""Push configuration APIs."""
from datetime import time
from fastapi import APIRouter, Depends, HTTPException, Response, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, PushConfig, User
router = APIRouter()
class PushConfigUpsert(BaseModel):
"""Upsert push config payload."""
child_profile_id: str
push_time: time | None = None
push_days: list[int] | None = None
enabled: bool | None = None
class PushConfigResponse(BaseModel):
"""Push config response."""
id: str
child_profile_id: str
push_time: time | None
push_days: list[int]
enabled: bool
class Config:
from_attributes = True
class PushConfigListResponse(BaseModel):
"""Push config list response."""
configs: list[PushConfigResponse]
total: int
def _validate_push_days(push_days: list[int]) -> list[int]:
invalid = [day for day in push_days if day < 0 or day > 6]
if invalid:
raise HTTPException(status_code=400, detail="推送日期必须在 0-6 之间")
return list(dict.fromkeys(push_days))
@router.get("/push-configs", response_model=PushConfigListResponse)
async def list_push_configs(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List push configs for current user."""
result = await db.execute(
select(PushConfig).where(PushConfig.user_id == user.id)
)
configs = result.scalars().all()
return PushConfigListResponse(configs=configs, total=len(configs))
@router.put("/push-configs", response_model=PushConfigResponse)
async def upsert_push_config(
payload: PushConfigUpsert,
response: Response,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Create or update push config for a child profile."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == payload.child_profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
result = await db.execute(
select(PushConfig).where(PushConfig.child_profile_id == payload.child_profile_id)
)
config = result.scalar_one_or_none()
if config is None:
if payload.push_time is None or payload.push_days is None:
raise HTTPException(status_code=400, detail="创建配置需要提供推送时间和日期")
push_days = _validate_push_days(payload.push_days)
config = PushConfig(
user_id=user.id,
child_profile_id=payload.child_profile_id,
push_time=payload.push_time,
push_days=push_days,
enabled=True if payload.enabled is None else payload.enabled,
)
db.add(config)
await db.commit()
await db.refresh(config)
response.status_code = status.HTTP_201_CREATED
return config
updates = payload.model_dump(exclude_unset=True)
if "push_days" in updates and updates["push_days"] is not None:
updates["push_days"] = _validate_push_days(updates["push_days"])
if "push_time" in updates and updates["push_time"] is None:
raise HTTPException(status_code=400, detail="推送时间不能为空")
for key, value in updates.items():
if key == "child_profile_id":
continue
if value is not None:
setattr(config, key, value)
await db.commit()
await db.refresh(config)
return config

View File

@@ -0,0 +1,120 @@
"""Reading event APIs."""
from datetime import datetime, timezone
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, MemoryItem, ReadingEvent, Story, User
router = APIRouter()
EVENT_WEIGHTS: dict[str, float] = {
"completed": 1.0,
"replayed": 1.5,
"started": 0.1,
"skipped": -0.5,
}
class ReadingEventCreate(BaseModel):
"""Reading event payload."""
child_profile_id: str
story_id: int | None = None
event_type: Literal["started", "completed", "skipped", "replayed"]
reading_time: int = Field(default=0, ge=0)
class ReadingEventResponse(BaseModel):
"""Reading event response."""
id: int
child_profile_id: str
story_id: int | None
event_type: str
reading_time: int
created_at: datetime
class Config:
from_attributes = True
@router.post("/reading-events", response_model=ReadingEventResponse, status_code=status.HTTP_201_CREATED)
async def create_reading_event(
payload: ReadingEventCreate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Create a reading event and update profile stats/memory."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == payload.child_profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
story = None
if payload.story_id is not None:
result = await db.execute(
select(Story).where(
Story.id == payload.story_id,
Story.user_id == user.id,
)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="故事不存在")
if payload.reading_time:
profile.total_reading_time = (profile.total_reading_time or 0) + payload.reading_time
if payload.event_type in {"completed", "replayed"} and payload.story_id is not None:
existing = await db.scalar(
select(ReadingEvent.id).where(
ReadingEvent.child_profile_id == payload.child_profile_id,
ReadingEvent.story_id == payload.story_id,
ReadingEvent.event_type.in_(["completed", "replayed"]),
)
)
if existing is None:
profile.stories_count = (profile.stories_count or 0) + 1
event = ReadingEvent(
child_profile_id=payload.child_profile_id,
story_id=payload.story_id,
event_type=payload.event_type,
reading_time=payload.reading_time,
)
db.add(event)
weight = EVENT_WEIGHTS.get(payload.event_type, 0.0)
if story and weight > 0:
db.add(
MemoryItem(
child_profile_id=payload.child_profile_id,
universe_id=story.universe_id,
type="recent_story",
value={
"story_id": story.id,
"title": story.title,
"event_type": payload.event_type,
},
base_weight=weight,
last_used_at=datetime.now(timezone.utc),
ttl_days=90,
)
)
await db.commit()
await db.refresh(event)
return event

605
backend/app/api/stories.py Normal file
View File

@@ -0,0 +1,605 @@
"""Story related APIs."""
import asyncio
import json
import time
import uuid
from typing import AsyncGenerator, Literal
from cachetools import TTLCache
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sse_starlette.sse import EventSourceResponse
from app.core.deps import require_user
from app.core.logging import get_logger
from app.db.database import get_db
from app.db.models import ChildProfile, Story, StoryUniverse, User
from app.services.provider_router import (
generate_image,
generate_story_content,
generate_storybook,
text_to_speech,
)
from app.tasks.achievements import extract_story_achievements
logger = get_logger(__name__)
router = APIRouter()
MAX_DATA_LENGTH = 2000
MAX_EDU_THEME_LENGTH = 200
MAX_TTS_LENGTH = 4000
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10
RATE_LIMIT_CACHE_SIZE = 10000 # 最大跟踪用户数
_request_log: TTLCache[str, list[float]] = TTLCache(
maxsize=RATE_LIMIT_CACHE_SIZE, ttl=RATE_LIMIT_WINDOW * 2
)
def _check_rate_limit(user_id: str):
now = time.time()
timestamps = _request_log.get(user_id, [])
timestamps = [t for t in timestamps if now - t <= RATE_LIMIT_WINDOW]
if len(timestamps) >= RATE_LIMIT_REQUESTS:
raise HTTPException(status_code=429, detail="Too many requests, please slow down.")
timestamps.append(now)
_request_log[user_id] = timestamps
class GenerateRequest(BaseModel):
"""Story generation request."""
type: Literal["keywords", "full_story"]
data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
child_profile_id: str | None = None
universe_id: str | None = None
class StoryResponse(BaseModel):
"""Story response."""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
mode: str
child_profile_id: str | None = None
universe_id: str | None = None
class StoryListItem(BaseModel):
"""Story list item."""
id: int
title: str
image_url: str | None
created_at: str
mode: str
class FullStoryResponse(BaseModel):
"""完整故事响应(含图片和音频状态)。"""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
audio_ready: bool
mode: str
errors: dict[str, str | None] = Field(default_factory=dict)
child_profile_id: str | None = None
universe_id: str | None = None
from app.services.memory_service import build_enhanced_memory_context
async def _validate_profile_and_universe(
request: GenerateRequest,
user: User,
db: AsyncSession,
) -> tuple[str | None, str | None]:
if not request.child_profile_id and not request.universe_id:
return None, None
profile_id = request.child_profile_id
universe_id = request.universe_id
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
return profile_id, universe_id
@router.post("/stories/generate", response_model=StoryResponse)
async def generate_story(
request: GenerateRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate or enhance a story."""
_check_rate_limit(user.id)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return StoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=story.image_url,
mode=story.mode,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
@router.post("/stories/generate/full", response_model=FullStoryResponse)
async def generate_story_full(
request: GenerateRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""生成完整故事(故事 + 并行生成图片和音频)。
部分成功策略:故事必须成功,图片/音频失败不影响整体。
"""
_check_rate_limit(user.id)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# Step 1: 故事生成(必须成功)
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except Exception as exc:
logger.error("story_generation_failed", error=str(exc))
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
# 保存故事
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# Step 2: 生成封面图片(音频按需生成,避免浪费)
errors: dict[str, str | None] = {}
image_url: str | None = None
if story.cover_prompt:
try:
image_url = await generate_image(story.cover_prompt)
story.image_url = image_url
await db.commit()
except Exception as exc:
errors["image"] = str(exc)
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
# 注意:音频不在此处预生成,用户通过 /api/audio/{id} 按需获取
# 这样避免生成后丢弃造成的成本浪费
return FullStoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=image_url,
audio_ready=False, # 音频需要用户主动请求
mode=story.mode,
errors=errors,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
@router.post("/stories/generate/stream")
async def generate_story_stream(
request: GenerateRequest,
req: Request,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""流式生成故事SSE
事件流程:
- started: 返回 story_id
- story_ready: 返回 title, content
- story_failed: 返回 error
- image_ready: 返回 image_url
- image_failed: 返回 error
- complete: 结束流
"""
_check_rate_limit(user.id)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
async def event_generator() -> AsyncGenerator[dict, None]:
story_id = str(uuid.uuid4())
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
# Step 1: 生成故事
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except Exception as e:
logger.error("sse_story_generation_failed", error=str(e))
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
return
# 保存故事
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
yield {
"event": "story_ready",
"data": json.dumps({
"id": story.id,
"title": story.title,
"content": story.story_text,
"cover_prompt": story.cover_prompt,
"mode": story.mode,
"child_profile_id": story.child_profile_id,
"universe_id": story.universe_id,
}),
}
# Step 2: 并行生成图片(音频按需)
if story.cover_prompt:
try:
image_url = await generate_image(story.cover_prompt)
story.image_url = image_url
await db.commit()
yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})}
except Exception as e:
logger.warning("sse_image_generation_failed", story_id=story.id, error=str(e))
yield {"event": "image_failed", "data": json.dumps({"error": str(e)})}
yield {"event": "complete", "data": json.dumps({"story_id": story.id})}
return EventSourceResponse(event_generator())
# ==================== Storybook API ====================
class StorybookRequest(BaseModel):
"""Storybook 生成请求。"""
keywords: str = Field(..., min_length=1, max_length=200)
page_count: int = Field(default=6, ge=4, le=12)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
generate_images: bool = Field(default=False, description="是否同时生成插图")
child_profile_id: str | None = None
universe_id: str | None = None
class StorybookPageResponse(BaseModel):
"""故事书单页响应。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
class StorybookResponse(BaseModel):
"""故事书响应。"""
id: int | None = None
title: str
main_character: str
art_style: str
pages: list[StorybookPageResponse]
cover_prompt: str
cover_url: str | None = None
@router.post("/storybook/generate", response_model=StorybookResponse)
async def generate_storybook_api(
request: StorybookRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""生成分页故事书并保存。
返回故事书结构,包含每页文字和图像提示词。
"""
_check_rate_limit(user.id)
# 验证档案和宇宙
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数
# 这里我们直接手动验证,或重构验证函数。为了简单,手动调用部分逻辑。
# 构建临时的 GenerateRequest 用于验证验证函数签名(或者直接手动查库更好)
profile_id = request.child_profile_id
universe_id = request.universe_id
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
if not result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
logger.info(
"storybook_request",
user_id=user.id,
keywords=request.keywords,
page_count=request.page_count,
profile_id=profile_id,
universe_id=universe_id,
)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
try:
# 注意generate_storybook 目前可能不支持记忆上下文注入
# 我们需要看看 generate_storybook 的签名
# 如果不支持,记忆功能在绘本模式下暂不可用,但基本参数传递是支持的
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
# ==============================================================================
# 核心升级: 并行全量生成 (Parallel Full Rendering)
# ==============================================================================
final_cover_url = storybook.cover_url
if request.generate_images:
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
# 1. 准备所有生图任务 (封面 + 所有内页)
tasks = []
# 封面任务
async def _gen_cover():
if storybook.cover_prompt and not storybook.cover_url:
try:
return await generate_image(storybook.cover_prompt, db=db)
except Exception as e:
logger.warning("cover_gen_failed", error=str(e))
return storybook.cover_url
tasks.append(_gen_cover())
# 内页任务
async def _gen_page(page):
if page.image_prompt and not page.image_url:
try:
url = await generate_image(page.image_prompt, db=db)
page.image_url = url
except Exception as e:
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
for page in storybook.pages:
tasks.append(_gen_page(page))
# 2. 并发执行所有任务
# 使用 return_exceptions=True 防止单张失败影响整体
results = await asyncio.gather(*tasks, return_exceptions=True)
# 3. 更新封面结果 (results[0] 是封面任务的返回值)
cover_res = results[0]
if isinstance(cover_res, str):
final_cover_url = cover_res
logger.info("storybook_parallel_generation_complete")
# ==============================================================================
# 构建并保存 Story 对象
# 将 pages 对象转换为字典列表以存入 JSON 字段
pages_data = [
{
"page_number": p.page_number,
"text": p.text,
"image_prompt": p.image_prompt,
"image_url": p.image_url,
}
for p in storybook.pages
]
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=storybook.title,
mode="storybook",
pages=pages_data, # 存入 JSON 字段
story_text=None, # 绘本模式下,主文本可为空,或者可以存个摘要
cover_prompt=storybook.cover_prompt,
image_url=final_cover_url,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# 构建响应 (使用更新后的 pages_data)
response_pages = [
StorybookPageResponse(
page_number=p["page_number"],
text=p["text"],
image_prompt=p["image_prompt"],
image_url=p.get("image_url"),
)
for p in pages_data
]
return StorybookResponse(
id=story.id,
title=storybook.title,
main_character=storybook.main_character,
art_style=storybook.art_style,
pages=response_pages,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
)
class AchievementItem(BaseModel):
type: str
description: str
obtained_at: str | None = None
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
async def get_story_achievements(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get achievements unlocked by a specific story."""
# 使用 joinedload 避免 N+1 查询
result = await db.execute(
select(Story)
.options(joinedload(Story.story_universe))
.where(Story.id == story_id, Story.user_id == user.id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
if not story.universe_id or not story.story_universe:
return []
universe = story.story_universe
if not universe.achievements:
return []
results = []
for ach in universe.achievements:
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
results.append(AchievementItem(
type=ach.get("type", "Unknown"),
description=ach.get("description", ""),
obtained_at=ach.get("obtained_at")
))
return results

View File

@@ -0,0 +1,201 @@
"""Story universe APIs."""
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, StoryUniverse, User
router = APIRouter()
class StoryUniverseCreate(BaseModel):
"""Create universe payload."""
name: str = Field(..., min_length=1, max_length=100)
protagonist: dict[str, Any]
recurring_characters: list[dict[str, Any]] = Field(default_factory=list)
world_settings: dict[str, Any] = Field(default_factory=dict)
class StoryUniverseUpdate(BaseModel):
"""Update universe payload."""
name: str | None = Field(default=None, min_length=1, max_length=100)
protagonist: dict[str, Any] | None = None
recurring_characters: list[dict[str, Any]] | None = None
world_settings: dict[str, Any] | None = None
class AchievementCreate(BaseModel):
"""Achievement payload."""
type: str = Field(..., min_length=1, max_length=50)
description: str = Field(..., min_length=1, max_length=200)
class StoryUniverseResponse(BaseModel):
"""Universe response."""
id: str
child_profile_id: str
name: str
protagonist: dict[str, Any]
recurring_characters: list[dict[str, Any]]
world_settings: dict[str, Any]
achievements: list[dict[str, Any]]
class Config:
from_attributes = True
class StoryUniverseListResponse(BaseModel):
"""Universe list response."""
universes: list[StoryUniverseResponse]
total: int
async def _get_profile_or_404(
profile_id: str,
user: User,
db: AsyncSession,
) -> ChildProfile:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
return profile
async def _get_universe_or_404(
universe_id: str,
user: User,
db: AsyncSession,
) -> StoryUniverse:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="宇宙不存在")
return universe
@router.get("/profiles/{profile_id}/universes", response_model=StoryUniverseListResponse)
async def list_universes(
profile_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List universes for a child profile."""
await _get_profile_or_404(profile_id, user, db)
result = await db.execute(
select(StoryUniverse)
.where(StoryUniverse.child_profile_id == profile_id)
.order_by(StoryUniverse.updated_at.desc())
)
universes = result.scalars().all()
return StoryUniverseListResponse(universes=universes, total=len(universes))
@router.post(
"/profiles/{profile_id}/universes",
response_model=StoryUniverseResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_universe(
profile_id: str,
payload: StoryUniverseCreate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Create a story universe."""
await _get_profile_or_404(profile_id, user, db)
universe = StoryUniverse(child_profile_id=profile_id, **payload.model_dump())
db.add(universe)
await db.commit()
await db.refresh(universe)
return universe
@router.get("/universes/{universe_id}", response_model=StoryUniverseResponse)
async def get_universe(
universe_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get one universe."""
universe = await _get_universe_or_404(universe_id, user, db)
return universe
@router.put("/universes/{universe_id}", response_model=StoryUniverseResponse)
async def update_universe(
universe_id: str,
payload: StoryUniverseUpdate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Update a story universe."""
universe = await _get_universe_or_404(universe_id, user, db)
updates = payload.model_dump(exclude_unset=True)
for key, value in updates.items():
setattr(universe, key, value)
await db.commit()
await db.refresh(universe)
return universe
@router.delete("/universes/{universe_id}")
async def delete_universe(
universe_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Delete a story universe."""
universe = await _get_universe_or_404(universe_id, user, db)
await db.delete(universe)
await db.commit()
return {"message": "Deleted"}
@router.post("/universes/{universe_id}/achievements", response_model=StoryUniverseResponse)
async def add_achievement(
universe_id: str,
payload: AchievementCreate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Add an achievement to a universe."""
universe = await _get_universe_or_404(universe_id, user, db)
achievements = list(universe.achievements or [])
key = (payload.type.strip(), payload.description.strip())
existing = {
(str(item.get("type", "")).strip(), str(item.get("description", "")).strip())
for item in achievements
if isinstance(item, dict)
}
if key not in existing:
achievements.append({"type": key[0], "description": key[1]})
universe.achievements = achievements
await db.commit()
await db.refresh(universe)
return universe