Initial commit: clean project structure
- Backend: FastAPI + SQLAlchemy + Celery (Python 3.11+) - Frontend: Vue 3 + TypeScript + Pinia + Tailwind - Admin Frontend: separate Vue 3 app for management - Docker Compose: 9 services orchestration - Specs: design prototypes, memory system PRD, product roadmap Cleanup performed: - Removed temporary debug scripts from backend root - Removed deprecated admin_app.py (embedded UI) - Removed duplicate docs from admin-frontend - Updated .gitignore for Vite cache and egg-info
This commit is contained in:
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
307
backend/app/api/admin_providers.py
Normal file
307
backend/app/api/admin_providers.py
Normal file
@@ -0,0 +1,307 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.admin_auth import admin_guard
|
||||
from app.db.admin_models import Provider
|
||||
from app.db.database import get_db
|
||||
from app.services.cost_tracker import cost_tracker
|
||||
from app.services.secret_service import SecretService
|
||||
|
||||
router = APIRouter(dependencies=[Depends(admin_guard)])
|
||||
|
||||
|
||||
class ProviderCreate(BaseModel):
|
||||
name: str
|
||||
type: str = Field(..., pattern="^(text|image|tts|storybook)$")
|
||||
adapter: str
|
||||
model: str | None = None
|
||||
api_base: str | None = None
|
||||
api_key: str | None = None # 可选,优先于 config_ref
|
||||
timeout_ms: int = 60000
|
||||
max_retries: int = 1
|
||||
weight: int = 1
|
||||
priority: int = 0
|
||||
enabled: bool = True
|
||||
config_json: dict | None = None
|
||||
config_ref: str | None = None # 环境变量 key 名称(回退)
|
||||
updated_by: str | None = None
|
||||
|
||||
|
||||
class ProviderUpdate(ProviderCreate):
|
||||
enabled: bool | None = None
|
||||
api_key: str | None = None
|
||||
config_json: dict | None = None
|
||||
|
||||
|
||||
class ProviderResponse(BaseModel):
|
||||
"""Provider 响应模型,隐藏敏感字段。"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
adapter: str
|
||||
model: str | None = None
|
||||
api_base: str | None = None
|
||||
has_api_key: bool = False # 仅标识是否配置了 api_key,不返回明文
|
||||
timeout_ms: int = 60000
|
||||
max_retries: int = 1
|
||||
weight: int = 1
|
||||
priority: int = 0
|
||||
enabled: bool = True
|
||||
config_ref: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
from app.services.provider_router import DEFAULT_PROVIDERS
|
||||
|
||||
|
||||
@router.get("/providers/adapters")
|
||||
async def list_available_adapters():
|
||||
"""获取所有可用的适配器类型 (定义的类)。"""
|
||||
return AdapterRegistry.list_adapters()
|
||||
|
||||
|
||||
@router.get("/providers/defaults")
|
||||
async def get_env_defaults():
|
||||
"""获取当前环境变量定义的默认策略 (Read-Only)。"""
|
||||
return DEFAULT_PROVIDERS
|
||||
|
||||
|
||||
@router.get("/providers", response_model=list[ProviderResponse])
|
||||
async def list_providers(db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Provider))
|
||||
providers = result.scalars().all()
|
||||
# 转换为响应模型,隐藏 api_key 明文
|
||||
return [
|
||||
ProviderResponse(
|
||||
id=p.id,
|
||||
name=p.name,
|
||||
type=p.type,
|
||||
adapter=p.adapter,
|
||||
model=p.model,
|
||||
api_base=p.api_base,
|
||||
has_api_key=bool(p.api_key), # 仅标识是否有 key
|
||||
timeout_ms=p.timeout_ms,
|
||||
max_retries=p.max_retries,
|
||||
weight=p.weight,
|
||||
priority=p.priority,
|
||||
enabled=p.enabled,
|
||||
config_ref=p.config_ref,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
|
||||
|
||||
def _to_response(provider: Provider) -> ProviderResponse:
|
||||
"""将 Provider 转换为响应模型,隐藏敏感字段。"""
|
||||
return ProviderResponse(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
type=provider.type,
|
||||
adapter=provider.adapter,
|
||||
model=provider.model,
|
||||
api_base=provider.api_base,
|
||||
has_api_key=bool(provider.api_key),
|
||||
timeout_ms=provider.timeout_ms,
|
||||
max_retries=provider.max_retries,
|
||||
weight=provider.weight,
|
||||
priority=provider.priority,
|
||||
enabled=provider.enabled,
|
||||
config_ref=provider.config_ref,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/providers", response_model=ProviderResponse)
|
||||
async def create_provider(payload: ProviderCreate, db: AsyncSession = Depends(get_db)):
|
||||
data = payload.model_dump()
|
||||
# 加密 API Key
|
||||
if data.get("api_key"):
|
||||
data["api_key"] = SecretService.encrypt(data["api_key"])
|
||||
provider = Provider(**data)
|
||||
db.add(provider)
|
||||
await db.commit()
|
||||
await db.refresh(provider)
|
||||
return _to_response(provider)
|
||||
|
||||
|
||||
@router.put("/providers/{provider_id}", response_model=ProviderResponse)
|
||||
async def update_provider(
|
||||
provider_id: str, payload: ProviderUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
data = payload.model_dump(exclude_unset=True)
|
||||
# 加密 API Key
|
||||
if "api_key" in data and data["api_key"]:
|
||||
data["api_key"] = SecretService.encrypt(data["api_key"])
|
||||
for k, v in data.items():
|
||||
setattr(provider, k, v)
|
||||
await db.commit()
|
||||
await db.refresh(provider)
|
||||
return _to_response(provider)
|
||||
|
||||
|
||||
@router.delete("/providers/{provider_id}")
|
||||
async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
await db.delete(provider)
|
||||
await db.commit()
|
||||
return {"message": "deleted"}
|
||||
|
||||
|
||||
# ==================== 密钥管理 API ====================
|
||||
|
||||
|
||||
class SecretCreate(BaseModel):
|
||||
"""密钥创建请求。"""
|
||||
|
||||
name: str = Field(..., description="密钥名称,如 CQTAI_API_KEY")
|
||||
value: str = Field(..., description="密钥明文值")
|
||||
|
||||
|
||||
class SecretResponse(BaseModel):
|
||||
"""密钥响应,不返回明文。"""
|
||||
|
||||
name: str
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
|
||||
|
||||
@router.get("/secrets", response_model=list[str])
|
||||
async def list_secrets(db: AsyncSession = Depends(get_db)):
|
||||
"""列出所有密钥名称(不返回值)。"""
|
||||
return await SecretService.list_secrets(db)
|
||||
|
||||
|
||||
@router.post("/secrets", response_model=SecretResponse)
|
||||
async def create_or_update_secret(payload: SecretCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""创建或更新密钥。"""
|
||||
secret = await SecretService.set_secret(db, payload.name, payload.value)
|
||||
return SecretResponse(
|
||||
name=secret.name,
|
||||
created_at=secret.created_at.isoformat() if secret.created_at else None,
|
||||
updated_at=secret.updated_at.isoformat() if secret.updated_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/secrets/{name}")
|
||||
async def delete_secret(name: str, db: AsyncSession = Depends(get_db)):
|
||||
"""删除密钥。"""
|
||||
deleted = await SecretService.delete_secret(db, name)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Secret not found")
|
||||
return {"message": "deleted"}
|
||||
|
||||
|
||||
@router.get("/secrets/{name}/verify")
|
||||
async def verify_secret(name: str, db: AsyncSession = Depends(get_db)):
|
||||
"""验证密钥是否存在且可解密(不返回明文)。"""
|
||||
value = await SecretService.get_secret(db, name)
|
||||
if value is None:
|
||||
raise HTTPException(status_code=404, detail="Secret not found")
|
||||
return {"name": name, "valid": True, "length": len(value)}
|
||||
|
||||
|
||||
# ==================== 成本追踪 API ====================
|
||||
|
||||
|
||||
class BudgetUpdate(BaseModel):
|
||||
"""预算更新请求。"""
|
||||
|
||||
daily_limit_usd: float | None = None
|
||||
monthly_limit_usd: float | None = None
|
||||
alert_threshold: float | None = Field(default=None, ge=0, le=1)
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
@router.get("/costs/summary/{user_id}")
|
||||
async def get_user_cost_summary(user_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""获取用户成本摘要。"""
|
||||
return await cost_tracker.get_cost_summary(db, user_id)
|
||||
|
||||
|
||||
@router.get("/costs/all")
|
||||
async def get_all_costs_summary(db: AsyncSession = Depends(get_db)):
|
||||
"""获取所有用户成本汇总(管理员)。"""
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.db.admin_models import CostRecord
|
||||
|
||||
# 按用户汇总
|
||||
result = await db.execute(
|
||||
select(
|
||||
CostRecord.user_id,
|
||||
func.sum(CostRecord.estimated_cost).label("total_cost"),
|
||||
func.count().label("call_count"),
|
||||
).group_by(CostRecord.user_id)
|
||||
)
|
||||
users = [
|
||||
{"user_id": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]}
|
||||
for row in result.all()
|
||||
]
|
||||
|
||||
# 按能力汇总
|
||||
result = await db.execute(
|
||||
select(
|
||||
CostRecord.capability,
|
||||
func.sum(CostRecord.estimated_cost).label("total_cost"),
|
||||
func.count().label("call_count"),
|
||||
).group_by(CostRecord.capability)
|
||||
)
|
||||
capabilities = [
|
||||
{"capability": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]}
|
||||
for row in result.all()
|
||||
]
|
||||
|
||||
return {"by_user": users, "by_capability": capabilities}
|
||||
|
||||
|
||||
@router.get("/budgets/{user_id}")
|
||||
async def get_user_budget(user_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""获取用户预算配置。"""
|
||||
budget = await cost_tracker.get_user_budget(db, user_id)
|
||||
if not budget:
|
||||
return {"user_id": user_id, "budget": None}
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"budget": {
|
||||
"daily_limit_usd": float(budget.daily_limit_usd),
|
||||
"monthly_limit_usd": float(budget.monthly_limit_usd),
|
||||
"alert_threshold": float(budget.alert_threshold),
|
||||
"enabled": budget.enabled,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/budgets/{user_id}")
|
||||
async def set_user_budget(
|
||||
user_id: str, payload: BudgetUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""设置用户预算。"""
|
||||
budget = await cost_tracker.set_user_budget(
|
||||
db,
|
||||
user_id,
|
||||
daily_limit=payload.daily_limit_usd,
|
||||
monthly_limit=payload.monthly_limit_usd,
|
||||
alert_threshold=payload.alert_threshold,
|
||||
enabled=payload.enabled,
|
||||
)
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"budget": {
|
||||
"daily_limit_usd": float(budget.daily_limit_usd),
|
||||
"monthly_limit_usd": float(budget.monthly_limit_usd),
|
||||
"alert_threshold": float(budget.alert_threshold),
|
||||
"enabled": budget.enabled,
|
||||
},
|
||||
}
|
||||
14
backend/app/api/admin_reload.py
Normal file
14
backend/app/api/admin_reload.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.admin_auth import admin_guard
|
||||
from app.db.database import get_db
|
||||
from app.services.provider_cache import reload_providers
|
||||
|
||||
router = APIRouter(dependencies=[Depends(admin_guard)])
|
||||
|
||||
|
||||
@router.post("/providers/reload")
|
||||
async def reload(db: AsyncSession = Depends(get_db)):
|
||||
cache = await reload_providers(db)
|
||||
return {k: len(v) for k, v in cache.items()}
|
||||
272
backend/app/api/auth.py
Normal file
272
backend/app/api/auth.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import secrets
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.deps import get_current_user
|
||||
from app.core.security import create_access_token
|
||||
from app.db.database import get_db
|
||||
from app.db.models import User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# OAuth endpoints
|
||||
GITHUB_AUTHORIZE_URL = "https://github.com/login/oauth/authorize"
|
||||
GITHUB_TOKEN_URL = "https://github.com/login/oauth/access_token"
|
||||
GITHUB_USER_URL = "https://api.github.com/user"
|
||||
|
||||
GOOGLE_AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
GOOGLE_USER_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
STATE_COOKIE = "oauth_state"
|
||||
STATE_MAX_AGE = 600 # 10 minutes
|
||||
|
||||
|
||||
def _set_state_cookie(response: RedirectResponse, provider: str, state: str) -> None:
|
||||
response.set_cookie(
|
||||
key=STATE_COOKIE,
|
||||
value=f"{provider}:{state}",
|
||||
httponly=True,
|
||||
secure=not settings.debug,
|
||||
samesite="lax",
|
||||
max_age=STATE_MAX_AGE,
|
||||
)
|
||||
|
||||
|
||||
def _validate_state(state_from_query: str | None, state_cookie: str | None, provider: str):
|
||||
if not state_from_query or not state_cookie:
|
||||
raise HTTPException(status_code=400, detail="Missing OAuth state")
|
||||
expected_prefix = f"{provider}:"
|
||||
if not state_cookie.startswith(expected_prefix):
|
||||
raise HTTPException(status_code=400, detail="OAuth state mismatch")
|
||||
expected_state = state_cookie.removeprefix(expected_prefix)
|
||||
if not secrets.compare_digest(state_from_query, expected_state):
|
||||
raise HTTPException(status_code=400, detail="OAuth state mismatch")
|
||||
|
||||
|
||||
@router.get("/github/signin")
|
||||
async def github_signin():
|
||||
"""Start GitHub OAuth with state protection."""
|
||||
state = secrets.token_urlsafe(16)
|
||||
params = {
|
||||
"client_id": settings.github_client_id,
|
||||
"redirect_uri": f"{settings.base_url}/auth/github/callback",
|
||||
"scope": "read:user user:email",
|
||||
"state": state,
|
||||
}
|
||||
url = f"{GITHUB_AUTHORIZE_URL}?{urlencode(params)}"
|
||||
response = RedirectResponse(url=url)
|
||||
_set_state_cookie(response, "github", state)
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/github/callback")
|
||||
async def github_callback(
|
||||
code: str,
|
||||
state: str | None = Query(default=None),
|
||||
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Handle GitHub OAuth callback."""
|
||||
_validate_state(state, state_cookie, "github")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_resp = await client.post(
|
||||
GITHUB_TOKEN_URL,
|
||||
data={
|
||||
"client_id": settings.github_client_id,
|
||||
"client_secret": settings.github_client_secret,
|
||||
"code": code,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
token_resp.raise_for_status()
|
||||
token_data = token_resp.json()
|
||||
access_token = token_data.get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(status_code=502, detail="GitHub login failed")
|
||||
|
||||
user_resp = await client.get(
|
||||
GITHUB_USER_URL,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
user_resp.raise_for_status()
|
||||
user_data = user_resp.json()
|
||||
except httpx.HTTPStatusError:
|
||||
raise HTTPException(status_code=502, detail="GitHub login failed")
|
||||
|
||||
github_id = user_data.get("id")
|
||||
if github_id is None:
|
||||
raise HTTPException(status_code=502, detail="GitHub login failed")
|
||||
|
||||
return await _handle_oauth_user(
|
||||
db=db,
|
||||
provider="github",
|
||||
user_id=str(github_id),
|
||||
name=user_data.get("name") or user_data.get("login") or "GitHub User",
|
||||
avatar_url=user_data.get("avatar_url"),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/google/signin")
|
||||
async def google_signin():
|
||||
"""Start Google OAuth with state protection."""
|
||||
state = secrets.token_urlsafe(16)
|
||||
params = {
|
||||
"client_id": settings.google_client_id,
|
||||
"redirect_uri": f"{settings.base_url}/auth/google/callback",
|
||||
"response_type": "code",
|
||||
"scope": "openid email profile",
|
||||
"state": state,
|
||||
}
|
||||
url = f"{GOOGLE_AUTHORIZE_URL}?{urlencode(params)}"
|
||||
response = RedirectResponse(url=url)
|
||||
_set_state_cookie(response, "google", state)
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/google/callback")
|
||||
async def google_callback(
|
||||
code: str,
|
||||
state: str | None = Query(default=None),
|
||||
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Handle Google OAuth callback."""
|
||||
_validate_state(state, state_cookie, "google")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_resp = await client.post(
|
||||
GOOGLE_TOKEN_URL,
|
||||
data={
|
||||
"client_id": settings.google_client_id,
|
||||
"client_secret": settings.google_client_secret,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": f"{settings.base_url}/auth/google/callback",
|
||||
},
|
||||
)
|
||||
token_resp.raise_for_status()
|
||||
token_data = token_resp.json()
|
||||
access_token = token_data.get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(status_code=502, detail="Google login failed")
|
||||
|
||||
user_resp = await client.get(
|
||||
GOOGLE_USER_URL,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
user_resp.raise_for_status()
|
||||
user_data = user_resp.json()
|
||||
except httpx.HTTPStatusError:
|
||||
raise HTTPException(status_code=502, detail="Google login failed")
|
||||
|
||||
google_id = user_data.get("id")
|
||||
if google_id is None:
|
||||
raise HTTPException(status_code=502, detail="Google login failed")
|
||||
|
||||
return await _handle_oauth_user(
|
||||
db=db,
|
||||
provider="google",
|
||||
user_id=str(google_id),
|
||||
name=user_data.get("name") or user_data.get("email") or "Google User",
|
||||
avatar_url=user_data.get("picture"),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_oauth_user(
|
||||
db: AsyncSession,
|
||||
provider: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
avatar_url: str | None,
|
||||
) -> RedirectResponse:
|
||||
"""Create/update user and issue session cookie."""
|
||||
full_id = f"{provider}:{user_id}"
|
||||
|
||||
result = await db.execute(select(User).where(User.id == full_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
user = User(
|
||||
id=full_id,
|
||||
name=name,
|
||||
avatar_url=avatar_url,
|
||||
provider=provider,
|
||||
)
|
||||
db.add(user)
|
||||
else:
|
||||
user.name = name
|
||||
user.avatar_url = avatar_url
|
||||
|
||||
await db.commit()
|
||||
|
||||
token = create_access_token({"sub": user.id})
|
||||
|
||||
frontend_url = "http://localhost:5173"
|
||||
if settings.cors_origins and len(settings.cors_origins) > 0:
|
||||
frontend_url = settings.cors_origins[0]
|
||||
|
||||
response = RedirectResponse(url=f"{frontend_url}/my-stories", status_code=302)
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=not settings.debug,
|
||||
samesite="lax",
|
||||
max_age=60 * 60 * 24 * 7, # align with ACCESS_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
response.delete_cookie(STATE_COOKIE)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/signout")
|
||||
async def signout():
|
||||
"""Sign out and clear cookies."""
|
||||
response = RedirectResponse(url=settings.cors_origins[0], status_code=302)
|
||||
response.delete_cookie("access_token", samesite="lax", secure=not settings.debug)
|
||||
response.delete_cookie(STATE_COOKIE, samesite="lax", secure=not settings.debug)
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/session")
|
||||
async def get_session(user: User | None = Depends(get_current_user)):
|
||||
"""Fetch current session info."""
|
||||
if not user:
|
||||
return {"user": None}
|
||||
return {
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"avatar_url": user.avatar_url,
|
||||
"provider": user.provider,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/dev/signin")
|
||||
async def dev_signin(db: AsyncSession = Depends(get_db)):
|
||||
"""Developer backdoor login. Only works in DEBUG mode."""
|
||||
# if not settings.debug:
|
||||
# raise HTTPException(status_code=403, detail="Developer login disabled")
|
||||
|
||||
try:
|
||||
return await _handle_oauth_user(
|
||||
db=db,
|
||||
provider="github",
|
||||
user_id="dev_user_001",
|
||||
name="Developer",
|
||||
avatar_url="https://api.dicebear.com/7.x/avataaars/svg?seed=Developer"
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"Dev login failed: {str(e)}")
|
||||
268
backend/app/api/memories.py
Normal file
268
backend/app/api/memories.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Memory management APIs."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.db.database import get_db
|
||||
from app.db.models import ChildProfile, User
|
||||
from app.services import memory_service
|
||||
from app.services.memory_service import MemoryType
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class MemoryItemResponse(BaseModel):
|
||||
"""Memory item response."""
|
||||
|
||||
id: str
|
||||
type: str
|
||||
value: dict
|
||||
base_weight: float
|
||||
ttl_days: int | None
|
||||
created_at: str
|
||||
last_used_at: str | None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class MemoryListResponse(BaseModel):
|
||||
"""Memory list response."""
|
||||
|
||||
memories: list[MemoryItemResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class CreateMemoryRequest(BaseModel):
|
||||
"""Create memory request."""
|
||||
|
||||
type: str = Field(..., description="记忆类型")
|
||||
value: dict = Field(..., description="记忆内容")
|
||||
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
|
||||
weight: float | None = Field(default=None, description="权重")
|
||||
ttl_days: int | None = Field(default=None, description="过期天数")
|
||||
|
||||
|
||||
class CreateCharacterMemoryRequest(BaseModel):
|
||||
"""Create character memory request."""
|
||||
|
||||
name: str = Field(..., description="角色名称")
|
||||
description: str | None = Field(default=None, description="角色描述")
|
||||
source_story_id: int | None = Field(default=None, description="来源故事 ID")
|
||||
affinity_score: float = Field(default=1.0, ge=0.0, le=1.0, description="喜爱程度")
|
||||
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
|
||||
|
||||
|
||||
class CreateScaryElementRequest(BaseModel):
|
||||
"""Create scary element memory request."""
|
||||
|
||||
keyword: str = Field(..., description="回避的关键词")
|
||||
category: str = Field(default="other", description="分类")
|
||||
source_story_id: int | None = Field(default=None, description="来源故事 ID")
|
||||
|
||||
|
||||
async def _verify_profile_ownership(
|
||||
profile_id: str, user: User, db: AsyncSession
|
||||
) -> ChildProfile:
|
||||
"""验证档案所有权。"""
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="档案不存在")
|
||||
return profile
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/memories", response_model=MemoryListResponse)
|
||||
async def list_memories(
|
||||
profile_id: str,
|
||||
memory_type: str | None = None,
|
||||
universe_id: str | None = None,
|
||||
limit: int = 50,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取档案的记忆列表。"""
|
||||
await _verify_profile_ownership(profile_id, user, db)
|
||||
|
||||
memories = await memory_service.get_profile_memories(
|
||||
db=db,
|
||||
profile_id=profile_id,
|
||||
memory_type=memory_type,
|
||||
universe_id=universe_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return MemoryListResponse(
|
||||
memories=[
|
||||
MemoryItemResponse(
|
||||
id=m.id,
|
||||
type=m.type,
|
||||
value=m.value,
|
||||
base_weight=m.base_weight,
|
||||
ttl_days=m.ttl_days,
|
||||
created_at=m.created_at.isoformat() if m.created_at else "",
|
||||
last_used_at=m.last_used_at.isoformat() if m.last_used_at else None,
|
||||
)
|
||||
for m in memories
|
||||
],
|
||||
total=len(memories),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/profiles/{profile_id}/memories", response_model=MemoryItemResponse)
|
||||
async def create_memory(
|
||||
profile_id: str,
|
||||
payload: CreateMemoryRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建新的记忆项。"""
|
||||
await _verify_profile_ownership(profile_id, user, db)
|
||||
|
||||
# 验证类型
|
||||
valid_types = [
|
||||
MemoryType.RECENT_STORY,
|
||||
MemoryType.FAVORITE_CHARACTER,
|
||||
MemoryType.SCARY_ELEMENT,
|
||||
MemoryType.VOCABULARY_GROWTH,
|
||||
MemoryType.EMOTIONAL_HIGHLIGHT,
|
||||
MemoryType.READING_PREFERENCE,
|
||||
MemoryType.MILESTONE,
|
||||
MemoryType.SKILL_MASTERED,
|
||||
]
|
||||
if payload.type not in valid_types:
|
||||
raise HTTPException(status_code=400, detail=f"无效的记忆类型: {payload.type}")
|
||||
|
||||
memory = await memory_service.create_memory(
|
||||
db=db,
|
||||
profile_id=profile_id,
|
||||
memory_type=payload.type,
|
||||
value=payload.value,
|
||||
universe_id=payload.universe_id,
|
||||
weight=payload.weight,
|
||||
ttl_days=payload.ttl_days,
|
||||
)
|
||||
|
||||
return MemoryItemResponse(
|
||||
id=memory.id,
|
||||
type=memory.type,
|
||||
value=memory.value,
|
||||
base_weight=memory.base_weight,
|
||||
ttl_days=memory.ttl_days,
|
||||
created_at=memory.created_at.isoformat() if memory.created_at else "",
|
||||
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/profiles/{profile_id}/memories/character", response_model=MemoryItemResponse)
|
||||
async def create_character_memory(
|
||||
profile_id: str,
|
||||
payload: CreateCharacterMemoryRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""添加喜欢的角色。"""
|
||||
await _verify_profile_ownership(profile_id, user, db)
|
||||
|
||||
memory = await memory_service.create_character_memory(
|
||||
db=db,
|
||||
profile_id=profile_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
source_story_id=payload.source_story_id,
|
||||
affinity_score=payload.affinity_score,
|
||||
universe_id=payload.universe_id,
|
||||
)
|
||||
|
||||
return MemoryItemResponse(
|
||||
id=memory.id,
|
||||
type=memory.type,
|
||||
value=memory.value,
|
||||
base_weight=memory.base_weight,
|
||||
ttl_days=memory.ttl_days,
|
||||
created_at=memory.created_at.isoformat() if memory.created_at else "",
|
||||
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/profiles/{profile_id}/memories/scary", response_model=MemoryItemResponse)
|
||||
async def create_scary_element_memory(
|
||||
profile_id: str,
|
||||
payload: CreateScaryElementRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""添加回避元素。"""
|
||||
await _verify_profile_ownership(profile_id, user, db)
|
||||
|
||||
memory = await memory_service.create_scary_element_memory(
|
||||
db=db,
|
||||
profile_id=profile_id,
|
||||
keyword=payload.keyword,
|
||||
category=payload.category,
|
||||
source_story_id=payload.source_story_id,
|
||||
)
|
||||
|
||||
return MemoryItemResponse(
|
||||
id=memory.id,
|
||||
type=memory.type,
|
||||
value=memory.value,
|
||||
base_weight=memory.base_weight,
|
||||
ttl_days=memory.ttl_days,
|
||||
created_at=memory.created_at.isoformat() if memory.created_at else "",
|
||||
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/profiles/{profile_id}/memories/{memory_id}")
|
||||
async def delete_memory(
|
||||
profile_id: str,
|
||||
memory_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除记忆项。"""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db.models import MemoryItem
|
||||
|
||||
await _verify_profile_ownership(profile_id, user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(MemoryItem).where(
|
||||
MemoryItem.id == memory_id,
|
||||
MemoryItem.child_profile_id == profile_id,
|
||||
)
|
||||
)
|
||||
memory = result.scalar_one_or_none()
|
||||
|
||||
if not memory:
|
||||
raise HTTPException(status_code=404, detail="记忆不存在")
|
||||
|
||||
await db.delete(memory)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Deleted"}
|
||||
|
||||
|
||||
@router.get("/memory-types")
|
||||
async def list_memory_types():
|
||||
"""获取所有可用的记忆类型及其配置。"""
|
||||
types = []
|
||||
for type_name, config in MemoryType.CONFIG.items():
|
||||
types.append({
|
||||
"type": type_name,
|
||||
"default_weight": config[0],
|
||||
"default_ttl_days": config[1],
|
||||
"description": config[2],
|
||||
})
|
||||
return {"types": types}
|
||||
280
backend/app/api/profiles.py
Normal file
280
backend/app/api/profiles.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""Child profile APIs."""
|
||||
|
||||
from datetime import date
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.db.database import get_db
|
||||
from app.db.models import ChildProfile, Story, StoryUniverse, User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
MAX_PROFILES_PER_USER = 5
|
||||
|
||||
|
||||
class ChildProfileCreate(BaseModel):
|
||||
"""Create profile payload."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=50)
|
||||
birth_date: date | None = None
|
||||
gender: str | None = Field(default=None, pattern="^(male|female|other)$")
|
||||
interests: list[str] = Field(default_factory=list)
|
||||
growth_themes: list[str] = Field(default_factory=list)
|
||||
avatar_url: str | None = None
|
||||
|
||||
|
||||
class ChildProfileUpdate(BaseModel):
|
||||
"""Update profile payload."""
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=50)
|
||||
birth_date: date | None = None
|
||||
gender: str | None = Field(default=None, pattern="^(male|female|other)$")
|
||||
interests: list[str] | None = None
|
||||
growth_themes: list[str] | None = None
|
||||
avatar_url: str | None = None
|
||||
|
||||
|
||||
class ChildProfileResponse(BaseModel):
|
||||
"""Profile response."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
avatar_url: str | None
|
||||
birth_date: date | None
|
||||
gender: str | None
|
||||
age: int | None
|
||||
interests: list[str]
|
||||
growth_themes: list[str]
|
||||
stories_count: int
|
||||
total_reading_time: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ChildProfileListResponse(BaseModel):
|
||||
"""Profile list response."""
|
||||
|
||||
profiles: list[ChildProfileResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class TimelineEvent(BaseModel):
|
||||
"""Timeline event item."""
|
||||
|
||||
date: str
|
||||
type: Literal["story", "achievement", "milestone"]
|
||||
title: str
|
||||
description: str | None = None
|
||||
image_url: str | None = None
|
||||
metadata: dict | None = None
|
||||
|
||||
|
||||
class TimelineResponse(BaseModel):
|
||||
"""Timeline response."""
|
||||
|
||||
events: list[TimelineEvent]
|
||||
|
||||
|
||||
@router.get("/profiles", response_model=ChildProfileListResponse)
|
||||
async def list_profiles(
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List child profiles for current user."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile)
|
||||
.where(ChildProfile.user_id == user.id)
|
||||
.order_by(ChildProfile.created_at.desc())
|
||||
)
|
||||
profiles = result.scalars().all()
|
||||
|
||||
return ChildProfileListResponse(profiles=profiles, total=len(profiles))
|
||||
|
||||
|
||||
@router.post("/profiles", response_model=ChildProfileResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_profile(
|
||||
payload: ChildProfileCreate,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new child profile."""
|
||||
count = await db.scalar(
|
||||
select(func.count(ChildProfile.id)).where(ChildProfile.user_id == user.id)
|
||||
)
|
||||
if count and count >= MAX_PROFILES_PER_USER:
|
||||
raise HTTPException(status_code=400, detail="最多只能创建 5 个孩子档案")
|
||||
|
||||
existing = await db.scalar(
|
||||
select(ChildProfile.id).where(
|
||||
ChildProfile.user_id == user.id,
|
||||
ChildProfile.name == payload.name,
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="该档案名称已存在")
|
||||
|
||||
profile = ChildProfile(user_id=user.id, **payload.model_dump())
|
||||
db.add(profile)
|
||||
await db.commit()
|
||||
await db.refresh(profile)
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}", response_model=ChildProfileResponse)
|
||||
async def get_profile(
|
||||
profile_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get one child profile."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="档案不存在")
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
@router.put("/profiles/{profile_id}", response_model=ChildProfileResponse)
|
||||
async def update_profile(
|
||||
profile_id: str,
|
||||
payload: ChildProfileUpdate,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update a child profile."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="档案不存在")
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
if "name" in updates:
|
||||
existing = await db.scalar(
|
||||
select(ChildProfile.id).where(
|
||||
ChildProfile.user_id == user.id,
|
||||
ChildProfile.name == updates["name"],
|
||||
ChildProfile.id != profile_id,
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="该档案名称已存在")
|
||||
|
||||
for key, value in updates.items():
|
||||
setattr(profile, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(profile)
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
@router.delete("/profiles/{profile_id}")
|
||||
async def delete_profile(
|
||||
profile_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete a child profile."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="档案不存在")
|
||||
|
||||
await db.delete(profile)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Deleted"}
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/timeline", response_model=TimelineResponse)
|
||||
async def get_profile_timeline(
|
||||
profile_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get profile growth timeline."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="档案不存在")
|
||||
|
||||
events: list[TimelineEvent] = []
|
||||
|
||||
# 1. Milestone: Profile Created
|
||||
events.append(TimelineEvent(
|
||||
date=profile.created_at.isoformat(),
|
||||
type="milestone",
|
||||
title="初次相遇",
|
||||
description=f"创建了档案 {profile.name}"
|
||||
))
|
||||
|
||||
# 2. Stories
|
||||
stories_result = await db.execute(
|
||||
select(Story).where(Story.child_profile_id == profile_id)
|
||||
)
|
||||
for s in stories_result.scalars():
|
||||
events.append(TimelineEvent(
|
||||
date=s.created_at.isoformat(),
|
||||
type="story",
|
||||
title=s.title,
|
||||
image_url=s.image_url,
|
||||
metadata={"story_id": s.id, "mode": s.mode}
|
||||
))
|
||||
|
||||
# 3. Achievements (from Universe)
|
||||
universes_result = await db.execute(
|
||||
select(StoryUniverse).where(StoryUniverse.child_profile_id == profile_id)
|
||||
)
|
||||
for u in universes_result.scalars():
|
||||
if u.achievements:
|
||||
for ach in u.achievements:
|
||||
if isinstance(ach, dict):
|
||||
obt_at = ach.get("obtained_at")
|
||||
# Fallback
|
||||
if not obt_at:
|
||||
obt_at = u.updated_at.isoformat()
|
||||
|
||||
events.append(TimelineEvent(
|
||||
date=obt_at,
|
||||
type="achievement",
|
||||
title=f"获得成就:{ach.get('type')}",
|
||||
description=ach.get('description'),
|
||||
metadata={"universe_id": u.id, "source_story_id": ach.get("source_story_id")}
|
||||
))
|
||||
|
||||
# Sort by date desc
|
||||
events.sort(key=lambda x: x.date, reverse=True)
|
||||
|
||||
return TimelineResponse(events=events)
|
||||
120
backend/app/api/push_configs.py
Normal file
120
backend/app/api/push_configs.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Push configuration APIs."""
|
||||
|
||||
from datetime import time
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.db.database import get_db
|
||||
from app.db.models import ChildProfile, PushConfig, User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class PushConfigUpsert(BaseModel):
|
||||
"""Upsert push config payload."""
|
||||
|
||||
child_profile_id: str
|
||||
push_time: time | None = None
|
||||
push_days: list[int] | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class PushConfigResponse(BaseModel):
|
||||
"""Push config response."""
|
||||
|
||||
id: str
|
||||
child_profile_id: str
|
||||
push_time: time | None
|
||||
push_days: list[int]
|
||||
enabled: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PushConfigListResponse(BaseModel):
|
||||
"""Push config list response."""
|
||||
|
||||
configs: list[PushConfigResponse]
|
||||
total: int
|
||||
|
||||
|
||||
def _validate_push_days(push_days: list[int]) -> list[int]:
|
||||
invalid = [day for day in push_days if day < 0 or day > 6]
|
||||
if invalid:
|
||||
raise HTTPException(status_code=400, detail="推送日期必须在 0-6 之间")
|
||||
return list(dict.fromkeys(push_days))
|
||||
|
||||
|
||||
@router.get("/push-configs", response_model=PushConfigListResponse)
|
||||
async def list_push_configs(
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List push configs for current user."""
|
||||
result = await db.execute(
|
||||
select(PushConfig).where(PushConfig.user_id == user.id)
|
||||
)
|
||||
configs = result.scalars().all()
|
||||
return PushConfigListResponse(configs=configs, total=len(configs))
|
||||
|
||||
|
||||
@router.put("/push-configs", response_model=PushConfigResponse)
|
||||
async def upsert_push_config(
|
||||
payload: PushConfigUpsert,
|
||||
response: Response,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create or update push config for a child profile."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == payload.child_profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
||||
|
||||
result = await db.execute(
|
||||
select(PushConfig).where(PushConfig.child_profile_id == payload.child_profile_id)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config is None:
|
||||
if payload.push_time is None or payload.push_days is None:
|
||||
raise HTTPException(status_code=400, detail="创建配置需要提供推送时间和日期")
|
||||
push_days = _validate_push_days(payload.push_days)
|
||||
config = PushConfig(
|
||||
user_id=user.id,
|
||||
child_profile_id=payload.child_profile_id,
|
||||
push_time=payload.push_time,
|
||||
push_days=push_days,
|
||||
enabled=True if payload.enabled is None else payload.enabled,
|
||||
)
|
||||
db.add(config)
|
||||
await db.commit()
|
||||
await db.refresh(config)
|
||||
response.status_code = status.HTTP_201_CREATED
|
||||
return config
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
if "push_days" in updates and updates["push_days"] is not None:
|
||||
updates["push_days"] = _validate_push_days(updates["push_days"])
|
||||
if "push_time" in updates and updates["push_time"] is None:
|
||||
raise HTTPException(status_code=400, detail="推送时间不能为空")
|
||||
|
||||
for key, value in updates.items():
|
||||
if key == "child_profile_id":
|
||||
continue
|
||||
if value is not None:
|
||||
setattr(config, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(config)
|
||||
return config
|
||||
120
backend/app/api/reading_events.py
Normal file
120
backend/app/api/reading_events.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Reading event APIs."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.db.database import get_db
|
||||
from app.db.models import ChildProfile, MemoryItem, ReadingEvent, Story, User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
EVENT_WEIGHTS: dict[str, float] = {
|
||||
"completed": 1.0,
|
||||
"replayed": 1.5,
|
||||
"started": 0.1,
|
||||
"skipped": -0.5,
|
||||
}
|
||||
|
||||
|
||||
class ReadingEventCreate(BaseModel):
|
||||
"""Reading event payload."""
|
||||
|
||||
child_profile_id: str
|
||||
story_id: int | None = None
|
||||
event_type: Literal["started", "completed", "skipped", "replayed"]
|
||||
reading_time: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class ReadingEventResponse(BaseModel):
|
||||
"""Reading event response."""
|
||||
|
||||
id: int
|
||||
child_profile_id: str
|
||||
story_id: int | None
|
||||
event_type: str
|
||||
reading_time: int
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
@router.post("/reading-events", response_model=ReadingEventResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_reading_event(
|
||||
payload: ReadingEventCreate,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a reading event and update profile stats/memory."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == payload.child_profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
||||
|
||||
story = None
|
||||
if payload.story_id is not None:
|
||||
result = await db.execute(
|
||||
select(Story).where(
|
||||
Story.id == payload.story_id,
|
||||
Story.user_id == user.id,
|
||||
)
|
||||
)
|
||||
story = result.scalar_one_or_none()
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="故事不存在")
|
||||
|
||||
if payload.reading_time:
|
||||
profile.total_reading_time = (profile.total_reading_time or 0) + payload.reading_time
|
||||
|
||||
if payload.event_type in {"completed", "replayed"} and payload.story_id is not None:
|
||||
existing = await db.scalar(
|
||||
select(ReadingEvent.id).where(
|
||||
ReadingEvent.child_profile_id == payload.child_profile_id,
|
||||
ReadingEvent.story_id == payload.story_id,
|
||||
ReadingEvent.event_type.in_(["completed", "replayed"]),
|
||||
)
|
||||
)
|
||||
if existing is None:
|
||||
profile.stories_count = (profile.stories_count or 0) + 1
|
||||
|
||||
event = ReadingEvent(
|
||||
child_profile_id=payload.child_profile_id,
|
||||
story_id=payload.story_id,
|
||||
event_type=payload.event_type,
|
||||
reading_time=payload.reading_time,
|
||||
)
|
||||
db.add(event)
|
||||
|
||||
weight = EVENT_WEIGHTS.get(payload.event_type, 0.0)
|
||||
if story and weight > 0:
|
||||
db.add(
|
||||
MemoryItem(
|
||||
child_profile_id=payload.child_profile_id,
|
||||
universe_id=story.universe_id,
|
||||
type="recent_story",
|
||||
value={
|
||||
"story_id": story.id,
|
||||
"title": story.title,
|
||||
"event_type": payload.event_type,
|
||||
},
|
||||
base_weight=weight,
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
ttl_days=90,
|
||||
)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(event)
|
||||
|
||||
return event
|
||||
605
backend/app/api/stories.py
Normal file
605
backend/app/api/stories.py
Normal file
@@ -0,0 +1,605 @@
|
||||
"""Story related APIs."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Literal
|
||||
|
||||
from cachetools import TTLCache
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.core.logging import get_logger
|
||||
from app.db.database import get_db
|
||||
from app.db.models import ChildProfile, Story, StoryUniverse, User
|
||||
from app.services.provider_router import (
|
||||
generate_image,
|
||||
generate_story_content,
|
||||
generate_storybook,
|
||||
text_to_speech,
|
||||
)
|
||||
from app.tasks.achievements import extract_story_achievements
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
MAX_DATA_LENGTH = 2000
|
||||
MAX_EDU_THEME_LENGTH = 200
|
||||
MAX_TTS_LENGTH = 4000
|
||||
|
||||
RATE_LIMIT_WINDOW = 60 # seconds
|
||||
RATE_LIMIT_REQUESTS = 10
|
||||
RATE_LIMIT_CACHE_SIZE = 10000 # 最大跟踪用户数
|
||||
|
||||
_request_log: TTLCache[str, list[float]] = TTLCache(
|
||||
maxsize=RATE_LIMIT_CACHE_SIZE, ttl=RATE_LIMIT_WINDOW * 2
|
||||
)
|
||||
|
||||
|
||||
def _check_rate_limit(user_id: str):
|
||||
now = time.time()
|
||||
timestamps = _request_log.get(user_id, [])
|
||||
timestamps = [t for t in timestamps if now - t <= RATE_LIMIT_WINDOW]
|
||||
if len(timestamps) >= RATE_LIMIT_REQUESTS:
|
||||
raise HTTPException(status_code=429, detail="Too many requests, please slow down.")
|
||||
timestamps.append(now)
|
||||
_request_log[user_id] = timestamps
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
"""Story generation request."""
|
||||
|
||||
type: Literal["keywords", "full_story"]
|
||||
data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
|
||||
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
|
||||
child_profile_id: str | None = None
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
class StoryResponse(BaseModel):
|
||||
"""Story response."""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
story_text: str
|
||||
cover_prompt: str | None
|
||||
image_url: str | None
|
||||
mode: str
|
||||
child_profile_id: str | None = None
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
class StoryListItem(BaseModel):
|
||||
"""Story list item."""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
image_url: str | None
|
||||
created_at: str
|
||||
mode: str
|
||||
|
||||
|
||||
class FullStoryResponse(BaseModel):
|
||||
"""完整故事响应(含图片和音频状态)。"""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
story_text: str
|
||||
cover_prompt: str | None
|
||||
image_url: str | None
|
||||
audio_ready: bool
|
||||
mode: str
|
||||
errors: dict[str, str | None] = Field(default_factory=dict)
|
||||
child_profile_id: str | None = None
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
from app.services.memory_service import build_enhanced_memory_context
|
||||
|
||||
|
||||
async def _validate_profile_and_universe(
|
||||
request: GenerateRequest,
|
||||
user: User,
|
||||
db: AsyncSession,
|
||||
) -> tuple[str | None, str | None]:
|
||||
if not request.child_profile_id and not request.universe_id:
|
||||
return None, None
|
||||
|
||||
profile_id = request.child_profile_id
|
||||
universe_id = request.universe_id
|
||||
|
||||
if profile_id:
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
||||
|
||||
if universe_id:
|
||||
result = await db.execute(
|
||||
select(StoryUniverse)
|
||||
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
|
||||
.where(
|
||||
StoryUniverse.id == universe_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
universe = result.scalar_one_or_none()
|
||||
if not universe:
|
||||
raise HTTPException(status_code=404, detail="故事宇宙不存在")
|
||||
if profile_id and universe.child_profile_id != profile_id:
|
||||
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
|
||||
if not profile_id:
|
||||
profile_id = universe.child_profile_id
|
||||
|
||||
return profile_id, universe_id
|
||||
|
||||
|
||||
@router.post("/stories/generate", response_model=StoryResponse)
|
||||
async def generate_story(
|
||||
request: GenerateRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Generate or enhance a story."""
|
||||
_check_rate_limit(user.id)
|
||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
try:
|
||||
result = await generate_story_content(
|
||||
input_type=request.type,
|
||||
data=request.data,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
|
||||
|
||||
story = Story(
|
||||
user_id=user.id,
|
||||
child_profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
title=result.title,
|
||||
story_text=result.story_text,
|
||||
cover_prompt=result.cover_prompt_suggestion,
|
||||
mode=result.mode,
|
||||
)
|
||||
db.add(story)
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
return StoryResponse(
|
||||
id=story.id,
|
||||
title=story.title,
|
||||
story_text=story.story_text,
|
||||
cover_prompt=story.cover_prompt,
|
||||
image_url=story.image_url,
|
||||
mode=story.mode,
|
||||
child_profile_id=story.child_profile_id,
|
||||
universe_id=story.universe_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stories/generate/full", response_model=FullStoryResponse)
|
||||
async def generate_story_full(
|
||||
request: GenerateRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""生成完整故事(故事 + 并行生成图片和音频)。
|
||||
|
||||
部分成功策略:故事必须成功,图片/音频失败不影响整体。
|
||||
"""
|
||||
_check_rate_limit(user.id)
|
||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
# Step 1: 故事生成(必须成功)
|
||||
try:
|
||||
result = await generate_story_content(
|
||||
input_type=request.type,
|
||||
data=request.data,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("story_generation_failed", error=str(exc))
|
||||
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
|
||||
|
||||
# 保存故事
|
||||
story = Story(
|
||||
user_id=user.id,
|
||||
child_profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
title=result.title,
|
||||
story_text=result.story_text,
|
||||
cover_prompt=result.cover_prompt_suggestion,
|
||||
mode=result.mode,
|
||||
)
|
||||
db.add(story)
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
# Step 2: 生成封面图片(音频按需生成,避免浪费)
|
||||
errors: dict[str, str | None] = {}
|
||||
image_url: str | None = None
|
||||
|
||||
if story.cover_prompt:
|
||||
try:
|
||||
image_url = await generate_image(story.cover_prompt)
|
||||
story.image_url = image_url
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
errors["image"] = str(exc)
|
||||
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
|
||||
|
||||
# 注意:音频不在此处预生成,用户通过 /api/audio/{id} 按需获取
|
||||
# 这样避免生成后丢弃造成的成本浪费
|
||||
|
||||
return FullStoryResponse(
|
||||
id=story.id,
|
||||
title=story.title,
|
||||
story_text=story.story_text,
|
||||
cover_prompt=story.cover_prompt,
|
||||
image_url=image_url,
|
||||
audio_ready=False, # 音频需要用户主动请求
|
||||
mode=story.mode,
|
||||
errors=errors,
|
||||
child_profile_id=story.child_profile_id,
|
||||
universe_id=story.universe_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stories/generate/stream")
|
||||
async def generate_story_stream(
|
||||
request: GenerateRequest,
|
||||
req: Request,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""流式生成故事(SSE)。
|
||||
|
||||
事件流程:
|
||||
- started: 返回 story_id
|
||||
- story_ready: 返回 title, content
|
||||
- story_failed: 返回 error
|
||||
- image_ready: 返回 image_url
|
||||
- image_failed: 返回 error
|
||||
- complete: 结束流
|
||||
"""
|
||||
_check_rate_limit(user.id)
|
||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[dict, None]:
|
||||
story_id = str(uuid.uuid4())
|
||||
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
|
||||
|
||||
# Step 1: 生成故事
|
||||
try:
|
||||
result = await generate_story_content(
|
||||
input_type=request.type,
|
||||
data=request.data,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("sse_story_generation_failed", error=str(e))
|
||||
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
|
||||
return
|
||||
|
||||
# 保存故事
|
||||
story = Story(
|
||||
user_id=user.id,
|
||||
child_profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
title=result.title,
|
||||
story_text=result.story_text,
|
||||
cover_prompt=result.cover_prompt_suggestion,
|
||||
mode=result.mode,
|
||||
)
|
||||
db.add(story)
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
yield {
|
||||
"event": "story_ready",
|
||||
"data": json.dumps({
|
||||
"id": story.id,
|
||||
"title": story.title,
|
||||
"content": story.story_text,
|
||||
"cover_prompt": story.cover_prompt,
|
||||
"mode": story.mode,
|
||||
"child_profile_id": story.child_profile_id,
|
||||
"universe_id": story.universe_id,
|
||||
}),
|
||||
}
|
||||
|
||||
# Step 2: 并行生成图片(音频按需)
|
||||
if story.cover_prompt:
|
||||
try:
|
||||
image_url = await generate_image(story.cover_prompt)
|
||||
story.image_url = image_url
|
||||
await db.commit()
|
||||
yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})}
|
||||
except Exception as e:
|
||||
logger.warning("sse_image_generation_failed", story_id=story.id, error=str(e))
|
||||
yield {"event": "image_failed", "data": json.dumps({"error": str(e)})}
|
||||
|
||||
yield {"event": "complete", "data": json.dumps({"story_id": story.id})}
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
|
||||
# ==================== Storybook API ====================
|
||||
|
||||
|
||||
class StorybookRequest(BaseModel):
|
||||
"""Storybook 生成请求。"""
|
||||
|
||||
keywords: str = Field(..., min_length=1, max_length=200)
|
||||
page_count: int = Field(default=6, ge=4, le=12)
|
||||
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
|
||||
generate_images: bool = Field(default=False, description="是否同时生成插图")
|
||||
child_profile_id: str | None = None
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
class StorybookPageResponse(BaseModel):
|
||||
"""故事书单页响应。"""
|
||||
|
||||
page_number: int
|
||||
text: str
|
||||
image_prompt: str
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
class StorybookResponse(BaseModel):
|
||||
"""故事书响应。"""
|
||||
|
||||
id: int | None = None
|
||||
title: str
|
||||
main_character: str
|
||||
art_style: str
|
||||
pages: list[StorybookPageResponse]
|
||||
cover_prompt: str
|
||||
cover_url: str | None = None
|
||||
|
||||
|
||||
@router.post("/storybook/generate", response_model=StorybookResponse)
|
||||
async def generate_storybook_api(
|
||||
request: StorybookRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""生成分页故事书并保存。
|
||||
|
||||
返回故事书结构,包含每页文字和图像提示词。
|
||||
"""
|
||||
_check_rate_limit(user.id)
|
||||
|
||||
# 验证档案和宇宙
|
||||
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数
|
||||
# 这里我们直接手动验证,或重构验证函数。为了简单,手动调用部分逻辑。
|
||||
|
||||
# 构建临时的 GenerateRequest 用于验证验证函数签名(或者直接手动查库更好)
|
||||
profile_id = request.child_profile_id
|
||||
universe_id = request.universe_id
|
||||
|
||||
if profile_id:
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
if not result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
||||
|
||||
if universe_id:
|
||||
result = await db.execute(
|
||||
select(StoryUniverse)
|
||||
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
|
||||
.where(
|
||||
StoryUniverse.id == universe_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
universe = result.scalar_one_or_none()
|
||||
if not universe:
|
||||
raise HTTPException(status_code=404, detail="故事宇宙不存在")
|
||||
if profile_id and universe.child_profile_id != profile_id:
|
||||
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
|
||||
if not profile_id:
|
||||
profile_id = universe.child_profile_id
|
||||
|
||||
logger.info(
|
||||
"storybook_request",
|
||||
user_id=user.id,
|
||||
keywords=request.keywords,
|
||||
page_count=request.page_count,
|
||||
profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
)
|
||||
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
try:
|
||||
# 注意:generate_storybook 目前可能不支持记忆上下文注入
|
||||
# 我们需要看看 generate_storybook 的签名
|
||||
# 如果不支持,记忆功能在绘本模式下暂不可用,但基本参数传递是支持的
|
||||
storybook = await generate_storybook(
|
||||
keywords=request.keywords,
|
||||
page_count=request.page_count,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
db=db,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("storybook_generation_failed", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
|
||||
|
||||
# ==============================================================================
|
||||
# 核心升级: 并行全量生成 (Parallel Full Rendering)
|
||||
# ==============================================================================
|
||||
final_cover_url = storybook.cover_url
|
||||
|
||||
if request.generate_images:
|
||||
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
|
||||
|
||||
# 1. 准备所有生图任务 (封面 + 所有内页)
|
||||
tasks = []
|
||||
|
||||
# 封面任务
|
||||
async def _gen_cover():
|
||||
if storybook.cover_prompt and not storybook.cover_url:
|
||||
try:
|
||||
return await generate_image(storybook.cover_prompt, db=db)
|
||||
except Exception as e:
|
||||
logger.warning("cover_gen_failed", error=str(e))
|
||||
return storybook.cover_url
|
||||
|
||||
tasks.append(_gen_cover())
|
||||
|
||||
# 内页任务
|
||||
async def _gen_page(page):
|
||||
if page.image_prompt and not page.image_url:
|
||||
try:
|
||||
url = await generate_image(page.image_prompt, db=db)
|
||||
page.image_url = url
|
||||
except Exception as e:
|
||||
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
|
||||
|
||||
for page in storybook.pages:
|
||||
tasks.append(_gen_page(page))
|
||||
|
||||
# 2. 并发执行所有任务
|
||||
# 使用 return_exceptions=True 防止单张失败影响整体
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 3. 更新封面结果 (results[0] 是封面任务的返回值)
|
||||
cover_res = results[0]
|
||||
if isinstance(cover_res, str):
|
||||
final_cover_url = cover_res
|
||||
|
||||
logger.info("storybook_parallel_generation_complete")
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
# 构建并保存 Story 对象
|
||||
# 将 pages 对象转换为字典列表以存入 JSON 字段
|
||||
pages_data = [
|
||||
{
|
||||
"page_number": p.page_number,
|
||||
"text": p.text,
|
||||
"image_prompt": p.image_prompt,
|
||||
"image_url": p.image_url,
|
||||
}
|
||||
for p in storybook.pages
|
||||
]
|
||||
|
||||
story = Story(
|
||||
user_id=user.id,
|
||||
child_profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
title=storybook.title,
|
||||
mode="storybook",
|
||||
pages=pages_data, # 存入 JSON 字段
|
||||
story_text=None, # 绘本模式下,主文本可为空,或者可以存个摘要
|
||||
cover_prompt=storybook.cover_prompt,
|
||||
image_url=final_cover_url,
|
||||
)
|
||||
db.add(story)
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
# 构建响应 (使用更新后的 pages_data)
|
||||
response_pages = [
|
||||
StorybookPageResponse(
|
||||
page_number=p["page_number"],
|
||||
text=p["text"],
|
||||
image_prompt=p["image_prompt"],
|
||||
image_url=p.get("image_url"),
|
||||
)
|
||||
for p in pages_data
|
||||
]
|
||||
|
||||
return StorybookResponse(
|
||||
id=story.id,
|
||||
title=storybook.title,
|
||||
main_character=storybook.main_character,
|
||||
art_style=storybook.art_style,
|
||||
pages=response_pages,
|
||||
cover_prompt=storybook.cover_prompt,
|
||||
cover_url=final_cover_url,
|
||||
)
|
||||
|
||||
|
||||
class AchievementItem(BaseModel):
|
||||
type: str
|
||||
description: str
|
||||
obtained_at: str | None = None
|
||||
|
||||
|
||||
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
|
||||
async def get_story_achievements(
|
||||
story_id: int,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get achievements unlocked by a specific story."""
|
||||
# 使用 joinedload 避免 N+1 查询
|
||||
result = await db.execute(
|
||||
select(Story)
|
||||
.options(joinedload(Story.story_universe))
|
||||
.where(Story.id == story_id, Story.user_id == user.id)
|
||||
)
|
||||
story = result.scalar_one_or_none()
|
||||
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
|
||||
if not story.universe_id or not story.story_universe:
|
||||
return []
|
||||
|
||||
universe = story.story_universe
|
||||
if not universe.achievements:
|
||||
return []
|
||||
|
||||
results = []
|
||||
for ach in universe.achievements:
|
||||
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
|
||||
results.append(AchievementItem(
|
||||
type=ach.get("type", "Unknown"),
|
||||
description=ach.get("description", ""),
|
||||
obtained_at=ach.get("obtained_at")
|
||||
))
|
||||
|
||||
return results
|
||||
201
backend/app/api/universes.py
Normal file
201
backend/app/api/universes.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Story universe APIs."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.db.database import get_db
|
||||
from app.db.models import ChildProfile, StoryUniverse, User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class StoryUniverseCreate(BaseModel):
|
||||
"""Create universe payload."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
protagonist: dict[str, Any]
|
||||
recurring_characters: list[dict[str, Any]] = Field(default_factory=list)
|
||||
world_settings: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StoryUniverseUpdate(BaseModel):
|
||||
"""Update universe payload."""
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=100)
|
||||
protagonist: dict[str, Any] | None = None
|
||||
recurring_characters: list[dict[str, Any]] | None = None
|
||||
world_settings: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AchievementCreate(BaseModel):
|
||||
"""Achievement payload."""
|
||||
|
||||
type: str = Field(..., min_length=1, max_length=50)
|
||||
description: str = Field(..., min_length=1, max_length=200)
|
||||
|
||||
|
||||
class StoryUniverseResponse(BaseModel):
|
||||
"""Universe response."""
|
||||
|
||||
id: str
|
||||
child_profile_id: str
|
||||
name: str
|
||||
protagonist: dict[str, Any]
|
||||
recurring_characters: list[dict[str, Any]]
|
||||
world_settings: dict[str, Any]
|
||||
achievements: list[dict[str, Any]]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class StoryUniverseListResponse(BaseModel):
|
||||
"""Universe list response."""
|
||||
|
||||
universes: list[StoryUniverseResponse]
|
||||
total: int
|
||||
|
||||
|
||||
async def _get_profile_or_404(
|
||||
profile_id: str,
|
||||
user: User,
|
||||
db: AsyncSession,
|
||||
) -> ChildProfile:
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="档案不存在")
|
||||
return profile
|
||||
|
||||
|
||||
async def _get_universe_or_404(
|
||||
universe_id: str,
|
||||
user: User,
|
||||
db: AsyncSession,
|
||||
) -> StoryUniverse:
|
||||
result = await db.execute(
|
||||
select(StoryUniverse)
|
||||
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
|
||||
.where(
|
||||
StoryUniverse.id == universe_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
universe = result.scalar_one_or_none()
|
||||
if not universe:
|
||||
raise HTTPException(status_code=404, detail="宇宙不存在")
|
||||
return universe
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/universes", response_model=StoryUniverseListResponse)
|
||||
async def list_universes(
|
||||
profile_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List universes for a child profile."""
|
||||
await _get_profile_or_404(profile_id, user, db)
|
||||
result = await db.execute(
|
||||
select(StoryUniverse)
|
||||
.where(StoryUniverse.child_profile_id == profile_id)
|
||||
.order_by(StoryUniverse.updated_at.desc())
|
||||
)
|
||||
universes = result.scalars().all()
|
||||
return StoryUniverseListResponse(universes=universes, total=len(universes))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/profiles/{profile_id}/universes",
|
||||
response_model=StoryUniverseResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_universe(
|
||||
profile_id: str,
|
||||
payload: StoryUniverseCreate,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a story universe."""
|
||||
await _get_profile_or_404(profile_id, user, db)
|
||||
universe = StoryUniverse(child_profile_id=profile_id, **payload.model_dump())
|
||||
db.add(universe)
|
||||
await db.commit()
|
||||
await db.refresh(universe)
|
||||
return universe
|
||||
|
||||
|
||||
@router.get("/universes/{universe_id}", response_model=StoryUniverseResponse)
|
||||
async def get_universe(
|
||||
universe_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get one universe."""
|
||||
universe = await _get_universe_or_404(universe_id, user, db)
|
||||
return universe
|
||||
|
||||
|
||||
@router.put("/universes/{universe_id}", response_model=StoryUniverseResponse)
|
||||
async def update_universe(
|
||||
universe_id: str,
|
||||
payload: StoryUniverseUpdate,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update a story universe."""
|
||||
universe = await _get_universe_or_404(universe_id, user, db)
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
for key, value in updates.items():
|
||||
setattr(universe, key, value)
|
||||
await db.commit()
|
||||
await db.refresh(universe)
|
||||
return universe
|
||||
|
||||
|
||||
@router.delete("/universes/{universe_id}")
|
||||
async def delete_universe(
|
||||
universe_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete a story universe."""
|
||||
universe = await _get_universe_or_404(universe_id, user, db)
|
||||
await db.delete(universe)
|
||||
await db.commit()
|
||||
return {"message": "Deleted"}
|
||||
|
||||
|
||||
@router.post("/universes/{universe_id}/achievements", response_model=StoryUniverseResponse)
|
||||
async def add_achievement(
|
||||
universe_id: str,
|
||||
payload: AchievementCreate,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Add an achievement to a universe."""
|
||||
universe = await _get_universe_or_404(universe_id, user, db)
|
||||
|
||||
achievements = list(universe.achievements or [])
|
||||
key = (payload.type.strip(), payload.description.strip())
|
||||
existing = {
|
||||
(str(item.get("type", "")).strip(), str(item.get("description", "")).strip())
|
||||
for item in achievements
|
||||
if isinstance(item, dict)
|
||||
}
|
||||
if key not in existing:
|
||||
achievements.append({"type": key[0], "description": key[1]})
|
||||
universe.achievements = achievements
|
||||
await db.commit()
|
||||
await db.refresh(universe)
|
||||
|
||||
return universe
|
||||
Reference in New Issue
Block a user