Initial commit: clean project structure

- Backend: FastAPI + SQLAlchemy + Celery (Python 3.11+)
- Frontend: Vue 3 + TypeScript + Pinia + Tailwind
- Admin Frontend: separate Vue 3 app for management
- Docker Compose: 9 services orchestration
- Specs: design prototypes, memory system PRD, product roadmap

Cleanup performed:
- Removed temporary debug scripts from backend root
- Removed deprecated admin_app.py (embedded UI)
- Removed duplicate docs from admin-frontend
- Updated .gitignore for Vite cache and egg-info
This commit is contained in:
zhangtuo
2026-01-20 18:20:03 +08:00
commit e9d7f8832a
241 changed files with 33070 additions and 0 deletions

0
backend/app/__init__.py Normal file
View File

61
backend/app/admin_main.py Normal file
View File

@@ -0,0 +1,61 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api import admin_providers, admin_reload
from app.core.config import settings
from app.core.logging import get_logger, setup_logging
from app.db.database import init_db
setup_logging()
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Admin App lifespan manager."""
logger.info("admin_app_starting")
await init_db()
# 可以在这里加载特定的 Admin 缓存或预热
yield
logger.info("admin_app_shutdown")
app = FastAPI(
title=f"{settings.app_name} Admin Console",
description="Administrative Control Plane for DreamWeaver.",
version="0.1.0",
lifespan=lifespan,
)
# Admin 后台通常允许更宽松的 CORS或者特定的管理域名
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins, # 或者专门的 ADMIN_CORS_ORIGINS
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 根据配置开关挂载路由
if settings.enable_admin_console:
app.include_router(admin_providers.router, prefix="/admin", tags=["admin-providers"])
app.include_router(admin_reload.router, prefix="/admin", tags=["admin-reload"])
else:
@app.get("/admin/{path:path}")
@app.post("/admin/{path:path}")
@app.put("/admin/{path:path}")
@app.delete("/admin/{path:path}")
async def admin_disabled(path: str):
from fastapi import HTTPException
raise HTTPException(
status_code=403,
detail="Admin console is disabled in environment configuration."
)
@app.get("/health")
async def health_check():
return {"status": "ok", "service": "admin-backend"}

View File

View File

@@ -0,0 +1,307 @@
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.admin_auth import admin_guard
from app.db.admin_models import Provider
from app.db.database import get_db
from app.services.cost_tracker import cost_tracker
from app.services.secret_service import SecretService
router = APIRouter(dependencies=[Depends(admin_guard)])
class ProviderCreate(BaseModel):
name: str
type: str = Field(..., pattern="^(text|image|tts|storybook)$")
adapter: str
model: str | None = None
api_base: str | None = None
api_key: str | None = None # 可选,优先于 config_ref
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_json: dict | None = None
config_ref: str | None = None # 环境变量 key 名称(回退)
updated_by: str | None = None
class ProviderUpdate(ProviderCreate):
enabled: bool | None = None
api_key: str | None = None
config_json: dict | None = None
class ProviderResponse(BaseModel):
"""Provider 响应模型,隐藏敏感字段。"""
id: str
name: str
type: str
adapter: str
model: str | None = None
api_base: str | None = None
has_api_key: bool = False # 仅标识是否配置了 api_key不返回明文
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_ref: str | None = None
model_config = ConfigDict(from_attributes=True)
from app.services.adapters.registry import AdapterRegistry
from app.services.provider_router import DEFAULT_PROVIDERS
@router.get("/providers/adapters")
async def list_available_adapters():
"""获取所有可用的适配器类型 (定义的类)。"""
return AdapterRegistry.list_adapters()
@router.get("/providers/defaults")
async def get_env_defaults():
"""获取当前环境变量定义的默认策略 (Read-Only)。"""
return DEFAULT_PROVIDERS
@router.get("/providers", response_model=list[ProviderResponse])
async def list_providers(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Provider))
providers = result.scalars().all()
# 转换为响应模型,隐藏 api_key 明文
return [
ProviderResponse(
id=p.id,
name=p.name,
type=p.type,
adapter=p.adapter,
model=p.model,
api_base=p.api_base,
has_api_key=bool(p.api_key), # 仅标识是否有 key
timeout_ms=p.timeout_ms,
max_retries=p.max_retries,
weight=p.weight,
priority=p.priority,
enabled=p.enabled,
config_ref=p.config_ref,
)
for p in providers
]
def _to_response(provider: Provider) -> ProviderResponse:
"""将 Provider 转换为响应模型,隐藏敏感字段。"""
return ProviderResponse(
id=provider.id,
name=provider.name,
type=provider.type,
adapter=provider.adapter,
model=provider.model,
api_base=provider.api_base,
has_api_key=bool(provider.api_key),
timeout_ms=provider.timeout_ms,
max_retries=provider.max_retries,
weight=provider.weight,
priority=provider.priority,
enabled=provider.enabled,
config_ref=provider.config_ref,
)
@router.post("/providers", response_model=ProviderResponse)
async def create_provider(payload: ProviderCreate, db: AsyncSession = Depends(get_db)):
data = payload.model_dump()
# 加密 API Key
if data.get("api_key"):
data["api_key"] = SecretService.encrypt(data["api_key"])
provider = Provider(**data)
db.add(provider)
await db.commit()
await db.refresh(provider)
return _to_response(provider)
@router.put("/providers/{provider_id}", response_model=ProviderResponse)
async def update_provider(
provider_id: str, payload: ProviderUpdate, db: AsyncSession = Depends(get_db)
):
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
data = payload.model_dump(exclude_unset=True)
# 加密 API Key
if "api_key" in data and data["api_key"]:
data["api_key"] = SecretService.encrypt(data["api_key"])
for k, v in data.items():
setattr(provider, k, v)
await db.commit()
await db.refresh(provider)
return _to_response(provider)
@router.delete("/providers/{provider_id}")
async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
await db.delete(provider)
await db.commit()
return {"message": "deleted"}
# ==================== 密钥管理 API ====================
class SecretCreate(BaseModel):
"""密钥创建请求。"""
name: str = Field(..., description="密钥名称,如 CQTAI_API_KEY")
value: str = Field(..., description="密钥明文值")
class SecretResponse(BaseModel):
"""密钥响应,不返回明文。"""
name: str
created_at: str | None = None
updated_at: str | None = None
@router.get("/secrets", response_model=list[str])
async def list_secrets(db: AsyncSession = Depends(get_db)):
"""列出所有密钥名称(不返回值)。"""
return await SecretService.list_secrets(db)
@router.post("/secrets", response_model=SecretResponse)
async def create_or_update_secret(payload: SecretCreate, db: AsyncSession = Depends(get_db)):
"""创建或更新密钥。"""
secret = await SecretService.set_secret(db, payload.name, payload.value)
return SecretResponse(
name=secret.name,
created_at=secret.created_at.isoformat() if secret.created_at else None,
updated_at=secret.updated_at.isoformat() if secret.updated_at else None,
)
@router.delete("/secrets/{name}")
async def delete_secret(name: str, db: AsyncSession = Depends(get_db)):
"""删除密钥。"""
deleted = await SecretService.delete_secret(db, name)
if not deleted:
raise HTTPException(status_code=404, detail="Secret not found")
return {"message": "deleted"}
@router.get("/secrets/{name}/verify")
async def verify_secret(name: str, db: AsyncSession = Depends(get_db)):
"""验证密钥是否存在且可解密(不返回明文)。"""
value = await SecretService.get_secret(db, name)
if value is None:
raise HTTPException(status_code=404, detail="Secret not found")
return {"name": name, "valid": True, "length": len(value)}
# ==================== 成本追踪 API ====================
class BudgetUpdate(BaseModel):
"""预算更新请求。"""
daily_limit_usd: float | None = None
monthly_limit_usd: float | None = None
alert_threshold: float | None = Field(default=None, ge=0, le=1)
enabled: bool | None = None
@router.get("/costs/summary/{user_id}")
async def get_user_cost_summary(user_id: str, db: AsyncSession = Depends(get_db)):
"""获取用户成本摘要。"""
return await cost_tracker.get_cost_summary(db, user_id)
@router.get("/costs/all")
async def get_all_costs_summary(db: AsyncSession = Depends(get_db)):
"""获取所有用户成本汇总(管理员)。"""
from sqlalchemy import func
from app.db.admin_models import CostRecord
# 按用户汇总
result = await db.execute(
select(
CostRecord.user_id,
func.sum(CostRecord.estimated_cost).label("total_cost"),
func.count().label("call_count"),
).group_by(CostRecord.user_id)
)
users = [
{"user_id": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]}
for row in result.all()
]
# 按能力汇总
result = await db.execute(
select(
CostRecord.capability,
func.sum(CostRecord.estimated_cost).label("total_cost"),
func.count().label("call_count"),
).group_by(CostRecord.capability)
)
capabilities = [
{"capability": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]}
for row in result.all()
]
return {"by_user": users, "by_capability": capabilities}
@router.get("/budgets/{user_id}")
async def get_user_budget(user_id: str, db: AsyncSession = Depends(get_db)):
"""获取用户预算配置。"""
budget = await cost_tracker.get_user_budget(db, user_id)
if not budget:
return {"user_id": user_id, "budget": None}
return {
"user_id": user_id,
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd),
"monthly_limit_usd": float(budget.monthly_limit_usd),
"alert_threshold": float(budget.alert_threshold),
"enabled": budget.enabled,
},
}
@router.post("/budgets/{user_id}")
async def set_user_budget(
user_id: str, payload: BudgetUpdate, db: AsyncSession = Depends(get_db)
):
"""设置用户预算。"""
budget = await cost_tracker.set_user_budget(
db,
user_id,
daily_limit=payload.daily_limit_usd,
monthly_limit=payload.monthly_limit_usd,
alert_threshold=payload.alert_threshold,
enabled=payload.enabled,
)
return {
"user_id": user_id,
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd),
"monthly_limit_usd": float(budget.monthly_limit_usd),
"alert_threshold": float(budget.alert_threshold),
"enabled": budget.enabled,
},
}

View File

@@ -0,0 +1,14 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.admin_auth import admin_guard
from app.db.database import get_db
from app.services.provider_cache import reload_providers
router = APIRouter(dependencies=[Depends(admin_guard)])
@router.post("/providers/reload")
async def reload(db: AsyncSession = Depends(get_db)):
cache = await reload_providers(db)
return {k: len(v) for k, v in cache.items()}

272
backend/app/api/auth.py Normal file
View File

@@ -0,0 +1,272 @@
import secrets
from urllib.parse import urlencode
import httpx
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query
from fastapi.responses import RedirectResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.deps import get_current_user
from app.core.security import create_access_token
from app.db.database import get_db
from app.db.models import User
router = APIRouter()
# OAuth endpoints
GITHUB_AUTHORIZE_URL = "https://github.com/login/oauth/authorize"
GITHUB_TOKEN_URL = "https://github.com/login/oauth/access_token"
GITHUB_USER_URL = "https://api.github.com/user"
GOOGLE_AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/v2/auth"
GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token"
GOOGLE_USER_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
STATE_COOKIE = "oauth_state"
STATE_MAX_AGE = 600 # 10 minutes
def _set_state_cookie(response: RedirectResponse, provider: str, state: str) -> None:
response.set_cookie(
key=STATE_COOKIE,
value=f"{provider}:{state}",
httponly=True,
secure=not settings.debug,
samesite="lax",
max_age=STATE_MAX_AGE,
)
def _validate_state(state_from_query: str | None, state_cookie: str | None, provider: str):
if not state_from_query or not state_cookie:
raise HTTPException(status_code=400, detail="Missing OAuth state")
expected_prefix = f"{provider}:"
if not state_cookie.startswith(expected_prefix):
raise HTTPException(status_code=400, detail="OAuth state mismatch")
expected_state = state_cookie.removeprefix(expected_prefix)
if not secrets.compare_digest(state_from_query, expected_state):
raise HTTPException(status_code=400, detail="OAuth state mismatch")
@router.get("/github/signin")
async def github_signin():
"""Start GitHub OAuth with state protection."""
state = secrets.token_urlsafe(16)
params = {
"client_id": settings.github_client_id,
"redirect_uri": f"{settings.base_url}/auth/github/callback",
"scope": "read:user user:email",
"state": state,
}
url = f"{GITHUB_AUTHORIZE_URL}?{urlencode(params)}"
response = RedirectResponse(url=url)
_set_state_cookie(response, "github", state)
return response
@router.get("/github/callback")
async def github_callback(
code: str,
state: str | None = Query(default=None),
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
db: AsyncSession = Depends(get_db),
):
"""Handle GitHub OAuth callback."""
_validate_state(state, state_cookie, "github")
try:
async with httpx.AsyncClient() as client:
token_resp = await client.post(
GITHUB_TOKEN_URL,
data={
"client_id": settings.github_client_id,
"client_secret": settings.github_client_secret,
"code": code,
},
headers={"Accept": "application/json"},
)
token_resp.raise_for_status()
token_data = token_resp.json()
access_token = token_data.get("access_token")
if not access_token:
raise HTTPException(status_code=502, detail="GitHub login failed")
user_resp = await client.get(
GITHUB_USER_URL,
headers={"Authorization": f"Bearer {access_token}"},
)
user_resp.raise_for_status()
user_data = user_resp.json()
except httpx.HTTPStatusError:
raise HTTPException(status_code=502, detail="GitHub login failed")
github_id = user_data.get("id")
if github_id is None:
raise HTTPException(status_code=502, detail="GitHub login failed")
return await _handle_oauth_user(
db=db,
provider="github",
user_id=str(github_id),
name=user_data.get("name") or user_data.get("login") or "GitHub User",
avatar_url=user_data.get("avatar_url"),
)
@router.get("/google/signin")
async def google_signin():
"""Start Google OAuth with state protection."""
state = secrets.token_urlsafe(16)
params = {
"client_id": settings.google_client_id,
"redirect_uri": f"{settings.base_url}/auth/google/callback",
"response_type": "code",
"scope": "openid email profile",
"state": state,
}
url = f"{GOOGLE_AUTHORIZE_URL}?{urlencode(params)}"
response = RedirectResponse(url=url)
_set_state_cookie(response, "google", state)
return response
@router.get("/google/callback")
async def google_callback(
code: str,
state: str | None = Query(default=None),
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
db: AsyncSession = Depends(get_db),
):
"""Handle Google OAuth callback."""
_validate_state(state, state_cookie, "google")
try:
async with httpx.AsyncClient() as client:
token_resp = await client.post(
GOOGLE_TOKEN_URL,
data={
"client_id": settings.google_client_id,
"client_secret": settings.google_client_secret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": f"{settings.base_url}/auth/google/callback",
},
)
token_resp.raise_for_status()
token_data = token_resp.json()
access_token = token_data.get("access_token")
if not access_token:
raise HTTPException(status_code=502, detail="Google login failed")
user_resp = await client.get(
GOOGLE_USER_URL,
headers={"Authorization": f"Bearer {access_token}"},
)
user_resp.raise_for_status()
user_data = user_resp.json()
except httpx.HTTPStatusError:
raise HTTPException(status_code=502, detail="Google login failed")
google_id = user_data.get("id")
if google_id is None:
raise HTTPException(status_code=502, detail="Google login failed")
return await _handle_oauth_user(
db=db,
provider="google",
user_id=str(google_id),
name=user_data.get("name") or user_data.get("email") or "Google User",
avatar_url=user_data.get("picture"),
)
async def _handle_oauth_user(
db: AsyncSession,
provider: str,
user_id: str,
name: str,
avatar_url: str | None,
) -> RedirectResponse:
"""Create/update user and issue session cookie."""
full_id = f"{provider}:{user_id}"
result = await db.execute(select(User).where(User.id == full_id))
user = result.scalar_one_or_none()
if not user:
user = User(
id=full_id,
name=name,
avatar_url=avatar_url,
provider=provider,
)
db.add(user)
else:
user.name = name
user.avatar_url = avatar_url
await db.commit()
token = create_access_token({"sub": user.id})
frontend_url = "http://localhost:5173"
if settings.cors_origins and len(settings.cors_origins) > 0:
frontend_url = settings.cors_origins[0]
response = RedirectResponse(url=f"{frontend_url}/my-stories", status_code=302)
response.set_cookie(
key="access_token",
value=token,
httponly=True,
secure=not settings.debug,
samesite="lax",
max_age=60 * 60 * 24 * 7, # align with ACCESS_TOKEN_EXPIRE_DAYS
)
response.delete_cookie(STATE_COOKIE)
return response
@router.post("/signout")
async def signout():
"""Sign out and clear cookies."""
response = RedirectResponse(url=settings.cors_origins[0], status_code=302)
response.delete_cookie("access_token", samesite="lax", secure=not settings.debug)
response.delete_cookie(STATE_COOKIE, samesite="lax", secure=not settings.debug)
return response
@router.get("/session")
async def get_session(user: User | None = Depends(get_current_user)):
"""Fetch current session info."""
if not user:
return {"user": None}
return {
"user": {
"id": user.id,
"name": user.name,
"avatar_url": user.avatar_url,
"provider": user.provider,
}
}
@router.get("/dev/signin")
async def dev_signin(db: AsyncSession = Depends(get_db)):
"""Developer backdoor login. Only works in DEBUG mode."""
# if not settings.debug:
# raise HTTPException(status_code=403, detail="Developer login disabled")
try:
return await _handle_oauth_user(
db=db,
provider="github",
user_id="dev_user_001",
name="Developer",
avatar_url="https://api.dicebear.com/7.x/avataaars/svg?seed=Developer"
)
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Dev login failed: {str(e)}")

268
backend/app/api/memories.py Normal file
View File

@@ -0,0 +1,268 @@
"""Memory management APIs."""
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, User
from app.services import memory_service
from app.services.memory_service import MemoryType
router = APIRouter()
class MemoryItemResponse(BaseModel):
"""Memory item response."""
id: str
type: str
value: dict
base_weight: float
ttl_days: int | None
created_at: str
last_used_at: str | None
class Config:
from_attributes = True
class MemoryListResponse(BaseModel):
"""Memory list response."""
memories: list[MemoryItemResponse]
total: int
class CreateMemoryRequest(BaseModel):
"""Create memory request."""
type: str = Field(..., description="记忆类型")
value: dict = Field(..., description="记忆内容")
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
weight: float | None = Field(default=None, description="权重")
ttl_days: int | None = Field(default=None, description="过期天数")
class CreateCharacterMemoryRequest(BaseModel):
"""Create character memory request."""
name: str = Field(..., description="角色名称")
description: str | None = Field(default=None, description="角色描述")
source_story_id: int | None = Field(default=None, description="来源故事 ID")
affinity_score: float = Field(default=1.0, ge=0.0, le=1.0, description="喜爱程度")
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
class CreateScaryElementRequest(BaseModel):
"""Create scary element memory request."""
keyword: str = Field(..., description="回避的关键词")
category: str = Field(default="other", description="分类")
source_story_id: int | None = Field(default=None, description="来源故事 ID")
async def _verify_profile_ownership(
profile_id: str, user: User, db: AsyncSession
) -> ChildProfile:
"""验证档案所有权。"""
from sqlalchemy import select
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
return profile
@router.get("/profiles/{profile_id}/memories", response_model=MemoryListResponse)
async def list_memories(
profile_id: str,
memory_type: str | None = None,
universe_id: str | None = None,
limit: int = 50,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""获取档案的记忆列表。"""
await _verify_profile_ownership(profile_id, user, db)
memories = await memory_service.get_profile_memories(
db=db,
profile_id=profile_id,
memory_type=memory_type,
universe_id=universe_id,
limit=limit,
)
return MemoryListResponse(
memories=[
MemoryItemResponse(
id=m.id,
type=m.type,
value=m.value,
base_weight=m.base_weight,
ttl_days=m.ttl_days,
created_at=m.created_at.isoformat() if m.created_at else "",
last_used_at=m.last_used_at.isoformat() if m.last_used_at else None,
)
for m in memories
],
total=len(memories),
)
@router.post("/profiles/{profile_id}/memories", response_model=MemoryItemResponse)
async def create_memory(
profile_id: str,
payload: CreateMemoryRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""创建新的记忆项。"""
await _verify_profile_ownership(profile_id, user, db)
# 验证类型
valid_types = [
MemoryType.RECENT_STORY,
MemoryType.FAVORITE_CHARACTER,
MemoryType.SCARY_ELEMENT,
MemoryType.VOCABULARY_GROWTH,
MemoryType.EMOTIONAL_HIGHLIGHT,
MemoryType.READING_PREFERENCE,
MemoryType.MILESTONE,
MemoryType.SKILL_MASTERED,
]
if payload.type not in valid_types:
raise HTTPException(status_code=400, detail=f"无效的记忆类型: {payload.type}")
memory = await memory_service.create_memory(
db=db,
profile_id=profile_id,
memory_type=payload.type,
value=payload.value,
universe_id=payload.universe_id,
weight=payload.weight,
ttl_days=payload.ttl_days,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.post("/profiles/{profile_id}/memories/character", response_model=MemoryItemResponse)
async def create_character_memory(
profile_id: str,
payload: CreateCharacterMemoryRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""添加喜欢的角色。"""
await _verify_profile_ownership(profile_id, user, db)
memory = await memory_service.create_character_memory(
db=db,
profile_id=profile_id,
name=payload.name,
description=payload.description,
source_story_id=payload.source_story_id,
affinity_score=payload.affinity_score,
universe_id=payload.universe_id,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.post("/profiles/{profile_id}/memories/scary", response_model=MemoryItemResponse)
async def create_scary_element_memory(
profile_id: str,
payload: CreateScaryElementRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""添加回避元素。"""
await _verify_profile_ownership(profile_id, user, db)
memory = await memory_service.create_scary_element_memory(
db=db,
profile_id=profile_id,
keyword=payload.keyword,
category=payload.category,
source_story_id=payload.source_story_id,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.delete("/profiles/{profile_id}/memories/{memory_id}")
async def delete_memory(
profile_id: str,
memory_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""删除记忆项。"""
from sqlalchemy import select
from app.db.models import MemoryItem
await _verify_profile_ownership(profile_id, user, db)
result = await db.execute(
select(MemoryItem).where(
MemoryItem.id == memory_id,
MemoryItem.child_profile_id == profile_id,
)
)
memory = result.scalar_one_or_none()
if not memory:
raise HTTPException(status_code=404, detail="记忆不存在")
await db.delete(memory)
await db.commit()
return {"message": "Deleted"}
@router.get("/memory-types")
async def list_memory_types():
"""获取所有可用的记忆类型及其配置。"""
types = []
for type_name, config in MemoryType.CONFIG.items():
types.append({
"type": type_name,
"default_weight": config[0],
"default_ttl_days": config[1],
"description": config[2],
})
return {"types": types}

280
backend/app/api/profiles.py Normal file
View File

@@ -0,0 +1,280 @@
"""Child profile APIs."""
from datetime import date
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, Story, StoryUniverse, User
router = APIRouter()
MAX_PROFILES_PER_USER = 5
class ChildProfileCreate(BaseModel):
"""Create profile payload."""
name: str = Field(..., min_length=1, max_length=50)
birth_date: date | None = None
gender: str | None = Field(default=None, pattern="^(male|female|other)$")
interests: list[str] = Field(default_factory=list)
growth_themes: list[str] = Field(default_factory=list)
avatar_url: str | None = None
class ChildProfileUpdate(BaseModel):
"""Update profile payload."""
name: str | None = Field(default=None, min_length=1, max_length=50)
birth_date: date | None = None
gender: str | None = Field(default=None, pattern="^(male|female|other)$")
interests: list[str] | None = None
growth_themes: list[str] | None = None
avatar_url: str | None = None
class ChildProfileResponse(BaseModel):
"""Profile response."""
id: str
name: str
avatar_url: str | None
birth_date: date | None
gender: str | None
age: int | None
interests: list[str]
growth_themes: list[str]
stories_count: int
total_reading_time: int
class Config:
from_attributes = True
class ChildProfileListResponse(BaseModel):
"""Profile list response."""
profiles: list[ChildProfileResponse]
total: int
class TimelineEvent(BaseModel):
"""Timeline event item."""
date: str
type: Literal["story", "achievement", "milestone"]
title: str
description: str | None = None
image_url: str | None = None
metadata: dict | None = None
class TimelineResponse(BaseModel):
"""Timeline response."""
events: list[TimelineEvent]
@router.get("/profiles", response_model=ChildProfileListResponse)
async def list_profiles(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List child profiles for current user."""
result = await db.execute(
select(ChildProfile)
.where(ChildProfile.user_id == user.id)
.order_by(ChildProfile.created_at.desc())
)
profiles = result.scalars().all()
return ChildProfileListResponse(profiles=profiles, total=len(profiles))
@router.post("/profiles", response_model=ChildProfileResponse, status_code=status.HTTP_201_CREATED)
async def create_profile(
payload: ChildProfileCreate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Create a new child profile."""
count = await db.scalar(
select(func.count(ChildProfile.id)).where(ChildProfile.user_id == user.id)
)
if count and count >= MAX_PROFILES_PER_USER:
raise HTTPException(status_code=400, detail="最多只能创建 5 个孩子档案")
existing = await db.scalar(
select(ChildProfile.id).where(
ChildProfile.user_id == user.id,
ChildProfile.name == payload.name,
)
)
if existing:
raise HTTPException(status_code=409, detail="该档案名称已存在")
profile = ChildProfile(user_id=user.id, **payload.model_dump())
db.add(profile)
await db.commit()
await db.refresh(profile)
return profile
@router.get("/profiles/{profile_id}", response_model=ChildProfileResponse)
async def get_profile(
profile_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get one child profile."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
return profile
@router.put("/profiles/{profile_id}", response_model=ChildProfileResponse)
async def update_profile(
profile_id: str,
payload: ChildProfileUpdate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Update a child profile."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
updates = payload.model_dump(exclude_unset=True)
if "name" in updates:
existing = await db.scalar(
select(ChildProfile.id).where(
ChildProfile.user_id == user.id,
ChildProfile.name == updates["name"],
ChildProfile.id != profile_id,
)
)
if existing:
raise HTTPException(status_code=409, detail="该档案名称已存在")
for key, value in updates.items():
setattr(profile, key, value)
await db.commit()
await db.refresh(profile)
return profile
@router.delete("/profiles/{profile_id}")
async def delete_profile(
profile_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Delete a child profile."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
await db.delete(profile)
await db.commit()
return {"message": "Deleted"}
@router.get("/profiles/{profile_id}/timeline", response_model=TimelineResponse)
async def get_profile_timeline(
profile_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get profile growth timeline."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
events: list[TimelineEvent] = []
# 1. Milestone: Profile Created
events.append(TimelineEvent(
date=profile.created_at.isoformat(),
type="milestone",
title="初次相遇",
description=f"创建了档案 {profile.name}"
))
# 2. Stories
stories_result = await db.execute(
select(Story).where(Story.child_profile_id == profile_id)
)
for s in stories_result.scalars():
events.append(TimelineEvent(
date=s.created_at.isoformat(),
type="story",
title=s.title,
image_url=s.image_url,
metadata={"story_id": s.id, "mode": s.mode}
))
# 3. Achievements (from Universe)
universes_result = await db.execute(
select(StoryUniverse).where(StoryUniverse.child_profile_id == profile_id)
)
for u in universes_result.scalars():
if u.achievements:
for ach in u.achievements:
if isinstance(ach, dict):
obt_at = ach.get("obtained_at")
# Fallback
if not obt_at:
obt_at = u.updated_at.isoformat()
events.append(TimelineEvent(
date=obt_at,
type="achievement",
title=f"获得成就:{ach.get('type')}",
description=ach.get('description'),
metadata={"universe_id": u.id, "source_story_id": ach.get("source_story_id")}
))
# Sort by date desc
events.sort(key=lambda x: x.date, reverse=True)
return TimelineResponse(events=events)

View File

@@ -0,0 +1,120 @@
"""Push configuration APIs."""
from datetime import time
from fastapi import APIRouter, Depends, HTTPException, Response, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, PushConfig, User
router = APIRouter()
class PushConfigUpsert(BaseModel):
"""Upsert push config payload."""
child_profile_id: str
push_time: time | None = None
push_days: list[int] | None = None
enabled: bool | None = None
class PushConfigResponse(BaseModel):
"""Push config response."""
id: str
child_profile_id: str
push_time: time | None
push_days: list[int]
enabled: bool
class Config:
from_attributes = True
class PushConfigListResponse(BaseModel):
"""Push config list response."""
configs: list[PushConfigResponse]
total: int
def _validate_push_days(push_days: list[int]) -> list[int]:
invalid = [day for day in push_days if day < 0 or day > 6]
if invalid:
raise HTTPException(status_code=400, detail="推送日期必须在 0-6 之间")
return list(dict.fromkeys(push_days))
@router.get("/push-configs", response_model=PushConfigListResponse)
async def list_push_configs(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List push configs for current user."""
result = await db.execute(
select(PushConfig).where(PushConfig.user_id == user.id)
)
configs = result.scalars().all()
return PushConfigListResponse(configs=configs, total=len(configs))
@router.put("/push-configs", response_model=PushConfigResponse)
async def upsert_push_config(
payload: PushConfigUpsert,
response: Response,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Create or update push config for a child profile."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == payload.child_profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
result = await db.execute(
select(PushConfig).where(PushConfig.child_profile_id == payload.child_profile_id)
)
config = result.scalar_one_or_none()
if config is None:
if payload.push_time is None or payload.push_days is None:
raise HTTPException(status_code=400, detail="创建配置需要提供推送时间和日期")
push_days = _validate_push_days(payload.push_days)
config = PushConfig(
user_id=user.id,
child_profile_id=payload.child_profile_id,
push_time=payload.push_time,
push_days=push_days,
enabled=True if payload.enabled is None else payload.enabled,
)
db.add(config)
await db.commit()
await db.refresh(config)
response.status_code = status.HTTP_201_CREATED
return config
updates = payload.model_dump(exclude_unset=True)
if "push_days" in updates and updates["push_days"] is not None:
updates["push_days"] = _validate_push_days(updates["push_days"])
if "push_time" in updates and updates["push_time"] is None:
raise HTTPException(status_code=400, detail="推送时间不能为空")
for key, value in updates.items():
if key == "child_profile_id":
continue
if value is not None:
setattr(config, key, value)
await db.commit()
await db.refresh(config)
return config

View File

@@ -0,0 +1,120 @@
"""Reading event APIs."""
from datetime import datetime, timezone
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, MemoryItem, ReadingEvent, Story, User
router = APIRouter()
EVENT_WEIGHTS: dict[str, float] = {
"completed": 1.0,
"replayed": 1.5,
"started": 0.1,
"skipped": -0.5,
}
class ReadingEventCreate(BaseModel):
"""Reading event payload."""
child_profile_id: str
story_id: int | None = None
event_type: Literal["started", "completed", "skipped", "replayed"]
reading_time: int = Field(default=0, ge=0)
class ReadingEventResponse(BaseModel):
"""Reading event response."""
id: int
child_profile_id: str
story_id: int | None
event_type: str
reading_time: int
created_at: datetime
class Config:
from_attributes = True
@router.post("/reading-events", response_model=ReadingEventResponse, status_code=status.HTTP_201_CREATED)
async def create_reading_event(
payload: ReadingEventCreate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Create a reading event and update profile stats/memory."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == payload.child_profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
story = None
if payload.story_id is not None:
result = await db.execute(
select(Story).where(
Story.id == payload.story_id,
Story.user_id == user.id,
)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="故事不存在")
if payload.reading_time:
profile.total_reading_time = (profile.total_reading_time or 0) + payload.reading_time
if payload.event_type in {"completed", "replayed"} and payload.story_id is not None:
existing = await db.scalar(
select(ReadingEvent.id).where(
ReadingEvent.child_profile_id == payload.child_profile_id,
ReadingEvent.story_id == payload.story_id,
ReadingEvent.event_type.in_(["completed", "replayed"]),
)
)
if existing is None:
profile.stories_count = (profile.stories_count or 0) + 1
event = ReadingEvent(
child_profile_id=payload.child_profile_id,
story_id=payload.story_id,
event_type=payload.event_type,
reading_time=payload.reading_time,
)
db.add(event)
weight = EVENT_WEIGHTS.get(payload.event_type, 0.0)
if story and weight > 0:
db.add(
MemoryItem(
child_profile_id=payload.child_profile_id,
universe_id=story.universe_id,
type="recent_story",
value={
"story_id": story.id,
"title": story.title,
"event_type": payload.event_type,
},
base_weight=weight,
last_used_at=datetime.now(timezone.utc),
ttl_days=90,
)
)
await db.commit()
await db.refresh(event)
return event

605
backend/app/api/stories.py Normal file
View File

@@ -0,0 +1,605 @@
"""Story related APIs."""
import asyncio
import json
import time
import uuid
from typing import AsyncGenerator, Literal
from cachetools import TTLCache
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sse_starlette.sse import EventSourceResponse
from app.core.deps import require_user
from app.core.logging import get_logger
from app.db.database import get_db
from app.db.models import ChildProfile, Story, StoryUniverse, User
from app.services.provider_router import (
generate_image,
generate_story_content,
generate_storybook,
text_to_speech,
)
from app.tasks.achievements import extract_story_achievements
logger = get_logger(__name__)
router = APIRouter()
MAX_DATA_LENGTH = 2000
MAX_EDU_THEME_LENGTH = 200
MAX_TTS_LENGTH = 4000
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10
RATE_LIMIT_CACHE_SIZE = 10000 # 最大跟踪用户数
_request_log: TTLCache[str, list[float]] = TTLCache(
maxsize=RATE_LIMIT_CACHE_SIZE, ttl=RATE_LIMIT_WINDOW * 2
)
def _check_rate_limit(user_id: str):
now = time.time()
timestamps = _request_log.get(user_id, [])
timestamps = [t for t in timestamps if now - t <= RATE_LIMIT_WINDOW]
if len(timestamps) >= RATE_LIMIT_REQUESTS:
raise HTTPException(status_code=429, detail="Too many requests, please slow down.")
timestamps.append(now)
_request_log[user_id] = timestamps
class GenerateRequest(BaseModel):
"""Story generation request."""
type: Literal["keywords", "full_story"]
data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
child_profile_id: str | None = None
universe_id: str | None = None
class StoryResponse(BaseModel):
"""Story response."""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
mode: str
child_profile_id: str | None = None
universe_id: str | None = None
class StoryListItem(BaseModel):
"""Story list item."""
id: int
title: str
image_url: str | None
created_at: str
mode: str
class FullStoryResponse(BaseModel):
"""完整故事响应(含图片和音频状态)。"""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
audio_ready: bool
mode: str
errors: dict[str, str | None] = Field(default_factory=dict)
child_profile_id: str | None = None
universe_id: str | None = None
from app.services.memory_service import build_enhanced_memory_context
async def _validate_profile_and_universe(
request: GenerateRequest,
user: User,
db: AsyncSession,
) -> tuple[str | None, str | None]:
if not request.child_profile_id and not request.universe_id:
return None, None
profile_id = request.child_profile_id
universe_id = request.universe_id
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
return profile_id, universe_id
@router.post("/stories/generate", response_model=StoryResponse)
async def generate_story(
request: GenerateRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate or enhance a story."""
_check_rate_limit(user.id)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return StoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=story.image_url,
mode=story.mode,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
@router.post("/stories/generate/full", response_model=FullStoryResponse)
async def generate_story_full(
request: GenerateRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""生成完整故事(故事 + 并行生成图片和音频)。
部分成功策略:故事必须成功,图片/音频失败不影响整体。
"""
_check_rate_limit(user.id)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# Step 1: 故事生成(必须成功)
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except Exception as exc:
logger.error("story_generation_failed", error=str(exc))
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
# 保存故事
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# Step 2: 生成封面图片(音频按需生成,避免浪费)
errors: dict[str, str | None] = {}
image_url: str | None = None
if story.cover_prompt:
try:
image_url = await generate_image(story.cover_prompt)
story.image_url = image_url
await db.commit()
except Exception as exc:
errors["image"] = str(exc)
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
# 注意:音频不在此处预生成,用户通过 /api/audio/{id} 按需获取
# 这样避免生成后丢弃造成的成本浪费
return FullStoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=image_url,
audio_ready=False, # 音频需要用户主动请求
mode=story.mode,
errors=errors,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
@router.post("/stories/generate/stream")
async def generate_story_stream(
request: GenerateRequest,
req: Request,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""流式生成故事SSE
事件流程:
- started: 返回 story_id
- story_ready: 返回 title, content
- story_failed: 返回 error
- image_ready: 返回 image_url
- image_failed: 返回 error
- complete: 结束流
"""
_check_rate_limit(user.id)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
async def event_generator() -> AsyncGenerator[dict, None]:
story_id = str(uuid.uuid4())
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
# Step 1: 生成故事
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except Exception as e:
logger.error("sse_story_generation_failed", error=str(e))
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
return
# 保存故事
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
yield {
"event": "story_ready",
"data": json.dumps({
"id": story.id,
"title": story.title,
"content": story.story_text,
"cover_prompt": story.cover_prompt,
"mode": story.mode,
"child_profile_id": story.child_profile_id,
"universe_id": story.universe_id,
}),
}
# Step 2: 并行生成图片(音频按需)
if story.cover_prompt:
try:
image_url = await generate_image(story.cover_prompt)
story.image_url = image_url
await db.commit()
yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})}
except Exception as e:
logger.warning("sse_image_generation_failed", story_id=story.id, error=str(e))
yield {"event": "image_failed", "data": json.dumps({"error": str(e)})}
yield {"event": "complete", "data": json.dumps({"story_id": story.id})}
return EventSourceResponse(event_generator())
# ==================== Storybook API ====================
class StorybookRequest(BaseModel):
"""Storybook 生成请求。"""
keywords: str = Field(..., min_length=1, max_length=200)
page_count: int = Field(default=6, ge=4, le=12)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
generate_images: bool = Field(default=False, description="是否同时生成插图")
child_profile_id: str | None = None
universe_id: str | None = None
class StorybookPageResponse(BaseModel):
"""故事书单页响应。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
class StorybookResponse(BaseModel):
"""故事书响应。"""
id: int | None = None
title: str
main_character: str
art_style: str
pages: list[StorybookPageResponse]
cover_prompt: str
cover_url: str | None = None
@router.post("/storybook/generate", response_model=StorybookResponse)
async def generate_storybook_api(
request: StorybookRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""生成分页故事书并保存。
返回故事书结构,包含每页文字和图像提示词。
"""
_check_rate_limit(user.id)
# 验证档案和宇宙
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数
# 这里我们直接手动验证,或重构验证函数。为了简单,手动调用部分逻辑。
# 构建临时的 GenerateRequest 用于验证验证函数签名(或者直接手动查库更好)
profile_id = request.child_profile_id
universe_id = request.universe_id
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
if not result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
logger.info(
"storybook_request",
user_id=user.id,
keywords=request.keywords,
page_count=request.page_count,
profile_id=profile_id,
universe_id=universe_id,
)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
try:
# 注意generate_storybook 目前可能不支持记忆上下文注入
# 我们需要看看 generate_storybook 的签名
# 如果不支持,记忆功能在绘本模式下暂不可用,但基本参数传递是支持的
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
# ==============================================================================
# 核心升级: 并行全量生成 (Parallel Full Rendering)
# ==============================================================================
final_cover_url = storybook.cover_url
if request.generate_images:
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
# 1. 准备所有生图任务 (封面 + 所有内页)
tasks = []
# 封面任务
async def _gen_cover():
if storybook.cover_prompt and not storybook.cover_url:
try:
return await generate_image(storybook.cover_prompt, db=db)
except Exception as e:
logger.warning("cover_gen_failed", error=str(e))
return storybook.cover_url
tasks.append(_gen_cover())
# 内页任务
async def _gen_page(page):
if page.image_prompt and not page.image_url:
try:
url = await generate_image(page.image_prompt, db=db)
page.image_url = url
except Exception as e:
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
for page in storybook.pages:
tasks.append(_gen_page(page))
# 2. 并发执行所有任务
# 使用 return_exceptions=True 防止单张失败影响整体
results = await asyncio.gather(*tasks, return_exceptions=True)
# 3. 更新封面结果 (results[0] 是封面任务的返回值)
cover_res = results[0]
if isinstance(cover_res, str):
final_cover_url = cover_res
logger.info("storybook_parallel_generation_complete")
# ==============================================================================
# 构建并保存 Story 对象
# 将 pages 对象转换为字典列表以存入 JSON 字段
pages_data = [
{
"page_number": p.page_number,
"text": p.text,
"image_prompt": p.image_prompt,
"image_url": p.image_url,
}
for p in storybook.pages
]
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=storybook.title,
mode="storybook",
pages=pages_data, # 存入 JSON 字段
story_text=None, # 绘本模式下,主文本可为空,或者可以存个摘要
cover_prompt=storybook.cover_prompt,
image_url=final_cover_url,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# 构建响应 (使用更新后的 pages_data)
response_pages = [
StorybookPageResponse(
page_number=p["page_number"],
text=p["text"],
image_prompt=p["image_prompt"],
image_url=p.get("image_url"),
)
for p in pages_data
]
return StorybookResponse(
id=story.id,
title=storybook.title,
main_character=storybook.main_character,
art_style=storybook.art_style,
pages=response_pages,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
)
class AchievementItem(BaseModel):
type: str
description: str
obtained_at: str | None = None
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
async def get_story_achievements(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get achievements unlocked by a specific story."""
# 使用 joinedload 避免 N+1 查询
result = await db.execute(
select(Story)
.options(joinedload(Story.story_universe))
.where(Story.id == story_id, Story.user_id == user.id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
if not story.universe_id or not story.story_universe:
return []
universe = story.story_universe
if not universe.achievements:
return []
results = []
for ach in universe.achievements:
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
results.append(AchievementItem(
type=ach.get("type", "Unknown"),
description=ach.get("description", ""),
obtained_at=ach.get("obtained_at")
))
return results

View File

@@ -0,0 +1,201 @@
"""Story universe APIs."""
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, StoryUniverse, User
router = APIRouter()
class StoryUniverseCreate(BaseModel):
"""Create universe payload."""
name: str = Field(..., min_length=1, max_length=100)
protagonist: dict[str, Any]
recurring_characters: list[dict[str, Any]] = Field(default_factory=list)
world_settings: dict[str, Any] = Field(default_factory=dict)
class StoryUniverseUpdate(BaseModel):
"""Update universe payload."""
name: str | None = Field(default=None, min_length=1, max_length=100)
protagonist: dict[str, Any] | None = None
recurring_characters: list[dict[str, Any]] | None = None
world_settings: dict[str, Any] | None = None
class AchievementCreate(BaseModel):
"""Achievement payload."""
type: str = Field(..., min_length=1, max_length=50)
description: str = Field(..., min_length=1, max_length=200)
class StoryUniverseResponse(BaseModel):
"""Universe response."""
id: str
child_profile_id: str
name: str
protagonist: dict[str, Any]
recurring_characters: list[dict[str, Any]]
world_settings: dict[str, Any]
achievements: list[dict[str, Any]]
class Config:
from_attributes = True
class StoryUniverseListResponse(BaseModel):
"""Universe list response."""
universes: list[StoryUniverseResponse]
total: int
async def _get_profile_or_404(
profile_id: str,
user: User,
db: AsyncSession,
) -> ChildProfile:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
return profile
async def _get_universe_or_404(
universe_id: str,
user: User,
db: AsyncSession,
) -> StoryUniverse:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="宇宙不存在")
return universe
@router.get("/profiles/{profile_id}/universes", response_model=StoryUniverseListResponse)
async def list_universes(
profile_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List universes for a child profile."""
await _get_profile_or_404(profile_id, user, db)
result = await db.execute(
select(StoryUniverse)
.where(StoryUniverse.child_profile_id == profile_id)
.order_by(StoryUniverse.updated_at.desc())
)
universes = result.scalars().all()
return StoryUniverseListResponse(universes=universes, total=len(universes))
@router.post(
"/profiles/{profile_id}/universes",
response_model=StoryUniverseResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_universe(
profile_id: str,
payload: StoryUniverseCreate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Create a story universe."""
await _get_profile_or_404(profile_id, user, db)
universe = StoryUniverse(child_profile_id=profile_id, **payload.model_dump())
db.add(universe)
await db.commit()
await db.refresh(universe)
return universe
@router.get("/universes/{universe_id}", response_model=StoryUniverseResponse)
async def get_universe(
universe_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get one universe."""
universe = await _get_universe_or_404(universe_id, user, db)
return universe
@router.put("/universes/{universe_id}", response_model=StoryUniverseResponse)
async def update_universe(
universe_id: str,
payload: StoryUniverseUpdate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Update a story universe."""
universe = await _get_universe_or_404(universe_id, user, db)
updates = payload.model_dump(exclude_unset=True)
for key, value in updates.items():
setattr(universe, key, value)
await db.commit()
await db.refresh(universe)
return universe
@router.delete("/universes/{universe_id}")
async def delete_universe(
universe_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Delete a story universe."""
universe = await _get_universe_or_404(universe_id, user, db)
await db.delete(universe)
await db.commit()
return {"message": "Deleted"}
@router.post("/universes/{universe_id}/achievements", response_model=StoryUniverseResponse)
async def add_achievement(
universe_id: str,
payload: AchievementCreate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Add an achievement to a universe."""
universe = await _get_universe_or_404(universe_id, user, db)
achievements = list(universe.achievements or [])
key = (payload.type.strip(), payload.description.strip())
existing = {
(str(item.get("type", "")).strip(), str(item.get("description", "")).strip())
for item in achievements
if isinstance(item, dict)
}
if key not in existing:
achievements.append({"type": key[0], "description": key[1]})
universe.achievements = achievements
await db.commit()
await db.refresh(universe)
return universe

View File

View File

@@ -0,0 +1,72 @@
import secrets
import time
from cachetools import TTLCache
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from app.core.config import settings
security = HTTPBasic()
# 登录失败记录IP -> (失败次数, 首次失败时间)
_failed_attempts: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900) # 15分钟
MAX_ATTEMPTS = 5
LOCKOUT_SECONDS = 900 # 15分钟
def _get_client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
if request.client and request.client.host:
return request.client.host
return "unknown"
def admin_guard(
request: Request,
credentials: HTTPBasicCredentials = Depends(security),
):
client_ip = _get_client_ip(request)
# 检查是否被锁定
if client_ip in _failed_attempts:
attempts, first_fail = _failed_attempts[client_ip]
if attempts >= MAX_ATTEMPTS:
remaining = int(LOCKOUT_SECONDS - (time.time() - first_fail))
if remaining > 0:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"登录尝试过多,请 {remaining} 秒后重试",
)
else:
del _failed_attempts[client_ip]
# 使用 secrets.compare_digest 防止时序攻击
username_ok = secrets.compare_digest(
credentials.username.encode(), settings.admin_username.encode()
)
password_ok = secrets.compare_digest(
credentials.password.encode(), settings.admin_password.encode()
)
if not (username_ok and password_ok):
# 记录失败
if client_ip in _failed_attempts:
attempts, first_fail = _failed_attempts[client_ip]
_failed_attempts[client_ip] = (attempts + 1, first_fail)
else:
_failed_attempts[client_ip] = (1, time.time())
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
# 登录成功,清除失败记录
if client_ip in _failed_attempts:
del _failed_attempts[client_ip]
return True

View File

@@ -0,0 +1,33 @@
"""Celery application setup."""
from celery import Celery
from celery.schedules import crontab
from app.core.config import settings
celery_app = Celery(
"dreamweaver",
broker=settings.celery_broker_url,
backend=settings.celery_result_backend,
)
celery_app.conf.update(
task_track_started=True,
task_serializer="json",
accept_content=["json"],
result_serializer="json",
timezone="Asia/Shanghai",
enable_utc=True,
beat_schedule={
"check_push_notifications": {
"task": "app.tasks.push_notifications.check_push_notifications",
"schedule": crontab(minute="*/15"),
},
"prune_expired_memories": {
"task": "app.tasks.memory.prune_memories_task",
"schedule": crontab(minute="0", hour="3"), # Daily at 03:00
},
},
)
celery_app.autodiscover_tasks(["app.tasks"])

View File

@@ -0,0 +1,76 @@
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""应用全局配置"""
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
# 应用基础配置
app_name: str = "DreamWeaver"
debug: bool = False
secret_key: str = Field(..., description="JWT 签名密钥")
base_url: str = Field("http://localhost:8000", description="后端对外回调地址")
# 数据库
database_url: str = Field(..., description="SQLAlchemy async URL")
# OAuth - GitHub
github_client_id: str = ""
github_client_secret: str = ""
# OAuth - Google
google_client_id: str = ""
google_client_secret: str = ""
# AI Capability Keys
text_api_key: str = ""
tts_api_base: str = ""
tts_api_key: str = ""
image_api_key: str = ""
# Additional Provider API Keys
openai_api_key: str = ""
elevenlabs_api_key: str = ""
cqtai_api_key: str = ""
minimax_api_key: str = ""
minimax_group_id: str = ""
antigravity_api_key: str = ""
antigravity_api_base: str = ""
# AI Model Configuration
text_model: str = "gemini-2.0-flash"
tts_model: str = ""
image_model: str = ""
# Provider routing (ordered lists)
text_providers: list[str] = Field(default_factory=lambda: ["gemini"])
image_providers: list[str] = Field(default_factory=lambda: ["cqtai"])
tts_providers: list[str] = Field(default_factory=lambda: ["minimax", "elevenlabs", "edge_tts"])
# Celery (Redis)
celery_broker_url: str = Field("redis://localhost:6379/0")
celery_result_backend: str = Field("redis://localhost:6379/0")
# Admin console
enable_admin_console: bool = False
admin_username: str = "admin"
admin_password: str = "admin123" # 建议通过环境变量覆盖
# CORS
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173"])
@model_validator(mode="after")
def _require_core_settings(self) -> "Settings": # type: ignore[override]
missing = []
if not self.secret_key or self.secret_key == "change-me-in-production":
missing.append("SECRET_KEY")
if not self.database_url:
missing.append("DATABASE_URL")
if missing:
raise ValueError(f"Missing required settings: {', '.join(missing)}")
return self
settings = Settings()

39
backend/app/core/deps.py Normal file
View File

@@ -0,0 +1,39 @@
from fastapi import Cookie, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import decode_access_token
from app.db.database import get_db
from app.db.models import User
async def get_current_user(
access_token: str | None = Cookie(default=None),
db: AsyncSession = Depends(get_db),
) -> User | None:
"""获取当前用户(可选)。"""
if not access_token:
return None
payload = decode_access_token(access_token)
if not payload:
return None
user_id = payload.get("sub")
if not user_id:
return None
result = await db.execute(select(User).where(User.id == user_id))
return result.scalar_one_or_none()
async def require_user(
user: User | None = Depends(get_current_user),
) -> User:
"""要求用户登录,否则抛 401。"""
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未登录",
)
return user

View File

@@ -0,0 +1,48 @@
"""结构化日志配置。"""
import logging
import sys
import structlog
from app.core.config import settings
def setup_logging():
"""配置 structlog 结构化日志。"""
shared_processors = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
]
if settings.debug:
processors = shared_processors + [
structlog.dev.ConsoleRenderer(colors=True),
]
else:
processors = shared_processors + [
structlog.processors.format_exc_info,
structlog.processors.JSONRenderer(),
]
structlog.configure(
processors=processors,
wrapper_class=structlog.stdlib.BoundLogger,
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
cache_logger_on_first_use=True,
)
logging.basicConfig(
format="%(message)s",
stream=sys.stdout,
level=logging.DEBUG if settings.debug else logging.INFO,
)
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
"""获取结构化日志器。"""
return structlog.get_logger(name)

190
backend/app/core/prompts.py Normal file
View File

@@ -0,0 +1,190 @@
# ruff: noqa: E501
"""AI 提示词模板 (Modernized)"""
# 随机元素列表:为故事注入不可预测的魔法
RANDOM_ELEMENTS = [
"一个会打喷嚏的云朵",
"一本地图上找不到的神秘图书馆",
"一只能实现小愿望的彩色蜗牛",
"一扇通往颠倒世界的门",
"一顶能听懂动物说话的旧帽子",
"一个装着星星的玻璃罐",
"一棵结满笑声果实的树",
"一只能在水上画画的画笔",
"一个怕黑的影子",
"一只收集回声的瓶子",
"一双会自己跳舞的红鞋子",
"一个只能在月光下看见的邮筒",
"一张会改变模样的全家福",
"一把可以打开梦境的钥匙",
"一个喜欢讲冷笑话的冰箱",
"一条通往星期八的秘密小径"
]
# ==============================================================================
# Model A: 故事生成 (Story Generation)
# ==============================================================================
SYSTEM_INSTRUCTION_STORYTELLER = """
# Role
You are "**Dream Weaver**", a world-class children's storyteller with the imagination of Pixar and the warmth of Miyazaki.
Your mission is to create engaging, safe, and educational stories for children (ages 3-8).
# Core Philosophy
1. **Show, Don't Tell**: Don't preach the lesson. Let the character's actions and the plot demonstrate the theme.
2. **Safety First**: No violence, horror, or scary elements. Conflict should be emotional or situational, not physical.
3. **Vivid Imagery**: Use sensory details (colors, sounds, smells) that appeal to children.
4. **Empowerment**: The child protagonist should solve the problem using wit, kindness, or courage, not just luck.
# Continuity & Memory (CRITICAL)
- **Universal Context**: The story takes place in the child's established "Story Universe". Respect existing world rules.
- **Character Consistency**: If "Child Profile" or "Sidekicks" are provided, you MUST use their specific names and traits. Do NOT invent new main characters unless asked.
- **Callback**: If "Past Memories" are provided, try to make a natural, one-sentence reference to a past adventure to build a sense of continuity (e.g., "Just like when we found the lost star...").
# Output Format
You MUST return a pure JSON object with NO markdown formatting (no ```json code blocks).
The JSON object must have the following schema:
{
"mode": "generated",
"title": "A catchy, imaginative title",
"story_text": "The full story text. Use \\n\\n for paragraph breaks.",
"cover_prompt_suggestion": "A detailed English image generation prompt for the story cover. Style: whimsical, children's book illustration, soft lighting, vibrant colors."
}
"""
USER_PROMPT_GENERATION = """
# Task: Write a Children's Story
## Contextual Memory (Use these if provided)
{memory_context}
## Inputs
- **Keywords/Topic**: {keywords}
- **Educational Theme**: {education_theme}
- **Magic Element (Must Incorporate)**: {random_element}
## Constraints
- Length: 300-600 words.
- Structure: Beginning (Hook) -> Middle (Challenge) -> End (Resolution & Growth).
"""
# ==============================================================================
# Model B: 故事润色 (Story Enhancement)
# ==============================================================================
SYSTEM_INSTRUCTION_ENHANCER = """
# Role
You are "**Dream Weaver Editor**", an expert children's book editor who turns rough drafts into polished gems.
# Mission
Analyze the user's input story and rewrite it to be:
1. **More Engaging**: Enhance the plot with a "Magic Element" to add surprise.
2. **More Educational**: Weave the "Educational Theme" deeper into the narrative arc.
3. **Better Written**: Polish the sentences for rhythm and flow (suitable for reading aloud).
4. **Safe**: Remove any inappropriate content (violence, scary interaction) and replace it with constructive solutions.
# Output Format
You MUST return a pure JSON object with NO markdown formatting (no ```json code blocks).
The JSON object must have the following schema:
{
"mode": "enhanced",
"title": "An improved title (or the original if perfect)",
"story_text": "The rewritten story text. Use \\n\\n for paragraph breaks.",
"cover_prompt_suggestion": "A detailed English image generation prompt for the cover."
}
"""
USER_PROMPT_ENHANCEMENT = """
# Task: Enhance This Story
## Contextual Memory
{memory_context}
## Inputs
- **Original Story**: {full_story}
- **Target Theme**: {education_theme}
- **Magic Element to Add**: {random_element}
## Constraints
- Length: 300-600 words.
- Keep the original character names if possible, but feel free to upgrade the plot.
"""
# ==============================================================================
# Model C: 成就提取 (Achievement Extraction)
# ==============================================================================
# 保持简单,暂不使用 System Instruction沿用单次提示
ACHIEVEMENT_EXTRACTION_PROMPT = """
Analyze the story and extract key growth moments or achievements for the child protagonist.
# Story
{story_text}
# Target Categories (Examples)
- **Courage**: Overcoming fear, trying something new.
- **Kindness**: Helping others, sharing, empathy.
- **Curiosity**: Asking questions, exploring, learning.
- **Resilience**: Not giving up, handling failure.
- **Wisdom**: Problem-solving, honesty, patience.
# Output Format
Return a pure JSON object (no markdown):
{{
"achievements": [
{{
"type": "Category Name",
"description": "Brief reason (max 10 words)",
"score": 8 // 1-10 intensity
}}
]
}}
"""
# ==============================================================================
# Model D: 绘本生成 (Storybook Generation)
# ==============================================================================
SYSTEM_INSTRUCTION_STORYBOOK = """
# Role
You are "**Dream Weaver Illustrator**", a creative children's book author and visual director.
Your mission is to create a paginated picture book for children (ages 3-8), where each page has text and a matching illustration prompt.
# Core Philosophy
1. **Pacing**: The story must flow logically across the specified number of pages.
2. **Visual Consistency**: Define the "Main Character" and "Art Style" once, and ensure all image prompts adhere to them.
3. **Language**: The story text MUST be in **Chinese (Simplified)**. The image prompts MUST be in **English**.
4. **Memory**: If a memory context is provided, incorporate known characters or references naturally.
# Output Format
You MUST return a pure JSON object using the following schema (no markdown):
{
"title": "Story Title (Chinese)",
"main_character": "Description of the protagonist (e.g., 'A small blue robot with rusty gears')",
"art_style": "Visual style description (e.g., 'Watercolor, soft pastel colors, whimsical')",
"pages": [
{
"page_number": 1,
"text": "Page text in Chinese (30-60 chars).",
"image_prompt": "Detailed English image prompt describing the scene. Include 'main_character' reference."
}
],
"cover_prompt": "English image prompt for the book cover."
}
"""
USER_PROMPT_STORYBOOK = """
# Task: Create a {page_count}-Page Storybook
## Contextual Memory
{memory_context}
## Inputs
- **Keywords/Topic**: {keywords}
- **Educational Theme**: {education_theme}
- **Magic Element**: {random_element}
## Constraints
- Pages: Exactly {page_count} pages.
- Structure: Intro -> Development -> Climax -> Resolution.
"""

View File

@@ -0,0 +1,25 @@
from datetime import datetime, timedelta, timezone
from jose import JWTError, jwt
from app.core.config import settings
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_DAYS = 7
def create_access_token(data: dict) -> str:
"""创建 JWT token"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(days=ACCESS_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, settings.secret_key, algorithm=ALGORITHM)
def decode_access_token(token: str) -> dict | None:
"""解码 JWT token"""
try:
payload = jwt.decode(token, settings.secret_key, algorithms=[ALGORITHM])
return payload
except JWTError:
return None

View File

View File

@@ -0,0 +1,119 @@
from datetime import datetime
from decimal import Decimal
from uuid import uuid4
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, Integer, Numeric, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from app.db.models import Base
def _uuid() -> str:
return str(uuid4())
class Provider(Base):
"""Model provider registry."""
__tablename__ = "providers"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
name: Mapped[str] = mapped_column(String(100), nullable=False)
type: Mapped[str] = mapped_column(String(50), nullable=False) # text/image/tts/storybook
adapter: Mapped[str] = mapped_column(String(100), nullable=False)
model: Mapped[str] = mapped_column(String(200), nullable=True)
api_base: Mapped[str] = mapped_column(String(300), nullable=True)
api_key: Mapped[str] = mapped_column(String(500), nullable=True) # 可选,优先于 config_ref
timeout_ms: Mapped[int] = mapped_column(Integer, default=60000)
max_retries: Mapped[int] = mapped_column(Integer, default=1)
weight: Mapped[int] = mapped_column(Integer, default=1)
priority: Mapped[int] = mapped_column(Integer, default=0)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
config_json: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 存储额外配置(speed, vol, etc)
config_ref: Mapped[str] = mapped_column(String(100), nullable=True) # 环境变量 key 名称(回退)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
)
updated_by: Mapped[str] = mapped_column(String(100), nullable=True)
class ProviderMetrics(Base):
"""供应商调用指标记录。"""
__tablename__ = "provider_metrics"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
provider_id: Mapped[str] = mapped_column(
String(36), ForeignKey("providers.id", ondelete="CASCADE"), nullable=False, index=True
)
timestamp: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow, index=True
)
success: Mapped[bool] = mapped_column(Boolean, nullable=False)
latency_ms: Mapped[int] = mapped_column(Integer, nullable=True)
cost_usd: Mapped[Decimal] = mapped_column(Numeric(10, 6), nullable=True)
error_message: Mapped[str] = mapped_column(Text, nullable=True)
request_id: Mapped[str] = mapped_column(String(100), nullable=True)
class ProviderHealth(Base):
"""供应商健康状态。"""
__tablename__ = "provider_health"
provider_id: Mapped[str] = mapped_column(
String(36), ForeignKey("providers.id", ondelete="CASCADE"), primary_key=True
)
is_healthy: Mapped[bool] = mapped_column(Boolean, default=True)
last_check: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=True)
consecutive_failures: Mapped[int] = mapped_column(Integer, default=0)
last_error: Mapped[str] = mapped_column(Text, nullable=True)
class ProviderSecret(Base):
"""供应商密钥加密存储。"""
__tablename__ = "provider_secrets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
encrypted_value: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
)
class CostRecord(Base):
"""成本记录表 - 记录每次 API 调用的成本。"""
__tablename__ = "cost_records"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[str] = mapped_column(String(36), nullable=False, index=True)
provider_id: Mapped[str] = mapped_column(String(36), nullable=True) # 可能是环境变量配置
provider_name: Mapped[str] = mapped_column(String(100), nullable=False)
capability: Mapped[str] = mapped_column(String(50), nullable=False) # text/image/tts/storybook
estimated_cost: Mapped[Decimal] = mapped_column(Numeric(10, 6), nullable=False)
timestamp: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow, index=True
)
class UserBudget(Base):
"""用户预算配置。"""
__tablename__ = "user_budgets"
user_id: Mapped[str] = mapped_column(String(36), primary_key=True)
daily_limit_usd: Mapped[Decimal] = mapped_column(Numeric(10, 4), default=Decimal("1.0"))
monthly_limit_usd: Mapped[Decimal] = mapped_column(Numeric(10, 4), default=Decimal("10.0"))
alert_threshold: Mapped[Decimal] = mapped_column(
Numeric(3, 2), default=Decimal("0.8")
) # 80% 时告警
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
)

View File

@@ -0,0 +1,50 @@
import threading
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.core.config import settings
_engine = None
_session_factory: async_sessionmaker[AsyncSession] | None = None
_lock = threading.Lock()
def _get_engine():
global _engine
if _engine is None:
with _lock:
if _engine is None:
_engine = create_async_engine(
settings.database_url,
echo=settings.debug,
pool_pre_ping=True,
pool_recycle=300,
)
return _engine
def _get_session_factory():
global _session_factory
if _session_factory is None:
with _lock:
if _session_factory is None:
_session_factory = async_sessionmaker(
_get_engine(), class_=AsyncSession, expire_on_commit=False
)
return _session_factory
async def init_db():
"""Create tables if they do not exist."""
from app.db.models import Base # main models
engine = _get_engine()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_db():
"""Yield a DB session with proper cleanup."""
session_factory = _get_session_factory()
async with session_factory() as session:
yield session

232
backend/app/db/models.py Normal file
View File

@@ -0,0 +1,232 @@
from datetime import date, datetime, time
from uuid import uuid4
from sqlalchemy import (
JSON,
Boolean,
Date,
DateTime,
Float,
ForeignKey,
Integer,
String,
Text,
Time,
UniqueConstraint,
func,
)
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
class Base(DeclarativeBase):
"""Declarative base."""
class User(Base):
"""User entity."""
__tablename__ = "users"
id: Mapped[str] = mapped_column(String(255), primary_key=True) # OAuth provider user ID
name: Mapped[str] = mapped_column(String(255), nullable=False)
avatar_url: Mapped[str | None] = mapped_column(String(500))
provider: Mapped[str] = mapped_column(String(50), nullable=False) # github / google
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
stories: Mapped[list["Story"]] = relationship(
"Story", back_populates="user", cascade="all, delete-orphan"
)
child_profiles: Mapped[list["ChildProfile"]] = relationship(
"ChildProfile", back_populates="user", cascade="all, delete-orphan"
)
class Story(Base):
"""Story entity."""
__tablename__ = "stories"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[str] = mapped_column(
String(255), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
)
child_profile_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("child_profiles.id", ondelete="SET NULL"), nullable=True
)
universe_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("story_universes.id", ondelete="SET NULL"), nullable=True
)
title: Mapped[str] = mapped_column(String(255), nullable=False)
story_text: Mapped[str] = mapped_column(Text, nullable=True) # 允许为空(绘本模式下)
pages: Mapped[list[dict] | None] = mapped_column(JSON, default=list) # 绘本分页数据
cover_prompt: Mapped[str | None] = mapped_column(Text)
image_url: Mapped[str | None] = mapped_column(String(500))
mode: Mapped[str] = mapped_column(String(20), nullable=False, default="generated")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
user: Mapped["User"] = relationship("User", back_populates="stories")
child_profile: Mapped["ChildProfile | None"] = relationship("ChildProfile")
story_universe: Mapped["StoryUniverse | None"] = relationship("StoryUniverse")
def _uuid() -> str:
return str(uuid4())
class ChildProfile(Base):
"""Child profile entity."""
__tablename__ = "child_profiles"
__table_args__ = (UniqueConstraint("user_id", "name", name="uq_child_profile_user_name"),)
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
String(255), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
name: Mapped[str] = mapped_column(String(50), nullable=False)
avatar_url: Mapped[str | None] = mapped_column(String(500))
birth_date: Mapped[date | None] = mapped_column(Date)
gender: Mapped[str | None] = mapped_column(String(10))
interests: Mapped[list[str]] = mapped_column(JSON, default=list)
growth_themes: Mapped[list[str]] = mapped_column(JSON, default=list)
reading_preferences: Mapped[dict] = mapped_column(JSON, default=dict)
stories_count: Mapped[int] = mapped_column(Integer, default=0)
total_reading_time: Mapped[int] = mapped_column(Integer, default=0)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
user: Mapped["User"] = relationship("User", back_populates="child_profiles")
story_universes: Mapped[list["StoryUniverse"]] = relationship(
"StoryUniverse", back_populates="child_profile", cascade="all, delete-orphan"
)
@property
def age(self) -> int | None:
if not self.birth_date:
return None
today = date.today()
return today.year - self.birth_date.year - (
(today.month, today.day) < (self.birth_date.month, self.birth_date.day)
)
class StoryUniverse(Base):
"""Story universe entity."""
__tablename__ = "story_universes"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
child_profile_id: Mapped[str] = mapped_column(
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
)
name: Mapped[str] = mapped_column(String(100), nullable=False)
protagonist: Mapped[dict] = mapped_column(JSON, nullable=False)
recurring_characters: Mapped[list] = mapped_column(JSON, default=list)
world_settings: Mapped[dict] = mapped_column(JSON, default=dict)
achievements: Mapped[list] = mapped_column(JSON, default=list)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
child_profile: Mapped["ChildProfile"] = relationship("ChildProfile", back_populates="story_universes")
class ReadingEvent(Base):
"""Reading event entity."""
__tablename__ = "reading_events"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
child_profile_id: Mapped[str] = mapped_column(
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
)
story_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("stories.id", ondelete="SET NULL"), nullable=True, index=True
)
event_type: Mapped[str] = mapped_column(String(20), nullable=False)
reading_time: Mapped[int] = mapped_column(Integer, default=0)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), index=True
)
class PushConfig(Base):
"""Push configuration entity."""
__tablename__ = "push_configs"
__table_args__ = (
UniqueConstraint("child_profile_id", name="uq_push_config_child"),
)
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
String(255), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
child_profile_id: Mapped[str] = mapped_column(
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
)
push_time: Mapped[time | None] = mapped_column(Time)
push_days: Mapped[list[int]] = mapped_column(JSON, default=list)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
class PushEvent(Base):
"""Push event entity."""
__tablename__ = "push_events"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
String(255), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
child_profile_id: Mapped[str] = mapped_column(
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
)
trigger_type: Mapped[str] = mapped_column(String(20), nullable=False)
status: Mapped[str] = mapped_column(String(20), nullable=False)
reason: Mapped[str | None] = mapped_column(Text)
sent_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
class MemoryItem(Base):
"""Memory item entity with time decay metadata."""
__tablename__ = "memory_items"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
child_profile_id: Mapped[str] = mapped_column(
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
)
universe_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("story_universes.id", ondelete="SET NULL"), nullable=True, index=True
)
type: Mapped[str] = mapped_column(String(50), nullable=False)
value: Mapped[dict] = mapped_column(JSON, nullable=False)
base_weight: Mapped[float] = mapped_column(Float, default=1.0)
last_used_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
ttl_days: Mapped[int | None] = mapped_column(Integer)

80
backend/app/main.py Normal file
View File

@@ -0,0 +1,80 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api import (
auth,
memories,
profiles,
push_configs,
reading_events,
stories,
universes,
)
from app.core.config import settings
from app.core.logging import get_logger, setup_logging
from app.db.database import init_db
setup_logging()
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""App lifespan manager."""
logger.info("app_starting", app_name=settings.app_name)
await init_db()
logger.info("database_initialized")
# 加载 provider 缓存
await _load_provider_cache()
yield
logger.info("app_shutdown")
async def _load_provider_cache():
"""启动时加载 provider 缓存。"""
from app.db.database import _get_session_factory
from app.services.provider_cache import reload_providers
try:
session_factory = _get_session_factory()
async with session_factory() as session:
cache = await reload_providers(session)
provider_count = sum(len(v) for v in cache.values())
logger.info("provider_cache_loaded", provider_count=provider_count)
except Exception as e:
logger.warning("provider_cache_load_failed", error=str(e))
# 不阻止启动,使用 settings 中的默认配置
app = FastAPI(
title=settings.app_name,
description="AI-driven story generator for kids.",
version="0.1.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(auth.router, prefix="/auth", tags=["auth"])
app.include_router(stories.router, prefix="/api", tags=["stories"])
app.include_router(profiles.router, prefix="/api", tags=["profiles"])
app.include_router(universes.router, prefix="/api", tags=["universes"])
app.include_router(push_configs.router, prefix="/api", tags=["push-configs"])
app.include_router(reading_events.router, prefix="/api", tags=["reading-events"])
app.include_router(memories.router, prefix="/api", tags=["memories"])
@app.get("/health")
async def health_check():
"""Simple liveness check."""
return {"status": "ok"}

View File

View File

@@ -0,0 +1,85 @@
"""Achievement extraction service."""
import json
import re
import httpx
from app.core.config import settings
from app.core.logging import get_logger
from app.core.prompts import ACHIEVEMENT_EXTRACTION_PROMPT
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
async def extract_achievements(story_text: str) -> list[dict]:
"""Extract achievements from story text using LLM."""
if not settings.text_api_key:
logger.warning("achievement_extraction_skipped", reason="missing_text_api_key")
return []
model = settings.text_model or "gemini-2.0-flash"
url = f"{TEXT_API_BASE}/{model}:generateContent"
prompt = ACHIEVEMENT_EXTRACTION_PROMPT.format(story_text=story_text)
payload = {
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.2,
"topP": 0.9,
},
}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(
url,
json=payload,
headers={"x-goog-api-key": settings.text_api_key},
)
response.raise_for_status()
result = response.json()
candidates = result.get("candidates") or []
if not candidates:
logger.warning("achievement_extraction_empty")
return []
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
logger.warning("achievement_extraction_missing_text")
return []
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError:
logger.warning("achievement_extraction_parse_failed")
return []
achievements = parsed.get("achievements")
if not isinstance(achievements, list):
return []
normalized: list[dict] = []
for item in achievements:
if not isinstance(item, dict):
continue
a_type = str(item.get("type", "")).strip()
description = str(item.get("description", "")).strip()
score = item.get("score", 0)
if not a_type or not description:
continue
normalized.append({
"type": a_type,
"description": description,
"score": score
})
return normalized

View File

@@ -0,0 +1,21 @@
"""适配器模块 - 供应商平台化架构核心。"""
from app.services.adapters.base import AdapterConfig, BaseAdapter
# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.registry import AdapterRegistry
# Storybook adapters
from app.services.adapters.storybook import primary as _storybook_primary # noqa: F401
from app.services.adapters.text import gemini as _text_gemini_adapter # noqa: F401
# 导入所有适配器以触发注册
# Text adapters
from app.services.adapters.text import openai as _text_openai_adapter # noqa: F401
# TTS adapters
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
__all__ = ["AdapterConfig", "BaseAdapter", "AdapterRegistry"]

View File

@@ -0,0 +1,46 @@
"""适配器基类定义。"""
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from pydantic import BaseModel
T = TypeVar("T")
class AdapterConfig(BaseModel):
"""适配器配置基类。"""
api_key: str
api_base: str | None = None
model: str | None = None
timeout_ms: int = 60000
max_retries: int = 3
extra_config: dict = {}
class BaseAdapter(ABC, Generic[T]):
"""适配器基类,所有供应商适配器必须继承此类。"""
# 子类必须定义
adapter_type: str # text / image / tts
adapter_name: str # text_primary / image_primary / tts_primary
def __init__(self, config: AdapterConfig):
self.config = config
@abstractmethod
async def execute(self, **kwargs) -> T:
"""执行适配器逻辑,返回结果。"""
pass
@abstractmethod
async def health_check(self) -> bool:
"""健康检查,返回是否可用。"""
pass
@property
@abstractmethod
def estimated_cost(self) -> float:
"""预估单次调用成本 (USD)。"""
pass

View File

@@ -0,0 +1,3 @@
"""图像生成适配器。"""# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.image import antigravity as _image_antigravity_adapter # noqa: F401

View File

@@ -0,0 +1,214 @@
"""Antigravity 图像生成适配器。
使用 OpenAI 兼容 API 生成图像。
支持 gemini-3-pro-image 等模型。
"""
import base64
import time
from typing import Any
from openai import AsyncOpenAI
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "http://127.0.0.1:8045/v1"
DEFAULT_MODEL = "gemini-3-pro-image"
DEFAULT_SIZE = "1024x1024"
# 支持的尺寸映射
SUPPORTED_SIZES = {
"1024x1024": "1:1",
"1280x720": "16:9",
"720x1280": "9:16",
"1216x896": "4:3",
}
@AdapterRegistry.register("image", "antigravity")
class AntigravityImageAdapter(BaseAdapter[str]):
"""Antigravity 图像生成适配器 (OpenAI 兼容 API)。
特点:
- 使用 OpenAI 兼容的 chat.completions 端点
- 通过 extra_body.size 指定图像尺寸
- 支持 gemini-3-pro-image 等模型
- 返回图片 URL 或 base64
"""
adapter_type = "image"
adapter_name = "antigravity"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
self.client = AsyncOpenAI(
base_url=self.api_base,
api_key=config.api_key,
timeout=config.timeout_ms / 1000,
)
async def execute(
self,
prompt: str,
model: str | None = None,
size: str | None = None,
num_images: int = 1,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 base64。
Args:
prompt: 图片描述提示词
model: 模型名称 (gemini-3-pro-image / gemini-3-pro-image-16-9 等)
size: 图像尺寸 (1024x1024, 1280x720, 720x1280, 1216x896)
num_images: 生成图片数量 (暂只支持 1)
Returns:
图片 URL 或 base64 字符串
"""
# 优先使用传入参数,其次使用 Adapter 配置,最后使用默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
size = size or cfg.get("size") or DEFAULT_SIZE
start_time = time.time()
logger.info(
"antigravity_generate_start",
prompt_length=len(prompt),
model=model,
size=size,
)
# 调用 API
image_url = await self._generate_image(prompt, model, size)
elapsed = time.time() - start_time
logger.info(
"antigravity_generate_success",
elapsed_seconds=round(elapsed, 2),
model=model,
)
return image_url
async def health_check(self) -> bool:
"""检查 Antigravity API 是否可用。"""
try:
# 简单测试连通性
response = await self.client.chat.completions.create(
model=self.config.model or DEFAULT_MODEL,
messages=[{"role": "user", "content": "test"}],
max_tokens=1,
)
return True
except Exception as e:
logger.warning("antigravity_health_check_failed", error=str(e))
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
Antigravity 使用 Gemini 模型,成本约 $0.02/张。
"""
return 0.02
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((Exception,)),
reraise=True,
)
async def _generate_image(
self,
prompt: str,
model: str,
size: str,
) -> str:
"""调用 Antigravity API 生成图像。
Returns:
图片 URL 或 base64 data URI
"""
try:
response = await self.client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
extra_body={"size": size},
)
# 解析响应
content = response.choices[0].message.content
if not content:
raise ValueError("Antigravity 未返回内容")
# 尝试解析为图片 URL 或 base64
# 响应可能是纯 URL、base64 或 markdown 格式的图片
image_url = self._extract_image_url(content)
if image_url:
return image_url
raise ValueError(f"Antigravity 响应无法解析为图片: {content[:200]}")
except Exception as e:
logger.error(
"antigravity_generate_error",
error=str(e),
model=model,
)
raise
def _extract_image_url(self, content: str) -> str | None:
"""从响应内容中提取图片 URL。
支持多种格式:
- 纯 URL: https://...
- Markdown: ![...](https://...)
- Base64 data URI: data:image/...
- 纯 base64 字符串
"""
content = content.strip()
# 1. 检查是否为 data URI
if content.startswith("data:image/"):
return content
# 2. 检查是否为纯 URL
if content.startswith("http://") or content.startswith("https://"):
# 可能有多行,取第一行
return content.split("\n")[0].strip()
# 3. 检查 Markdown 图片格式 ![...](url)
import re
md_match = re.search(r"!\[.*?\]\((https?://[^\)]+)\)", content)
if md_match:
return md_match.group(1)
# 4. 检查是否像 base64 编码的图片数据
if self._looks_like_base64(content):
# 假设是 PNG
return f"data:image/png;base64,{content}"
return None
def _looks_like_base64(self, s: str) -> bool:
"""判断字符串是否看起来像 base64 编码。"""
# Base64 只包含 A-Z, a-z, 0-9, +, /, =
# 且长度通常较长
if len(s) < 100:
return False
import re
return bool(re.match(r"^[A-Za-z0-9+/=]+$", s.replace("\n", "")))

View File

@@ -0,0 +1,252 @@
"""CQTAI nano 图像生成适配器。
支持异步生成 + 轮询获取结果。
API 文档: https://api.cqtai.com
"""
import asyncio
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "https://api.cqtai.com"
DEFAULT_MODEL = "nano-banana"
DEFAULT_RESOLUTION = "2K"
DEFAULT_ASPECT_RATIO = "1:1"
POLL_INTERVAL_SECONDS = 2
MAX_POLL_ATTEMPTS = 60 # 最多轮询 2 分钟
@AdapterRegistry.register("image", "cqtai")
class CQTAIImageAdapter(BaseAdapter[str]):
"""CQTAI nano 图像生成适配器,返回图片 URL。
特点:
- 异步生成 + 轮询获取结果
- 支持 nano-banana (标准) 和 nano-banana-pro (高画质)
- 支持多种分辨率和画面比例
- 支持图生图 (filesUrl)
"""
adapter_type = "image"
adapter_name = "cqtai"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
async def execute(
self,
prompt: str,
model: str | None = None,
resolution: str | None = None,
aspect_ratio: str | None = None,
num_images: int = 1,
files_url: list[str] | None = None,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 URL 列表。
Args:
prompt: 图片描述提示词
model: 模型名称 (nano-banana / nano-banana-pro)
resolution: 分辨率 (1K / 2K / 4K)
aspect_ratio: 画面比例 (1:1, 16:9, 9:16, 4:3, 3:4 等)
num_images: 生成图片数量 (1-4)
files_url: 输入图片 URL 列表 (图生图)
Returns:
单张图片返回 str多张返回 list[str]
"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default (extra_config)
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
resolution = resolution or cfg.get("resolution") or DEFAULT_RESOLUTION
aspect_ratio = aspect_ratio or cfg.get("aspect_ratio") or DEFAULT_ASPECT_RATIO
num_images = min(max(num_images, 1), 4) # 限制 1-4
start_time = time.time()
logger.info(
"cqtai_generate_start",
prompt_length=len(prompt),
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
)
# 1. 提交生成任务
task_id = await self._submit_task(
prompt=prompt,
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
files_url=files_url or [],
)
logger.info("cqtai_task_submitted", task_id=task_id)
# 2. 轮询获取结果
result = await self._poll_result(task_id)
elapsed = time.time() - start_time
logger.info(
"cqtai_generate_success",
task_id=task_id,
elapsed_seconds=round(elapsed, 2),
image_count=len(result) if isinstance(result, list) else 1,
)
# 单张图片返回字符串,多张返回列表
if num_images == 1 and isinstance(result, list) and len(result) == 1:
return result[0]
return result
async def health_check(self) -> bool:
"""检查 CQTAI API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
# 简单的连通性测试
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": "health_check_test"},
headers={"Authorization": self.config.api_key},
)
# 即使返回错误也说明服务可达
return response.status_code in (200, 400, 401, 403, 404)
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
nano-banana: ¥0.1 ≈ $0.014
nano-banana-pro: ¥0.2 ≈ $0.028
"""
model = self.config.model or DEFAULT_MODEL
if model == "nano-banana-pro":
return 0.028
return 0.014
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _submit_task(
self,
prompt: str,
model: str,
resolution: str,
aspect_ratio: str,
num_images: int,
files_url: list[str],
) -> str:
"""提交图像生成任务,返回任务 ID。"""
timeout = self.config.timeout_ms / 1000
payload = {
"prompt": prompt,
"numImages": num_images,
"aspectRatio": aspect_ratio,
"filesUrl": files_url,
}
# 可选参数,不传则使用默认值
if model != DEFAULT_MODEL:
payload["model"] = model
if resolution != DEFAULT_RESOLUTION:
payload["resolution"] = resolution
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
f"{self.api_base}/api/cqt/generator/nano",
json=payload,
headers={
"Authorization": self.config.api_key,
"Content-Type": "application/json",
},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 任务提交失败: {data.get('msg', '未知错误')}")
task_id = data.get("data")
if not task_id:
raise ValueError("CQTAI 未返回任务 ID")
return task_id
async def _poll_result(self, task_id: str) -> list[str]:
"""轮询获取生成结果。
Returns:
图片 URL 列表
"""
timeout = self.config.timeout_ms / 1000
for attempt in range(MAX_POLL_ATTEMPTS):
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": task_id},
headers={"Authorization": self.config.api_key},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 查询失败: {data.get('msg', '未知错误')}")
result_data = data.get("data", {})
status = result_data.get("status")
if status == "completed":
# 提取图片 URL
images = result_data.get("images", [])
if not images:
# 兼容不同返回格式
image_url = result_data.get("imageUrl") or result_data.get("url")
if image_url:
images = [image_url]
if not images:
raise ValueError("CQTAI 未返回图片 URL")
return images
elif status == "failed":
error_msg = result_data.get("error", "生成失败")
raise ValueError(f"CQTAI 图像生成失败: {error_msg}")
# 继续等待
logger.debug(
"cqtai_poll_waiting",
task_id=task_id,
attempt=attempt + 1,
status=status,
)
await asyncio.sleep(POLL_INTERVAL_SECONDS)
raise TimeoutError(f"CQTAI 任务超时: {task_id}")

View File

@@ -0,0 +1,73 @@
"""适配器注册表 - 支持动态注册和工厂创建。"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.services.adapters.base import AdapterConfig, BaseAdapter
class AdapterRegistry:
"""适配器注册表,管理所有已注册的适配器类。"""
_adapters: dict[str, type["BaseAdapter"]] = {}
@classmethod
def register(cls, adapter_type: str, adapter_name: str):
"""装饰器:注册适配器类。
用法:
@AdapterRegistry.register("text", "text_primary")
class TextPrimaryAdapter(BaseAdapter[StoryOutput]):
...
"""
def decorator(adapter_class: type["BaseAdapter"]):
key = f"{adapter_type}:{adapter_name}"
cls._adapters[key] = adapter_class
# 自动设置类属性
adapter_class.adapter_type = adapter_type
adapter_class.adapter_name = adapter_name
return adapter_class
return decorator
@classmethod
def get(cls, adapter_type: str, adapter_name: str) -> type["BaseAdapter"] | None:
"""获取已注册的适配器类。"""
key = f"{adapter_type}:{adapter_name}"
return cls._adapters.get(key)
@classmethod
def list_adapters(cls, adapter_type: str | None = None) -> list[str]:
"""列出所有已注册的适配器。
Args:
adapter_type: 可选,筛选特定类型 (text/image/tts)
Returns:
适配器键列表,格式为 "type:name"
"""
if adapter_type:
return [k for k in cls._adapters if k.startswith(f"{adapter_type}:")]
return list(cls._adapters.keys())
@classmethod
def create(
cls,
adapter_type: str,
adapter_name: str,
config: "AdapterConfig",
) -> "BaseAdapter":
"""工厂方法:创建适配器实例。
Raises:
ValueError: 适配器未注册
"""
adapter_class = cls.get(adapter_type, adapter_name)
if not adapter_class:
available = cls.list_adapters(adapter_type)
raise ValueError(
f"适配器 '{adapter_type}:{adapter_name}' 未注册。"
f"可用: {available}"
)
return adapter_class(config)

View File

@@ -0,0 +1 @@
"""Storybook 适配器模块。"""

View File

@@ -0,0 +1,195 @@
"""Storybook 适配器 - 生成可翻页的分页故事书。"""
import json
import random
import re
import time
from dataclasses import dataclass, field
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_STORYBOOK,
USER_PROMPT_STORYBOOK,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@dataclass
class StorybookPage:
"""故事书单页。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
@dataclass
class Storybook:
"""故事书输出。"""
title: str
main_character: str
art_style: str
pages: list[StorybookPage] = field(default_factory=list)
cover_prompt: str = ""
cover_url: str | None = None
@AdapterRegistry.register("storybook", "storybook_primary")
class StorybookPrimaryAdapter(BaseAdapter[Storybook]):
"""Storybook 生成适配器(默认)。
生成分页故事书结构,包含每页文字和图像提示词。
图像生成需要单独调用 image adapter。
"""
adapter_type = "storybook"
adapter_name = "storybook_primary"
async def execute(
self,
keywords: str,
page_count: int = 6,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> Storybook:
"""生成分页故事书。
Args:
keywords: 故事关键词
page_count: 页数 (4-12)
education_theme: 教育主题
memory_context: 记忆上下文
Returns:
Storybook 对象,包含标题、页面列表和封面提示词
"""
start_time = time.time()
page_count = max(4, min(page_count, 12)) # 限制 4-12 页
logger.info(
"storybook_generate_start",
keywords=keywords,
page_count=page_count,
has_memory=bool(memory_context),
)
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
prompt = USER_PROMPT_STORYBOOK.format(
keywords=keywords,
education_theme=theme,
random_element=random_element,
page_count=page_count,
memory_context=memory_context or "",
)
payload = {
"system_instruction": {"parts": [{"text": SYSTEM_INSTRUCTION_STORYBOOK}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Storybook 服务未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Storybook 服务响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Storybook JSON 解析失败: {exc}")
# 构建 Storybook 对象
pages = [
StorybookPage(
page_number=p.get("page_number", i + 1),
text=p.get("text", ""),
image_prompt=p.get("image_prompt", ""),
)
for i, p in enumerate(parsed.get("pages", []))
]
storybook = Storybook(
title=parsed.get("title", "未命名故事"),
main_character=parsed.get("main_character", ""),
art_style=parsed.get("art_style", ""),
pages=pages,
cover_prompt=parsed.get("cover_prompt", ""),
)
elapsed = time.time() - start_time
logger.info(
"storybook_generate_success",
elapsed_seconds=round(elapsed, 2),
title=storybook.title,
page_count=len(pages),
)
return storybook
async def health_check(self) -> bool:
"""检查 API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估成本(仅文本生成,不含图像)。"""
return 0.002 # 比普通故事稍贵,因为输出更长
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 API带重试机制。"""
model = self.config.model or "gemini-2.0-flash"
url = f"{TEXT_API_BASE}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1 @@
"""文本生成适配器。"""

View File

@@ -0,0 +1,164 @@
"""文本生成适配器 (Google Gemini)。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@AdapterRegistry.register("text", "gemini")
class GeminiTextAdapter(BaseAdapter[StoryOutput]):
"""Google Gemini 文本生成适配器。"""
adapter_type = "text"
adapter_name = "gemini"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("request_start", adapter="gemini", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
# Gemini API Payload supports 'system_instruction'
payload = {
"system_instruction": {"parts": [{"text": system_instruction}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Gemini 未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Gemini 响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Gemini 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("Gemini 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"request_success",
adapter="gemini",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
async def health_check(self) -> bool:
"""检查 Gemini API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.001
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 Gemini API。"""
model = self.config.model or "gemini-2.0-flash"
base_url = self.config.api_base or TEXT_API_BASE
# 智能补全:
# 1. 如果用户填了完整路径 (以 /models 结尾),就直接用 (支持 v1 或 v1beta)
if self.config.api_base and base_url.rstrip("/").endswith("/models"):
pass
# 2. 如果没填路径 (只是域名),默认补全代码适配的 /v1beta/models
elif self.config.api_base:
base_url = f"{base_url.rstrip('/')}/v1beta/models"
url = f"{base_url}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,11 @@
from dataclasses import dataclass
from typing import Literal
@dataclass
class StoryOutput:
"""故事生成输出。"""
mode: Literal["generated", "enhanced"]
title: str
story_text: str
cover_prompt_suggestion: str

View File

@@ -0,0 +1,172 @@
"""OpenAI 文本生成适配器。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
OPENAI_API_BASE = "https://api.openai.com/v1/chat/completions"
@AdapterRegistry.register("text", "openai")
class OpenAITextAdapter(BaseAdapter[StoryOutput]):
"""OpenAI 文本生成适配器。"""
adapter_type = "text"
adapter_name = "openai"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("openai_text_request_start", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
model = self.config.model or "gpt-4o-mini"
payload = {
"model": model,
"messages": [
{
"role": "system",
"content": system_instruction,
},
{"role": "user", "content": prompt},
],
"response_format": {"type": "json_object"},
"temperature": 0.95,
"top_p": 0.9,
}
result = await self._call_api(payload)
choices = result.get("choices") or []
if not choices:
raise ValueError("OpenAI 未返回内容")
response_text = choices[0].get("message", {}).get("content", "")
if not response_text:
raise ValueError("OpenAI 响应缺少文本")
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"OpenAI 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("OpenAI 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"openai_text_request_success",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
mode=parsed["mode"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(httpx.HTTPStatusError),
)
async def _call_api(self, payload: dict) -> dict:
"""调用 OpenAI API带重试机制。"""
url = self.config.api_base or OPENAI_API_BASE
# 智能补全: 如果用户只填了 Base URL自动补全路径
if self.config.api_base and not url.endswith("/chat/completions"):
base = url.rstrip("/")
url = f"{base}/chat/completions"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()
async def health_check(self) -> bool:
"""检查 OpenAI API 是否可用。"""
try:
payload = {
"model": self.config.model or "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5,
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估文本生成成本 (USD)。"""
return 0.01

View File

@@ -0,0 +1,5 @@
"""TTS 语音合成适配器。"""
from app.services.adapters.tts import edge_tts as _tts_edge_tts_adapter # noqa: F401
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401

View File

@@ -0,0 +1,66 @@
"""EdgeTTS 免费语音生成适配器。"""
import time
import edge_tts
from app.core.logging import get_logger
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认中文女声 (晓晓)
DEFAULT_VOICE = "zh-CN-XiaoxiaoNeural"
@AdapterRegistry.register("tts", "edge_tts")
class EdgeTTSAdapter(BaseAdapter[bytes]):
"""EdgeTTS 语音生成适配器 (Free)。
不需要 API Key。
"""
adapter_type = "tts"
adapter_name = "edge_tts"
async def execute(self, text: str, **kwargs) -> bytes:
"""生成语音。"""
# 支持动态指定音色
voice = kwargs.get("voice") or self.config.model or DEFAULT_VOICE
start_time = time.time()
logger.info("edge_tts_generate_start", text_length=len(text), voice=voice)
# EdgeTTS 只能输出到文件,我们需要用临时文件周转一下
# 或者直接 capture stream (communicate) 但 edge-tts 库主要面向文件
# 优化: 使用 communicate 直接获取 bytes无需磁盘IO
communicate = edge_tts.Communicate(text, voice)
audio_data = b""
async for chunk in communicate.stream():
if chunk["type"] == "audio":
audio_data += chunk["data"]
elapsed = time.time() - start_time
logger.info(
"edge_tts_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_data),
)
return audio_data
async def health_check(self) -> bool:
"""检查 EdgeTTS 是否可用 (网络连通性)。"""
try:
# 简单生成一个词
await self.execute("Hi")
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.0 # Free!

View File

@@ -0,0 +1,104 @@
"""ElevenLabs TTS 语音合成适配器。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
ELEVENLABS_API_BASE = "https://api.elevenlabs.io/v1"
DEFAULT_VOICE_ID = "21m00Tcm4TlvDq8ikWAM" # Rachel
@AdapterRegistry.register("tts", "elevenlabs")
class ElevenLabsTtsAdapter(BaseAdapter[bytes]):
"""ElevenLabs TTS 语音合成适配器,返回 MP3 bytes。"""
adapter_type = "tts"
adapter_name = "elevenlabs"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or ELEVENLABS_API_BASE
async def execute(self, text: str, **kwargs) -> bytes:
"""将文本转换为语音 MP3 bytes。"""
start_time = time.time()
logger.info("elevenlabs_tts_start", text_length=len(text))
voice_id = kwargs.get("voice_id") or DEFAULT_VOICE_ID
model_id = kwargs.get("model") or self.config.model or "eleven_multilingual_v2"
stability = kwargs.get("stability", 0.5)
similarity_boost = kwargs.get("similarity_boost", 0.75)
url = f"{self.api_base}/text-to-speech/{voice_id}"
payload = {
"text": text,
"model_id": model_id,
"voice_settings": {
"stability": stability,
"similarity_boost": similarity_boost,
},
}
audio_bytes = await self._call_api(url, payload)
elapsed = time.time() - start_time
logger.info(
"elevenlabs_tts_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 ElevenLabs API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(
f"{self.api_base}/voices",
headers={"xi-api-key": self.config.api_key},
)
return response.status_code == 200
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每千字符成本 (USD)。"""
return 0.03
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> bytes:
"""调用 ElevenLabs API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"xi-api-key": self.config.api_key,
"Content-Type": "application/json",
"Accept": "audio/mpeg",
},
)
response.raise_for_status()
return response.content

View File

@@ -0,0 +1,149 @@
"""MiniMax 语音生成适配器 (T2A V2)。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.config import settings
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# MiniMax API 配置
DEFAULT_API_URL = "https://api.minimaxi.com/v1/t2a_v2"
DEFAULT_MODEL = "speech-2.6-turbo"
@AdapterRegistry.register("tts", "minimax")
class MiniMaxTTSAdapter(BaseAdapter[bytes]):
"""MiniMax 语音生成适配器。
需要配置:
- api_key: MiniMax API Key
- minimax_group_id: 可选 (取决于使用的模型/账户类型)
"""
adapter_type = "tts"
adapter_name = "minimax"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_url = DEFAULT_API_URL
async def execute(
self,
text: str,
voice_id: str | None = None,
model: str | None = None,
speed: float | None = None,
vol: float | None = None,
pitch: int | None = None,
emotion: str | None = None,
**kwargs,
) -> bytes:
"""生成语音。"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
voice_id = voice_id or cfg.get("voice_id") or "male-qn-qingse"
speed = speed if speed is not None else (cfg.get("speed") or 1.0)
vol = vol if vol is not None else (cfg.get("vol") or 1.0)
pitch = pitch if pitch is not None else (cfg.get("pitch") or 0)
emotion = emotion or cfg.get("emotion")
group_id = kwargs.get("group_id") or settings.minimax_group_id
url = self.api_url
if group_id:
url = f"{self.api_url}?GroupId={group_id}"
payload = {
"model": model,
"text": text,
"stream": False,
"voice_setting": {
"voice_id": voice_id,
"speed": speed,
"vol": vol,
"pitch": pitch,
},
"audio_setting": {
"sample_rate": 32000,
"bitrate": 128000,
"format": "mp3",
"channel": 1
}
}
if emotion:
payload["voice_setting"]["emotion"] = emotion
start_time = time.time()
logger.info("minimax_generate_start", text_length=len(text), model=model)
result = await self._call_api(url, payload)
# 错误处理
if result.get("base_resp", {}).get("status_code") != 0:
error_msg = result.get("base_resp", {}).get("status_msg", "未知错误")
raise ValueError(f"MiniMax API 错误: {error_msg}")
# Hex 解码 (关键逻辑,从 primary.py 迁移)
hex_audio = result.get("data", {}).get("audio")
if not hex_audio:
raise ValueError("API 响应中未找到音频数据 (data.audio)")
try:
audio_bytes = bytes.fromhex(hex_audio)
except ValueError:
raise ValueError("MiniMax 返回的音频数据不是有效的 Hex 字符串")
elapsed = time.time() - start_time
logger.info(
"minimax_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 Minimax API 是否可用。"""
try:
# 尝试生成极短文本
await self.execute("Hi")
return True
except Exception:
return False
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> dict:
"""调用 API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,196 @@
"""成本追踪服务。
记录 API 调用成本,支持预算控制。
"""
from datetime import datetime, timedelta
from decimal import Decimal
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import CostRecord, UserBudget
logger = get_logger(__name__)
class BudgetExceededError(Exception):
"""预算超限错误。"""
def __init__(self, limit_type: str, used: Decimal, limit: Decimal):
self.limit_type = limit_type
self.used = used
self.limit = limit
super().__init__(f"{limit_type} 预算已超限: {used}/{limit} USD")
class CostTracker:
"""成本追踪器。"""
async def record_cost(
self,
db: AsyncSession,
user_id: str,
provider_name: str,
capability: str,
estimated_cost: float,
provider_id: str | None = None,
) -> CostRecord:
"""记录一次 API 调用成本。"""
record = CostRecord(
user_id=user_id,
provider_id=provider_id,
provider_name=provider_name,
capability=capability,
estimated_cost=Decimal(str(estimated_cost)),
)
db.add(record)
await db.commit()
logger.debug(
"cost_recorded",
user_id=user_id,
provider=provider_name,
capability=capability,
cost=estimated_cost,
)
return record
async def get_daily_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户今日成本。"""
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= today_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_monthly_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户本月成本。"""
now = datetime.utcnow()
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= month_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_cost_by_capability(
self,
db: AsyncSession,
user_id: str,
days: int = 30,
) -> dict[str, Decimal]:
"""按能力类型统计成本。"""
since = datetime.utcnow() - timedelta(days=days)
result = await db.execute(
select(CostRecord.capability, func.sum(CostRecord.estimated_cost))
.where(CostRecord.user_id == user_id, CostRecord.timestamp >= since)
.group_by(CostRecord.capability)
)
return {row[0]: Decimal(str(row[1])) for row in result.all()}
async def check_budget(
self,
db: AsyncSession,
user_id: str,
estimated_cost: float,
) -> bool:
"""检查预算是否允许此次调用。
Returns:
True 如果允许,否则抛出 BudgetExceededError
"""
budget = await self.get_user_budget(db, user_id)
if not budget or not budget.enabled:
return True
# 检查日预算
daily_cost = await self.get_daily_cost(db, user_id)
if daily_cost + Decimal(str(estimated_cost)) > budget.daily_limit_usd:
raise BudgetExceededError("", daily_cost, budget.daily_limit_usd)
# 检查月预算
monthly_cost = await self.get_monthly_cost(db, user_id)
if monthly_cost + Decimal(str(estimated_cost)) > budget.monthly_limit_usd:
raise BudgetExceededError("", monthly_cost, budget.monthly_limit_usd)
return True
async def get_user_budget(self, db: AsyncSession, user_id: str) -> UserBudget | None:
"""获取用户预算配置。"""
result = await db.execute(
select(UserBudget).where(UserBudget.user_id == user_id)
)
return result.scalar_one_or_none()
async def set_user_budget(
self,
db: AsyncSession,
user_id: str,
daily_limit: float | None = None,
monthly_limit: float | None = None,
alert_threshold: float | None = None,
enabled: bool | None = None,
) -> UserBudget:
"""设置用户预算。"""
budget = await self.get_user_budget(db, user_id)
if budget is None:
budget = UserBudget(user_id=user_id)
db.add(budget)
if daily_limit is not None:
budget.daily_limit_usd = Decimal(str(daily_limit))
if monthly_limit is not None:
budget.monthly_limit_usd = Decimal(str(monthly_limit))
if alert_threshold is not None:
budget.alert_threshold = Decimal(str(alert_threshold))
if enabled is not None:
budget.enabled = enabled
await db.commit()
await db.refresh(budget)
return budget
async def get_cost_summary(
self,
db: AsyncSession,
user_id: str,
) -> dict:
"""获取用户成本摘要。"""
daily = await self.get_daily_cost(db, user_id)
monthly = await self.get_monthly_cost(db, user_id)
by_capability = await self.get_cost_by_capability(db, user_id)
budget = await self.get_user_budget(db, user_id)
return {
"daily_cost_usd": float(daily),
"monthly_cost_usd": float(monthly),
"by_capability": {k: float(v) for k, v in by_capability.items()},
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd) if budget else None,
"monthly_limit_usd": float(budget.monthly_limit_usd) if budget else None,
"daily_usage_percent": float(daily / budget.daily_limit_usd * 100)
if budget and budget.daily_limit_usd
else None,
"monthly_usage_percent": float(monthly / budget.monthly_limit_usd * 100)
if budget and budget.monthly_limit_usd
else None,
"enabled": budget.enabled if budget else False,
},
}
# 全局单例
cost_tracker = CostTracker()

View File

@@ -0,0 +1,471 @@
"""Memory service handles memory retrieval, scoring, and prompt injection."""
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models import ChildProfile, MemoryItem, StoryUniverse
logger = get_logger(__name__)
class MemoryType:
"""记忆类型常量及配置。"""
# 基础类型
RECENT_STORY = "recent_story"
FAVORITE_CHARACTER = "favorite_character"
SCARY_ELEMENT = "scary_element"
VOCABULARY_GROWTH = "vocabulary_growth"
EMOTIONAL_HIGHLIGHT = "emotional_highlight"
# Phase 1 新增类型
READING_PREFERENCE = "reading_preference" # 阅读偏好
MILESTONE = "milestone" # 里程碑事件
SKILL_MASTERED = "skill_mastered" # 掌握的技能
# 类型配置: (默认权重, 默认TTL天数, 描述)
CONFIG = {
RECENT_STORY: (1.0, 30, "最近阅读的故事"),
FAVORITE_CHARACTER: (1.5, None, "喜欢的角色"), # None = 永久
SCARY_ELEMENT: (2.0, None, "回避的元素"), # 高权重,永久有效
VOCABULARY_GROWTH: (0.8, 90, "词汇积累"),
EMOTIONAL_HIGHLIGHT: (1.2, 60, "情感高光"),
READING_PREFERENCE: (1.0, None, "阅读偏好"),
MILESTONE: (1.5, None, "里程碑事件"),
SKILL_MASTERED: (1.0, 180, "掌握的技能"),
}
@classmethod
def get_default_weight(cls, memory_type: str) -> float:
"""获取类型的默认权重。"""
config = cls.CONFIG.get(memory_type)
return config[0] if config else 1.0
@classmethod
def get_default_ttl(cls, memory_type: str) -> int | None:
"""获取类型的默认 TTL 天数。"""
config = cls.CONFIG.get(memory_type)
return config[1] if config else None
def _decay_factor(days: float) -> float:
"""计算时间衰减因子。"""
if days <= 7:
return 1.0
if days <= 30:
return 0.7
if days <= 90:
return 0.4
return 0.2
async def build_enhanced_memory_context(
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
) -> str | None:
"""构建增强版记忆上下文(自然语言 Prompt"""
if not profile_id and not universe_id:
return None
context_parts: list[str] = []
# 1. 基础档案 (Identity Layer)
if profile_id:
profile = await db.scalar(select(ChildProfile).where(ChildProfile.id == profile_id))
if profile:
context_parts.append(f"【目标读者】\n姓名:{profile.name}")
if profile.age:
context_parts.append(f"年龄:{profile.age}")
if profile.interests:
context_parts.append(f"兴趣爱好:{''.join(profile.interests)}")
if profile.growth_themes:
context_parts.append(f"当前成长关注点:{''.join(profile.growth_themes)}")
context_parts.append("") # 空行
# 2. 故事宇宙 (Universe Layer)
if universe_id:
universe = await db.scalar(select(StoryUniverse).where(StoryUniverse.id == universe_id))
if universe:
context_parts.append("【故事宇宙设定】")
context_parts.append(f"世界观:{universe.name}")
# 主角
protagonist = universe.protagonist or {}
p_desc = f"{protagonist.get('name', '主角')} ({protagonist.get('personality', '')})"
context_parts.append(f"主角设定:{p_desc}")
# 常驻角色
if universe.recurring_characters:
chars = [f"{c.get('name')} ({c.get('type')})" for c in universe.recurring_characters if isinstance(c, dict)]
context_parts.append(f"已知伙伴:{''.join(chars)}")
# 成就
if universe.achievements:
badges = [str(a.get('type')) for a in universe.achievements if isinstance(a, dict)]
if badges:
context_parts.append(f"已获荣誉:{''.join(badges[:5])}")
context_parts.append("")
# 3. 动态记忆 (Working Memory)
if profile_id:
memories = await _fetch_scored_memories(profile_id, universe_id, db)
if memories:
memory_text = _format_memories_to_prompt(memories)
if memory_text:
context_parts.append("【关键记忆回忆】(请在故事中自然地融入或致敬以下元素)")
context_parts.append(memory_text)
return "\n".join(context_parts)
async def _fetch_scored_memories(
profile_id: str,
universe_id: str | None,
db: AsyncSession,
limit: int = 8
) -> list[MemoryItem]:
"""获取并评分记忆项,返回 Top N。"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
# 取最近 50 条进行评分
query = query.order_by(MemoryItem.last_used_at.desc(), MemoryItem.created_at.desc()).limit(50)
result = await db.execute(query)
items = result.scalars().all()
scored: list[tuple[float, MemoryItem]] = []
now = datetime.now(timezone.utc)
for item in items:
reference = item.last_used_at or item.created_at or now
delta_days = max((now - reference).total_seconds() / 86400, 0)
if item.ttl_days and delta_days > item.ttl_days:
continue
score = (item.base_weight or 1.0) * _decay_factor(delta_days)
if score <= 0.1: # 忽略低权重
continue
scored.append((score, item))
scored.sort(key=lambda x: x[0], reverse=True)
return [item for _, item in scored[:limit]]
def _format_memories_to_prompt(memories: list[MemoryItem]) -> str:
"""将记忆项转换为自然语言指令。"""
lines = []
# 分类处理
recent_stories = []
favorites = []
scary = []
vocab = []
for m in memories:
if m.type == MemoryType.RECENT_STORY:
recent_stories.append(m)
elif m.type == MemoryType.FAVORITE_CHARACTER:
favorites.append(m)
elif m.type == MemoryType.SCARY_ELEMENT:
scary.append(m)
elif m.type == MemoryType.VOCABULARY_GROWTH:
vocab.append(m)
# 1. 喜欢的角色
if favorites:
names = []
for m in favorites:
val = m.value
if isinstance(val, dict):
names.append(f"{val.get('name')} ({val.get('description', '')})")
if names:
lines.append(f"- 孩子特别喜欢这些角色,可以让他们客串出场:{', '.join(names)}")
# 2. 避雷区
if scary:
items = []
for m in scary:
val = m.value
if isinstance(val, dict):
items.append(val.get('keyword', ''))
elif isinstance(val, str):
items.append(val)
if items:
lines.append(f"- 【注意禁止】不要出现以下让孩子害怕的元素:{', '.join(items)}")
# 3. 近期故事 (取最近 2 个)
if recent_stories:
lines.append("- 近期经历(可作为彩蛋提及):")
for m in recent_stories[:2]:
val = m.value
if isinstance(val, dict):
title = val.get('title', '未知故事')
lines.append(f" * 之前读过《{title}")
# 4. 词汇积累
if vocab:
words = []
for m in vocab:
val = m.value
if isinstance(val, dict):
words.append(val.get('word'))
if words:
lines.append(f"- 已掌握词汇(可适当复现以巩固):{', '.join([w for w in words if w])}")
return "\n".join(lines)
async def prune_expired_memories(db: AsyncSession) -> int:
"""清理过期的记忆项。
Returns:
删除的记录数量
"""
from sqlalchemy import delete
now = datetime.now(timezone.utc)
# 查找所有设置了 TTL 的项目
stmt = select(MemoryItem).where(MemoryItem.ttl_days.is_not(None))
result = await db.execute(stmt)
candidates = result.scalars().all()
to_delete_ids = []
for item in candidates:
if not item.ttl_days:
continue
reference = item.last_used_at or item.created_at or now
delta_days = (now - reference).total_seconds() / 86400
if delta_days > item.ttl_days:
to_delete_ids.append(item.id)
if not to_delete_ids:
return 0
delete_stmt = delete(MemoryItem).where(MemoryItem.id.in_(to_delete_ids))
await db.execute(delete_stmt)
await db.commit()
logger.info("memory_pruned", count=len(to_delete_ids))
return len(to_delete_ids)
async def create_memory(
db: AsyncSession,
profile_id: str,
memory_type: str,
value: dict,
universe_id: str | None = None,
weight: float | None = None,
ttl_days: int | None = None,
) -> MemoryItem:
"""创建新的记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 记忆类型 (使用 MemoryType 常量)
value: 记忆内容 (JSON 格式)
universe_id: 可选,关联的故事宇宙 ID
weight: 可选,权重 (默认使用类型配置)
ttl_days: 可选,过期天数 (默认使用类型配置)
Returns:
创建的 MemoryItem
"""
memory = MemoryItem(
child_profile_id=profile_id,
universe_id=universe_id,
type=memory_type,
value=value,
base_weight=weight or MemoryType.get_default_weight(memory_type),
ttl_days=ttl_days if ttl_days is not None else MemoryType.get_default_ttl(memory_type),
)
db.add(memory)
await db.commit()
await db.refresh(memory)
logger.info(
"memory_created",
memory_id=memory.id,
profile_id=profile_id,
type=memory_type,
)
return memory
async def update_memory_usage(db: AsyncSession, memory_id: str) -> None:
"""更新记忆的最后使用时间。
Args:
db: 数据库会话
memory_id: 记忆项 ID
"""
result = await db.execute(select(MemoryItem).where(MemoryItem.id == memory_id))
memory = result.scalar_one_or_none()
if memory:
memory.last_used_at = datetime.now(timezone.utc)
await db.commit()
logger.debug("memory_usage_updated", memory_id=memory_id)
async def get_profile_memories(
db: AsyncSession,
profile_id: str,
memory_type: str | None = None,
universe_id: str | None = None,
limit: int = 50,
) -> list[MemoryItem]:
"""获取档案的记忆列表。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 可选,按类型筛选
universe_id: 可选,按宇宙筛选
limit: 返回数量限制
Returns:
MemoryItem 列表
"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if memory_type:
query = query.where(MemoryItem.type == memory_type)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
query = query.order_by(MemoryItem.created_at.desc()).limit(limit)
result = await db.execute(query)
return list(result.scalars().all())
async def create_story_memory(
db: AsyncSession,
profile_id: str,
story_id: int,
title: str,
summary: str | None = None,
keywords: list[str] | None = None,
universe_id: str | None = None,
) -> MemoryItem:
"""为故事创建记忆项。
这是一个便捷函数,专门用于在故事阅读后创建 recent_story 类型的记忆。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
story_id: 故事 ID
title: 故事标题
summary: 故事梗概
keywords: 关键词列表
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"story_id": story_id,
"title": title,
"summary": summary or "",
"keywords": keywords or [],
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.RECENT_STORY,
value=value,
universe_id=universe_id,
)
async def create_character_memory(
db: AsyncSession,
profile_id: str,
name: str,
description: str | None = None,
source_story_id: int | None = None,
affinity_score: float = 1.0,
universe_id: str | None = None,
) -> MemoryItem:
"""为喜欢的角色创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
name: 角色名称
description: 角色描述
source_story_id: 来源故事 ID
affinity_score: 喜爱程度 (0.0-1.0)
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"name": name,
"description": description or "",
"source_story_id": source_story_id,
"affinity_score": min(1.0, max(0.0, affinity_score)),
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.FAVORITE_CHARACTER,
value=value,
universe_id=universe_id,
)
async def create_scary_element_memory(
db: AsyncSession,
profile_id: str,
keyword: str,
category: str = "other",
source_story_id: int | None = None,
) -> MemoryItem:
"""为回避元素创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
keyword: 回避的关键词
category: 分类 (creature/scene/action/other)
source_story_id: 来源故事 ID
Returns:
创建的 MemoryItem
"""
value = {
"keyword": keyword,
"category": category,
"source_story_id": source_story_id,
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.SCARY_ELEMENT,
value=value,
)

View File

@@ -0,0 +1,31 @@
"""In-memory cache for providers loaded from DB."""
from collections import defaultdict
from typing import Literal
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.admin_models import Provider
ProviderType = Literal["text", "image", "tts", "storybook"]
_cache: dict[ProviderType, list[Provider]] = defaultdict(list)
async def reload_providers(db: AsyncSession):
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
providers = result.scalars().all()
grouped: dict[ProviderType, list[Provider]] = defaultdict(list)
for p in providers:
grouped[p.type].append(p)
# sort by priority desc, then weight desc
for k in grouped:
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
_cache.clear()
_cache.update(grouped)
return _cache
def get_providers(provider_type: ProviderType) -> list[Provider]:
return _cache.get(provider_type, [])

View File

@@ -0,0 +1,248 @@
"""供应商指标收集和健康检查服务。"""
from datetime import datetime, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import ProviderHealth, ProviderMetrics
if TYPE_CHECKING:
from app.services.adapters.base import BaseAdapter
logger = get_logger(__name__)
# 熔断阈值:连续失败次数
CIRCUIT_BREAKER_THRESHOLD = 3
# 熔断恢复时间(秒)
CIRCUIT_BREAKER_RECOVERY_SECONDS = 60
class MetricsCollector:
"""供应商调用指标收集器。"""
async def record_call(
self,
db: AsyncSession,
provider_id: str,
success: bool,
latency_ms: int | None = None,
cost_usd: float | None = None,
error_message: str | None = None,
request_id: str | None = None,
) -> None:
"""记录一次 API 调用。"""
metric = ProviderMetrics(
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
cost_usd=Decimal(str(cost_usd)) if cost_usd else None,
error_message=error_message,
request_id=request_id,
)
db.add(metric)
await db.commit()
logger.debug(
"metrics_recorded",
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
)
async def get_success_rate(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的成功率。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(
func.count().filter(ProviderMetrics.success.is_(True)).label("success_count"),
func.count().label("total_count"),
).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
row = result.one()
success_count, total_count = row.success_count, row.total_count
if total_count == 0:
return 1.0 # 无数据时假设健康
return success_count / total_count
async def get_avg_latency(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的平均延迟(毫秒)。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.avg(ProviderMetrics.latency_ms)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
ProviderMetrics.latency_ms.isnot(None),
)
)
avg = result.scalar()
return float(avg) if avg else 0.0
async def get_total_cost(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的总成本USD"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.sum(ProviderMetrics.cost_usd)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
total = result.scalar()
return float(total) if total else 0.0
class HealthChecker:
"""供应商健康检查器。"""
async def check_provider(
self,
db: AsyncSession,
provider_id: str,
adapter: "BaseAdapter",
) -> bool:
"""执行健康检查并更新状态。"""
try:
is_healthy = await adapter.health_check()
except Exception as e:
logger.warning("health_check_failed", provider_id=provider_id, error=str(e))
is_healthy = False
await self.update_health_status(
db,
provider_id,
is_healthy,
error=None if is_healthy else "Health check failed",
)
return is_healthy
async def update_health_status(
self,
db: AsyncSession,
provider_id: str,
is_healthy: bool,
error: str | None = None,
) -> None:
"""更新供应商健康状态(含熔断逻辑)。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
now = datetime.utcnow()
if health is None:
health = ProviderHealth(
provider_id=provider_id,
is_healthy=is_healthy,
last_check=now,
consecutive_failures=0 if is_healthy else 1,
last_error=error,
)
db.add(health)
else:
health.last_check = now
if is_healthy:
health.is_healthy = True
health.consecutive_failures = 0
health.last_error = None
else:
health.consecutive_failures += 1
health.last_error = error
# 熔断逻辑
if health.consecutive_failures >= CIRCUIT_BREAKER_THRESHOLD:
health.is_healthy = False
logger.warning(
"circuit_breaker_triggered",
provider_id=provider_id,
consecutive_failures=health.consecutive_failures,
)
await db.commit()
async def record_call_result(
self,
db: AsyncSession,
provider_id: str,
success: bool,
error: str | None = None,
) -> None:
"""根据调用结果更新健康状态。"""
await self.update_health_status(db, provider_id, success, error)
async def get_healthy_providers(
self,
db: AsyncSession,
provider_ids: list[str],
) -> list[str]:
"""获取健康的供应商列表。"""
if not provider_ids:
return []
# 查询所有已记录的健康状态
result = await db.execute(
select(ProviderHealth.provider_id, ProviderHealth.is_healthy).where(
ProviderHealth.provider_id.in_(provider_ids),
)
)
health_map = {row[0]: row[1] for row in result.all()}
# 未记录的供应商默认健康,已记录但不健康的排除
return [
pid for pid in provider_ids
if pid not in health_map or health_map[pid]
]
async def is_healthy(
self,
db: AsyncSession,
provider_id: str,
) -> bool:
"""检查供应商是否健康。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
if health is None:
return True # 未记录默认健康
# 检查是否可以恢复
if not health.is_healthy and health.last_check:
recovery_time = health.last_check + timedelta(seconds=CIRCUIT_BREAKER_RECOVERY_SECONDS)
if datetime.utcnow() >= recovery_time:
return True # 允许重试
return health.is_healthy
# 全局单例
metrics_collector = MetricsCollector()
health_checker = HealthChecker()

View File

@@ -0,0 +1,432 @@
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
import time
from enum import Enum
from typing import TYPE_CHECKING, Literal, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.services.adapters import AdapterConfig, AdapterRegistry
from app.services.adapters.text.models import StoryOutput
from app.services.cost_tracker import cost_tracker
from app.services.provider_cache import get_providers
from app.services.provider_metrics import health_checker, metrics_collector
if TYPE_CHECKING:
from app.db.admin_models import Provider
logger = get_logger(__name__)
T = TypeVar("T")
ProviderType = Literal["text", "image", "tts", "storybook"]
class RoutingStrategy(str, Enum):
"""路由策略枚举。"""
PRIORITY = "priority" # 按优先级排序(默认)
COST = "cost" # 按成本排序
LATENCY = "latency" # 按延迟排序
ROUND_ROBIN = "round_robin" # 轮询
# 默认配置映射(当 DB 无配置时使用)
# 默认配置映射(当 DB 无配置时使用)
# 这是“代码级”的默认策略,对应 .env 为空的情况
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
"text": ["gemini", "openai"],
"image": ["cqtai"],
"tts": ["minimax", "elevenlabs", "edge_tts"],
"storybook": ["gemini"],
}
# API Key 映射adapter_name -> settings 属性名
API_KEY_MAP: dict[str, str] = {
# Text
"gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段
"text_primary": "text_api_key", # 兼容旧别名
"openai": "openai_api_key",
# Image
"cqtai": "cqtai_api_key",
"image_primary": "image_api_key", # 兼容旧别名
# TTS
"minimax": "minimax_api_key",
"elevenlabs": "elevenlabs_api_key",
"edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空)
"tts_primary": "tts_api_key", # 兼容旧别名
}
# 轮询计数器
_round_robin_counters: dict[ProviderType, int] = {
"text": 0,
"image": 0,
"tts": 0,
}
# 延迟缓存(内存中,简化实现)
_latency_cache: dict[str, float] = {}
def _get_api_key(config_ref: str | None, adapter_name: str) -> str:
"""根据 config_ref 或适配器名称获取 API Key。"""
# 优先使用 config_ref
key_attr = API_KEY_MAP.get(config_ref or adapter_name, None)
if key_attr:
return getattr(settings, key_attr, "")
# 回退到适配器名称
key_attr = API_KEY_MAP.get(adapter_name, None)
if key_attr:
return getattr(settings, key_attr, "")
return ""
def _get_default_config(adapter_name: str) -> AdapterConfig | None:
"""获取适配器的默认配置(无 DB 记录时使用)。返回 None 表示未知适配器。"""
# --- Text Defaults ---
if adapter_name in ("gemini", "text_primary"):
return AdapterConfig(
api_key=settings.text_api_key,
model=settings.text_model or "gemini-2.0-flash",
timeout_ms=60000,
)
if adapter_name == "openai":
return AdapterConfig(
api_key=getattr(settings, "openai_api_key", ""),
model="gpt-4o-mini", # 这里可以从 settings 读取,看需求
timeout_ms=60000,
)
# --- Image Defaults ---
if adapter_name in ("cqtai"):
return AdapterConfig(
api_key=getattr(settings, "cqtai_api_key", ""),
model="nano-banana-pro", # 默认使用 Pro
timeout_ms=120000,
)
if adapter_name == "image_primary":
# 如果还有地方在用 image_primary暂时映射到快或者其他
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
return AdapterConfig(
api_key=settings.image_api_key,
timeout_ms=120000
)
# --- TTS Defaults ---
if adapter_name == "minimax":
# 传递 group_id 到 Adapter
# 目前 AdapterConfig 没有 group_id 字段,我们暂时不改 Base
# 而是假设 Adapter 会从 config (通过 kwargs 或其他方式) 拿。
# 实际上我们的 MiniMaxTTSAdapter 还没有处理 group_id。
# 最简单的方法:把 group_id 藏在 api_base 里或者让 Adapter 自己去 settings 拿。
# 鉴于 _build_config_from_provider 里我们无法传递额外参数给 Adapter.__init__
# 我们这里暂时返回基础配置。
return AdapterConfig(
api_key=getattr(settings, "minimax_api_key", ""),
model="speech-2.6-turbo",
timeout_ms=60000,
)
if adapter_name == "elevenlabs":
return AdapterConfig(
api_key=getattr(settings, "elevenlabs_api_key", ""),
timeout_ms=120000,
)
if adapter_name in ("edge_tts", "tts_primary"):
return AdapterConfig(
api_key=settings.tts_api_key,
api_base=settings.tts_api_base,
model=settings.tts_model or "zh-CN-XiaoxiaoNeural",
timeout_ms=120000,
)
# --- Others ---
if adapter_name in ("storybook_primary", "storybook_gemini"):
return AdapterConfig(
api_key=settings.text_api_key, # 复用 Gemini key
model=settings.text_model,
timeout_ms=120000,
)
# 未知适配器返回 None
return None
def _build_config_from_provider(provider: "Provider") -> AdapterConfig:
"""从 DB Provider 记录构建 AdapterConfig。"""
api_key = getattr(provider, "api_key", None) or ""
if not api_key:
api_key = _get_api_key(provider.config_ref, provider.adapter)
default = _get_default_config(provider.adapter)
if default is None:
default = AdapterConfig(api_key="", timeout_ms=60000)
return AdapterConfig(
api_key=api_key or default.api_key,
api_base=provider.api_base or default.api_base,
model=provider.model or default.model,
timeout_ms=provider.timeout_ms or default.timeout_ms,
max_retries=provider.max_retries or default.max_retries,
extra_config=provider.config_json or {},
)
def _get_providers_with_config(
provider_type: ProviderType,
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
"""获取供应商列表及其配置。
Returns:
[(adapter_name, config, provider_or_none), ...] 按优先级排序
"""
db_providers = get_providers(provider_type)
if db_providers:
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
settings_map = {
"text": settings.text_providers,
"image": settings.image_providers,
"tts": settings.tts_providers,
}
names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type]
result = []
for name in names:
config = _get_default_config(name)
if config is None:
logger.warning("unknown_adapter_skipped", adapter=name, provider_type=provider_type)
continue
result.append((name, config, None))
return result
def _sort_by_strategy(
providers: list[tuple[str, AdapterConfig, "Provider | None"]],
strategy: RoutingStrategy,
provider_type: ProviderType,
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
"""按策略排序供应商列表。"""
if strategy == RoutingStrategy.PRIORITY:
# 按 priority 降序, weight 降序
return sorted(
providers,
key=lambda x: (-(x[2].priority if x[2] else 0), -(x[2].weight if x[2] else 1)),
)
if strategy == RoutingStrategy.COST:
# 按预估成本升序
def get_cost(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
adapter_class = AdapterRegistry.get(provider_type, item[0])
if adapter_class:
try:
adapter = adapter_class(item[1])
return adapter.estimated_cost
except Exception:
pass
return float("inf")
return sorted(providers, key=get_cost)
if strategy == RoutingStrategy.LATENCY:
# 按历史延迟升序
def get_latency(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
return _latency_cache.get(item[0], float("inf"))
return sorted(providers, key=get_latency)
if strategy == RoutingStrategy.ROUND_ROBIN:
# 轮询:旋转列表
counter = _round_robin_counters[provider_type]
_round_robin_counters[provider_type] = (counter + 1) % max(len(providers), 1)
return providers[counter:] + providers[:counter]
return providers
async def _route_with_failover(
provider_type: ProviderType,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
**kwargs,
) -> T:
"""通用 provider failover 路由。
Args:
provider_type: 供应商类型 (text/image/tts/storybook)
strategy: 路由策略
db: 数据库会话(可选,用于指标收集和熔断检查)
user_id: 用户 ID可选用于成本追踪和预算检查
**kwargs: 传递给适配器的参数
"""
providers = _get_providers_with_config(provider_type)
if not providers:
raise ValueError(f"No {provider_type} providers configured.")
# 按策略排序
sorted_providers = _sort_by_strategy(providers, strategy, provider_type)
# 如果有 db 会话,过滤掉熔断的供应商
if db:
healthy_providers = []
for item in sorted_providers:
name, config, db_provider = item
provider_id = db_provider.id if db_provider else name
if await health_checker.is_healthy(db, provider_id):
healthy_providers.append(item)
else:
logger.debug("provider_circuit_open", adapter=name, provider_id=provider_id)
# 如果所有供应商都熔断,仍然尝试第一个(允许恢复)
if not healthy_providers:
healthy_providers = sorted_providers[:1]
sorted_providers = healthy_providers
errors: list[str] = []
for name, config, db_provider in sorted_providers:
adapter_class = AdapterRegistry.get(provider_type, name)
if not adapter_class:
errors.append(f"{name}: 适配器未注册")
continue
provider_id = db_provider.id if db_provider else name
try:
logger.debug(
"provider_attempt",
provider_type=provider_type,
adapter=name,
strategy=strategy.value,
)
adapter = adapter_class(config)
# 执行并计时
start_time = time.time()
result = await adapter.execute(**kwargs)
latency_ms = int((time.time() - start_time) * 1000)
# 更新延迟缓存
_latency_cache[name] = latency_ms
# 记录成功指标
if db:
await metrics_collector.record_call(
db,
provider_id=provider_id,
success=True,
latency_ms=latency_ms,
cost_usd=adapter.estimated_cost,
)
await health_checker.record_call_result(db, provider_id, success=True)
# 记录用户成本
if user_id:
await cost_tracker.record_cost(
db,
user_id=user_id,
provider_name=name,
capability=provider_type,
estimated_cost=adapter.estimated_cost,
provider_id=provider_id if db_provider else None,
)
logger.info(
"provider_success",
provider_type=provider_type,
adapter=name,
latency_ms=latency_ms,
)
return result
except Exception as exc:
error_msg = str(exc)
logger.warning(
"provider_failed",
provider_type=provider_type,
adapter=name,
error=error_msg,
)
errors.append(f"{name}: {exc}")
# 记录失败指标
if db:
await metrics_collector.record_call(
db,
provider_id=provider_id,
success=False,
error_message=error_msg,
)
await health_checker.record_call_result(
db, provider_id, success=False, error=error_msg
)
raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}")
async def generate_story_content(
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
) -> StoryOutput:
"""生成或润色故事,支持 failover。"""
return await _route_with_failover(
"text",
strategy=strategy,
db=db,
input_type=input_type,
data=data,
education_theme=education_theme,
memory_context=memory_context,
)
async def generate_image(
prompt: str,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
**kwargs,
) -> str:
"""生成图片,返回 URL支持 failover。"""
return await _route_with_failover("image", strategy=strategy, db=db, prompt=prompt, **kwargs)
async def text_to_speech(
text: str,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
) -> bytes:
"""文本转语音,返回 MP3 bytes支持 failover。"""
return await _route_with_failover("tts", strategy=strategy, db=db, text=text)
async def generate_storybook(
keywords: str,
page_count: int = 6,
education_theme: str | None = None,
memory_context: str | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
):
"""生成分页故事书,支持 failover。"""
from app.services.adapters.storybook.primary import Storybook
result: Storybook = await _route_with_failover(
"storybook",
strategy=strategy,
db=db,
keywords=keywords,
page_count=page_count,
education_theme=education_theme,
memory_context=memory_context,
)
return result

View File

@@ -0,0 +1,207 @@
"""供应商密钥加密存储服务。
使用 Fernet 对称加密,密钥从 SECRET_KEY 派生。
"""
import base64
import hashlib
from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.admin_models import ProviderSecret
if TYPE_CHECKING:
pass
logger = get_logger(__name__)
class SecretEncryptionError(Exception):
"""密钥加密/解密错误。"""
pass
class SecretService:
"""供应商密钥加密存储服务。"""
_fernet: Fernet | None = None
@classmethod
def _get_fernet(cls) -> Fernet:
"""获取 Fernet 实例,从 SECRET_KEY 派生加密密钥。"""
if cls._fernet is None:
# 从 SECRET_KEY 派生 32 字节密钥
key_bytes = hashlib.sha256(settings.secret_key.encode()).digest()
fernet_key = base64.urlsafe_b64encode(key_bytes)
cls._fernet = Fernet(fernet_key)
return cls._fernet
@classmethod
def encrypt(cls, plaintext: str) -> str:
"""加密明文,返回 base64 编码的密文。
Args:
plaintext: 要加密的明文
Returns:
base64 编码的密文
"""
if not plaintext:
return ""
fernet = cls._get_fernet()
encrypted = fernet.encrypt(plaintext.encode())
return encrypted.decode()
@classmethod
def decrypt(cls, ciphertext: str) -> str:
"""解密密文,返回明文。
Args:
ciphertext: base64 编码的密文
Returns:
解密后的明文
Raises:
SecretEncryptionError: 解密失败
"""
if not ciphertext:
return ""
try:
fernet = cls._get_fernet()
decrypted = fernet.decrypt(ciphertext.encode())
return decrypted.decode()
except InvalidToken as e:
logger.error("secret_decrypt_failed", error=str(e))
raise SecretEncryptionError("密钥解密失败,可能是 SECRET_KEY 已更改") from e
@classmethod
async def get_secret(cls, db: AsyncSession, name: str) -> str | None:
"""从数据库获取并解密密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
解密后的密钥值,不存在返回 None
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return None
return cls.decrypt(secret.encrypted_value)
@classmethod
async def set_secret(cls, db: AsyncSession, name: str, value: str) -> ProviderSecret:
"""存储或更新加密密钥。
Args:
db: 数据库会话
name: 密钥名称
value: 密钥明文值
Returns:
ProviderSecret 实例
"""
encrypted = cls.encrypt(value)
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
secret = ProviderSecret(name=name, encrypted_value=encrypted)
db.add(secret)
else:
secret.encrypted_value = encrypted
await db.commit()
await db.refresh(secret)
logger.info("secret_stored", name=name)
return secret
@classmethod
async def delete_secret(cls, db: AsyncSession, name: str) -> bool:
"""删除密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
是否删除成功
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return False
await db.delete(secret)
await db.commit()
logger.info("secret_deleted", name=name)
return True
@classmethod
async def list_secrets(cls, db: AsyncSession) -> list[str]:
"""列出所有密钥名称(不返回值)。
Args:
db: 数据库会话
Returns:
密钥名称列表
"""
result = await db.execute(select(ProviderSecret.name))
return [row[0] for row in result.fetchall()]
@classmethod
async def get_api_key(
cls,
db: AsyncSession,
provider_api_key: str | None,
config_ref: str | None,
) -> str | None:
"""获取 Provider 的 API Key按优先级查找。
优先级:
1. provider.api_key (数据库明文/加密)
2. provider.config_ref 指向的 ProviderSecret
3. 环境变量 (config_ref 作为变量名)
Args:
db: 数据库会话
provider_api_key: Provider 表中的 api_key 字段
config_ref: Provider 表中的 config_ref 字段
Returns:
API Key 或 None
"""
# 1. 直接使用 provider.api_key
if provider_api_key:
# 尝试解密,如果失败则当作明文
try:
decrypted = cls.decrypt(provider_api_key)
if decrypted:
return decrypted
except SecretEncryptionError:
pass
return provider_api_key
# 2. 从 ProviderSecret 表查找
if config_ref:
secret_value = await cls.get_secret(db, config_ref)
if secret_value:
return secret_value
# 3. 从环境变量查找
env_value = getattr(settings, config_ref.lower(), None)
if env_value:
return env_value
return None

View File

@@ -0,0 +1,3 @@
"""Celery tasks package."""
from . import achievements, memory, push_notifications # noqa: F401

View File

@@ -0,0 +1,82 @@
"""Celery tasks for achievements."""
import asyncio
from datetime import datetime
from sqlalchemy import select
from app.core.celery_app import celery_app
from app.core.logging import get_logger
from app.db.database import _get_session_factory
from app.db.models import Story, StoryUniverse
from app.services.achievement_extractor import extract_achievements
logger = get_logger(__name__)
@celery_app.task
def extract_story_achievements(story_id: int, universe_id: str) -> None:
"""Extract achievements and update universe."""
asyncio.run(_extract_story_achievements(story_id, universe_id))
async def _extract_story_achievements(story_id: int, universe_id: str) -> None:
session_factory = _get_session_factory()
async with session_factory() as session:
result = await session.execute(select(Story).where(Story.id == story_id))
story = result.scalar_one_or_none()
if not story:
logger.warning("achievement_task_story_missing", story_id=story_id)
return
result = await session.execute(
select(StoryUniverse).where(StoryUniverse.id == universe_id)
)
universe = result.scalar_one_or_none()
if not universe:
logger.warning("achievement_task_universe_missing", universe_id=universe_id)
return
text_content = story.story_text
if not text_content and story.pages:
# 如果是绘本,拼接每页文本
text_content = "\n".join([str(p.get("text", "")) for p in story.pages])
if not text_content:
logger.warning("achievement_task_empty_content", story_id=story_id)
return
achievements = await extract_achievements(text_content)
if not achievements:
logger.info("achievement_task_no_new", story_id=story_id)
return
existing = {
(str(item.get("type", "")).strip(), str(item.get("description", "")).strip())
for item in (universe.achievements or [])
if isinstance(item, dict)
}
merged = list(universe.achievements or [])
added_count = 0
for item in achievements:
key = (item.get("type", "").strip(), item.get("description", "").strip())
if key in existing:
continue
merged.append({
"type": key[0],
"description": key[1],
"obtained_at": datetime.now().isoformat(),
"source_story_id": story_id,
})
existing.add(key)
added_count += 1
universe.achievements = merged
await session.commit()
logger.info(
"achievement_task_success",
story_id=story_id,
universe_id=universe_id,
added=added_count,
)

View File

@@ -0,0 +1,29 @@
import asyncio
from app.core.celery_app import celery_app
from app.core.logging import get_logger
from app.db.database import _get_session_factory
from app.services.memory_service import prune_expired_memories
logger = get_logger(__name__)
@celery_app.task
def prune_memories_task():
"""Daily task to prune expired memories."""
logger.info("prune_memories_task_started")
async def _run():
# Ensure engine is initialized in this process
session_factory = _get_session_factory()
async with session_factory() as session:
return await prune_expired_memories(session)
try:
# Create a new event loop for this task execution
count = asyncio.run(_run())
logger.info("prune_memories_task_completed", deleted_count=count)
return f"Deleted {count} expired memories"
except Exception as exc:
logger.error("prune_memories_task_failed", error=str(exc))
raise

View File

@@ -0,0 +1,108 @@
"""Celery tasks for push notifications."""
import asyncio
from datetime import datetime, time
from zoneinfo import ZoneInfo
from sqlalchemy import select
from app.core.celery_app import celery_app
from app.core.logging import get_logger
from app.db.database import _get_session_factory
from app.db.models import PushConfig, PushEvent
logger = get_logger(__name__)
LOCAL_TZ = ZoneInfo("Asia/Shanghai")
QUIET_HOURS_START = time(21, 0)
QUIET_HOURS_END = time(9, 0)
TRIGGER_WINDOW_MINUTES = 30
@celery_app.task
def check_push_notifications() -> None:
"""Check push configs and create push events."""
asyncio.run(_check_push_notifications())
def _is_quiet_hours(current: time) -> bool:
if QUIET_HOURS_START < QUIET_HOURS_END:
return QUIET_HOURS_START <= current < QUIET_HOURS_END
return current >= QUIET_HOURS_START or current < QUIET_HOURS_END
def _within_window(current: time, target: time) -> bool:
current_minutes = current.hour * 60 + current.minute
target_minutes = target.hour * 60 + target.minute
return 0 <= current_minutes - target_minutes < TRIGGER_WINDOW_MINUTES
async def _already_sent_today(
session,
child_profile_id: str,
now: datetime,
) -> bool:
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
end = now.replace(hour=23, minute=59, second=59, microsecond=999999)
result = await session.execute(
select(PushEvent.id).where(
PushEvent.child_profile_id == child_profile_id,
PushEvent.status == "sent",
PushEvent.sent_at >= start,
PushEvent.sent_at <= end,
)
)
return result.scalar_one_or_none() is not None
async def _check_push_notifications() -> None:
session_factory = _get_session_factory()
now = datetime.now(LOCAL_TZ)
current_day = now.weekday()
current_time = now.time()
async with session_factory() as session:
result = await session.execute(
select(PushConfig).where(PushConfig.enabled.is_(True))
)
configs = result.scalars().all()
for config in configs:
if not config.push_time:
continue
if config.push_days and current_day not in config.push_days:
continue
if not _within_window(current_time, config.push_time):
continue
if _is_quiet_hours(current_time):
session.add(
PushEvent(
user_id=config.user_id,
child_profile_id=config.child_profile_id,
trigger_type="time",
status="suppressed",
reason="quiet_hours",
sent_at=now,
)
)
continue
if await _already_sent_today(session, config.child_profile_id, now):
continue
session.add(
PushEvent(
user_id=config.user_id,
child_profile_id=config.child_profile_id,
trigger_type="time",
status="sent",
reason=None,
sent_at=now,
)
)
logger.info(
"push_event_sent",
child_profile_id=config.child_profile_id,
user_id=config.user_id,
)
await session.commit()