Initial commit: clean project structure

- Backend: FastAPI + SQLAlchemy + Celery (Python 3.11+)
- Frontend: Vue 3 + TypeScript + Pinia + Tailwind
- Admin Frontend: separate Vue 3 app for management
- Docker Compose: 9 services orchestration
- Specs: design prototypes, memory system PRD, product roadmap

Cleanup performed:
- Removed temporary debug scripts from backend root
- Removed deprecated admin_app.py (embedded UI)
- Removed duplicate docs from admin-frontend
- Updated .gitignore for Vite cache and egg-info
This commit is contained in:
zhangtuo
2026-01-20 18:20:03 +08:00
commit e9d7f8832a
241 changed files with 33070 additions and 0 deletions

115
backend/.env.example Normal file
View File

@@ -0,0 +1,115 @@
# ==============================================
# DREAMWEAVER 环境变量配置模板
# ==============================================
# 使用说明:
# 1. 复制此文件为 .env
# 2. 填入您的 API Keys
# 3. 配合 docker-compose.yml 启动
# ==============================================
# ----------------------------------------------
# 1. 基础设施 (Infrastructure) [必填]
# ----------------------------------------------
# ⚠️ 在 Docker 启动时无需修改这部分,直接使用默认值即可
# ⚠️ 仅当您想连接外部数据库时才修改这里
POSTGRES_USER=dreamweaver
POSTGRES_PASSWORD=dreamweaver_password
POSTGRES_DB=dreamweaver_db
POSTGRES_PORT=5432
REDIS_PORT=6379
DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB}
CELERY_BROKER_URL=redis://redis:6379/0
CELERY_RESULT_BACKEND=redis://redis:6379/0
# Web Security
SECRET_KEY=change-me-to-a-secure-random-string-in-production
DEBUG=true
# ----------------------------------------------
# 2. AI 引擎配置 (AI Engines) [核心]
# ----------------------------------------------
# [策略配置]
# 系统默认使用的供应商列表 (按优先级排序)
# 文本生成: 优先 Gemini其次 OpenAI
TEXT_PROVIDERS=["gemini", "openai"]
# 图片生成: 优先 CQTAI (Flux/NanoBanana)
IMAGE_PROVIDERS=["cqtai"]
# 语音生成: 优先 MiniMax其次 ElevenLabs最后 EdgeTTS(免费)
TTS_PROVIDERS=["minimax", "elevenlabs", "edge_tts"]
# [模型参数]
TEXT_MODEL=gemini-2.0-flash
IMAGE_MODEL=nano-banana
IMAGE_RESOLUTION=1K
# TTS_MODEL=speech-2.6-turbo (MiniMax) / zh-CN-XiaoxiaoNeural (Edge)
# [API 密钥池]
# 请填入您拥有的 Key没有的留空即可
# ⚠️ 注意: 除非您使用国内中转(OneAPI)或企业私有版,否则无需填写 API_BASE (系统会自动使用官方地址)
# Google Gemini
TEXT_API_KEY=
TEXT_API_BASE=
# CQTAI / GoQuantum (Image)
CQTAI_API_KEY=
# CQTAI_API_BASE=https://api.cqtai.com/v1
# Antigravity (Image - OpenAI Compatible)
ANTIGRAVITY_API_KEY=
ANTIGRAVITY_API_BASE=http://127.0.0.1:8045/v1
# 模型: gemini-3-pro-image, gemini-3-pro-image-16-9, etc.
# MiniMax (TTS)
MINIMAX_API_KEY=
# MINIMAX_GROUP_ID 是 MiniMax v1/v2 接口必须的参数 (通常在 MiniMax 控制台可见)
MINIMAX_GROUP_ID=
MINIMAX_API_BASE=
# ElevenLabs (TTS)
ELEVENLABS_API_KEY=
# ELEVENLABS_API_BASE=https://api.elevenlabs.io/v1
# OpenAI (如需使用)
OPENAI_API_KEY=
OPENAI_API_BASE=
# ----------------------------------------------
# 3. 第三方登录 (OAuth Config) [可选]
# ----------------------------------------------
# 若留空,则无法使用该方式登录
GITHUB_CLIENT_ID=
GITHUB_CLIENT_SECRET=
GOOGLE_CLIENT_ID=
GOOGLE_CLIENT_SECRET=
# ----------------------------------------------
# 4. 管理后台 (Admin Console)
# ----------------------------------------------
# 是否开启 /admin 路由与 API (生产环境建议 false)
ENABLE_ADMIN_CONSOLE=true
# 管理员 Basic Auth 账号
ADMIN_USERNAME=admin
ADMIN_PASSWORD=admin
# ----------------------------------------------
# 5. 部署与网络 (Deployment & Network)
# ----------------------------------------------
# [外部访问地址]
# 用于 OAuth 回调验证 (对应 docker-compose 的 52000 端口)
BASE_URL=http://localhost:52000
# [跨域白名单 CORS]
# 包含 User Frontend (52080), Admin Frontend (52888) 及本地开发端口
CORS_ORIGINS=["http://localhost:52080", "http://localhost:52888", "http://localhost:5173", "http://localhost:5174"]
# [本地开发覆盖 Local Dev Override]
# 如果您不使用 Docker而是在本机直接运行 `python -m uvicorn ...`
# 请取消注释以下行以连接 localhost 数据库:
# DATABASE_URL=postgresql+asyncpg://dreamweaver:dreamweaver_password@localhost:52432/dreamweaver_db
# CELERY_BROKER_URL=redis://localhost:52379/0

27
backend/.gitignore vendored Normal file
View File

@@ -0,0 +1,27 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
.venv/
venv/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
# 环境变量
.env
# 测试
.pytest_cache/
.coverage
htmlcov/
# 其他
*.log
.DS_Store

27
backend/Dockerfile Normal file
View File

@@ -0,0 +1,27 @@
FROM python:3.11-slim
WORKDIR /app
# 安装系统依赖 (如果需要)
# RUN apt-get update && apt-get install -y gcc libpq-dev && rm -rf /var/lib/apt/lists/*
# 复制项目文件
COPY pyproject.toml .
# 复制源码
COPY app ./app
COPY alembic ./alembic
COPY alembic.ini .
# 安装依赖
# 使用 pip 安装当前目录 (.),会自动解析 pyproject.toml
RUN pip install --no-cache-dir .
# 创建静态文件目录 (用于存放生成的图片)
RUN mkdir -p static/images
# 暴露端口
EXPOSE 8000
# 启动命令
# 生产环境建议使用 gunicorn 或 uvicorn --workers
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

38
backend/alembic.ini Normal file
View File

@@ -0,0 +1,38 @@
[alembic]
script_location = alembic
sqlalchemy.url = postgresql+asyncpg://user:password@localhost/db
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s

20
backend/alembic/README.md Normal file
View File

@@ -0,0 +1,20 @@
# Alembic 使用说明
1. 安装依赖(在后端虚拟环境内)
```
pip install alembic
```
2. 设置环境变量,确保 `DATABASE_URL` 指向目标数据库。
3. 运行迁移:
```
alembic upgrade head
```
4. 生成新迁移(如有模型变更):
```
alembic revision -m "message" --autogenerate
```
说明:`alembic/env.py` 会从 `app.core.config` 读取数据库 URL并包含 admin/provider 模型。

68
backend/alembic/env.py Normal file
View File

@@ -0,0 +1,68 @@
import asyncio
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from alembic import context
from app.core.config import settings
from app.db import models, admin_models # ensure models are imported
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
fileConfig(config.config_file_name)
# override sqlalchemy.url from settings
config.set_main_option("sqlalchemy.url", settings.database_url)
target_metadata = models.Base.metadata
def run_migrations_offline():
"""Run migrations in 'offline' mode."""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection):
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online():
"""Run migrations in 'online' mode."""
connectable = async_engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
connect_args={"statement_cache_size": 0},
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())

View File

@@ -0,0 +1,45 @@
"""init providers and story mode"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0001_init_providers"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"providers",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("name", sa.String(length=100), nullable=False),
sa.Column("type", sa.String(length=50), nullable=False),
sa.Column("adapter", sa.String(length=100), nullable=False),
sa.Column("model", sa.String(length=200), nullable=True),
sa.Column("api_base", sa.String(length=300), nullable=True),
sa.Column("timeout_ms", sa.Integer(), server_default="60000", nullable=False),
sa.Column("max_retries", sa.Integer(), server_default="1", nullable=False),
sa.Column("weight", sa.Integer(), server_default="1", nullable=False),
sa.Column("priority", sa.Integer(), server_default="0", nullable=False),
sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False),
sa.Column("config_ref", sa.String(length=100), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_by", sa.String(length=100), nullable=True),
)
with op.batch_alter_table("stories", schema=None) as batch_op:
batch_op.add_column(
sa.Column("mode", sa.String(length=20), server_default="generated", nullable=False)
)
batch_op.alter_column("mode", server_default=None)
def downgrade() -> None:
with op.batch_alter_table("stories", schema=None) as batch_op:
batch_op.drop_column("mode")
op.drop_table("providers")

View File

@@ -0,0 +1,29 @@
"""add api_key to providers
Revision ID: 0002_add_api_key_to_providers
Revises: 0001_init_providers_and_story_mode
Create Date: 2025-01-01
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0002_add_api_key"
down_revision = "0001_init_providers"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 添加 api_key 列,可为空,优先于 config_ref 使用
with op.batch_alter_table("providers", schema=None) as batch_op:
batch_op.add_column(
sa.Column("api_key", sa.String(length=500), nullable=True)
)
def downgrade() -> None:
with op.batch_alter_table("providers", schema=None) as batch_op:
batch_op.drop_column("api_key")

View File

@@ -0,0 +1,100 @@
"""add provider monitoring tables
Revision ID: 0003_add_monitoring
Revises: 0002_add_api_key
Create Date: 2025-01-01
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0003_add_monitoring"
down_revision = "0002_add_api_key"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 创建 provider_metrics 表
op.create_table(
"provider_metrics",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("provider_id", sa.String(length=36), nullable=False),
sa.Column(
"timestamp",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("success", sa.Boolean(), nullable=False),
sa.Column("latency_ms", sa.Integer(), nullable=True),
sa.Column("cost_usd", sa.Numeric(precision=10, scale=6), nullable=True),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column("request_id", sa.String(length=100), nullable=True),
sa.ForeignKeyConstraint(
["provider_id"],
["providers.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_provider_metrics_provider_id",
"provider_metrics",
["provider_id"],
unique=False,
)
op.create_index(
"ix_provider_metrics_timestamp",
"provider_metrics",
["timestamp"],
unique=False,
)
# 创建 provider_health 表
op.create_table(
"provider_health",
sa.Column("provider_id", sa.String(length=36), nullable=False),
sa.Column("is_healthy", sa.Boolean(), server_default=sa.text("true"), nullable=True),
sa.Column("last_check", sa.DateTime(timezone=True), nullable=True),
sa.Column("consecutive_failures", sa.Integer(), server_default=sa.text("0"), nullable=True),
sa.Column("last_error", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(
["provider_id"],
["providers.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("provider_id"),
)
# 创建 provider_secrets 表
op.create_table(
"provider_secrets",
sa.Column("id", sa.String(length=36), nullable=False),
sa.Column("name", sa.String(length=100), nullable=False),
sa.Column("encrypted_value", sa.Text(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
def downgrade() -> None:
op.drop_table("provider_secrets")
op.drop_table("provider_health")
op.drop_index("ix_provider_metrics_timestamp", table_name="provider_metrics")
op.drop_index("ix_provider_metrics_provider_id", table_name="provider_metrics")
op.drop_table("provider_metrics")

View File

@@ -0,0 +1,42 @@
"""add child profiles
Revision ID: 0004_add_child_profiles
Revises: 0003_add_monitoring
Create Date: 2025-12-22
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0004_add_child_profiles"
down_revision = "0003_add_monitoring"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"child_profiles",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("user_id", sa.String(255), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("name", sa.String(50), nullable=False),
sa.Column("avatar_url", sa.String(500)),
sa.Column("birth_date", sa.Date()),
sa.Column("gender", sa.String(10)),
sa.Column("interests", sa.JSON(), server_default="[]", nullable=False),
sa.Column("growth_themes", sa.JSON(), server_default="[]", nullable=False),
sa.Column("reading_preferences", sa.JSON(), server_default="{}", nullable=False),
sa.Column("stories_count", sa.Integer(), server_default="0", nullable=False),
sa.Column("total_reading_time", sa.Integer(), server_default="0", nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.UniqueConstraint("user_id", "name", name="uq_child_profile_user_name"),
)
op.create_index("idx_child_profiles_user_id", "child_profiles", ["user_id"])
def downgrade():
op.drop_index("idx_child_profiles_user_id", table_name="child_profiles")
op.drop_table("child_profiles")

View File

@@ -0,0 +1,67 @@
"""add story universes and story links
Revision ID: 0005_add_story_universes_and_story_links
Revises: 0004_add_child_profiles
Create Date: 2025-12-22
"""
from alembic import op
import sqlalchemy as sa
revision = "0005_add_story_universes_and_story_links"
down_revision = "0004_add_child_profiles"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"story_universes",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column(
"child_profile_id",
sa.String(36),
sa.ForeignKey("child_profiles.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("name", sa.String(100), nullable=False),
sa.Column("protagonist", sa.JSON(), nullable=False),
sa.Column("recurring_characters", sa.JSON(), server_default="[]", nullable=False),
sa.Column("world_settings", sa.JSON(), server_default="{}", nullable=False),
sa.Column("achievements", sa.JSON(), server_default="[]", nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_index("idx_story_universes_child_id", "story_universes", ["child_profile_id"])
op.create_index("idx_story_universes_updated_at", "story_universes", ["updated_at"])
op.add_column("stories", sa.Column("child_profile_id", sa.String(36), nullable=True))
op.add_column("stories", sa.Column("universe_id", sa.String(36), nullable=True))
op.create_foreign_key(
"fk_stories_child_profile",
"stories",
"child_profiles",
["child_profile_id"],
["id"],
ondelete="SET NULL",
)
op.create_foreign_key(
"fk_stories_universe",
"stories",
"story_universes",
["universe_id"],
["id"],
ondelete="SET NULL",
)
def downgrade():
op.drop_constraint("fk_stories_universe", "stories", type_="foreignkey")
op.drop_constraint("fk_stories_child_profile", "stories", type_="foreignkey")
op.drop_column("stories", "universe_id")
op.drop_column("stories", "child_profile_id")
op.drop_index("idx_story_universes_updated_at", table_name="story_universes")
op.drop_index("idx_story_universes_child_id", table_name="story_universes")
op.drop_table("story_universes")

View File

@@ -0,0 +1,78 @@
"""add reading events and memory items
Revision ID: 0006_add_reading_events_and_memory_items
Revises: 0005_add_story_universes_and_story_links
Create Date: 2025-12-22
"""
from alembic import op
import sqlalchemy as sa
revision = "0006_add_reading_events_and_memory_items"
down_revision = "0005_add_story_universes_and_story_links"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"reading_events",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column(
"child_profile_id",
sa.String(36),
sa.ForeignKey("child_profiles.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"story_id",
sa.Integer(),
sa.ForeignKey("stories.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("event_type", sa.String(20), nullable=False),
sa.Column("reading_time", sa.Integer(), server_default="0", nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_index("idx_reading_events_profile", "reading_events", ["child_profile_id"])
op.create_index("idx_reading_events_story", "reading_events", ["story_id"])
op.create_index("idx_reading_events_created", "reading_events", ["created_at"])
op.create_table(
"memory_items",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column(
"child_profile_id",
sa.String(36),
sa.ForeignKey("child_profiles.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"universe_id",
sa.String(36),
sa.ForeignKey("story_universes.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("type", sa.String(50), nullable=False),
sa.Column("value", sa.JSON(), nullable=False),
sa.Column("base_weight", sa.Float(), server_default="1.0", nullable=False),
sa.Column("last_used_at", sa.DateTime(timezone=True)),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("ttl_days", sa.Integer()),
)
op.create_index("idx_memory_items_profile", "memory_items", ["child_profile_id"])
op.create_index("idx_memory_items_universe", "memory_items", ["universe_id"])
op.create_index("idx_memory_items_last_used", "memory_items", ["last_used_at"])
def downgrade():
op.drop_index("idx_memory_items_last_used", table_name="memory_items")
op.drop_index("idx_memory_items_universe", table_name="memory_items")
op.drop_index("idx_memory_items_profile", table_name="memory_items")
op.drop_table("memory_items")
op.drop_index("idx_reading_events_created", table_name="reading_events")
op.drop_index("idx_reading_events_story", table_name="reading_events")
op.drop_index("idx_reading_events_profile", table_name="reading_events")
op.drop_table("reading_events")

View File

@@ -0,0 +1,68 @@
"""Add push configs and events.
Revision ID: 0007_add_push_configs_and_events
Revises: 0006_add_reading_events_and_memory_items
Create Date: 2025-12-24 16:40:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0007_add_push_configs_and_events"
down_revision = "0006_add_reading_events_and_memory_items"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"push_configs",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("user_id", sa.String(length=255), nullable=False),
sa.Column("child_profile_id", sa.String(length=36), nullable=False),
sa.Column("push_time", sa.Time(), nullable=True),
sa.Column("push_days", sa.JSON(), nullable=False, server_default="[]"),
sa.Column("enabled", sa.Boolean(), nullable=False, server_default=sa.true()),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["child_profile_id"], ["child_profiles.id"], ondelete="CASCADE"),
sa.UniqueConstraint("child_profile_id", name="uq_push_config_child"),
)
op.create_index("ix_push_configs_user_id", "push_configs", ["user_id"])
op.create_index("ix_push_configs_child_profile_id", "push_configs", ["child_profile_id"])
op.create_table(
"push_events",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("user_id", sa.String(length=255), nullable=False),
sa.Column("child_profile_id", sa.String(length=36), nullable=False),
sa.Column("trigger_type", sa.String(length=20), nullable=False),
sa.Column("status", sa.String(length=20), nullable=False),
sa.Column("reason", sa.Text(), nullable=True),
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["child_profile_id"], ["child_profiles.id"], ondelete="CASCADE"),
)
op.create_index("ix_push_events_user_id", "push_events", ["user_id"])
op.create_index("ix_push_events_child_profile_id", "push_events", ["child_profile_id"])
op.create_index("ix_push_events_sent_at", "push_events", ["sent_at"])
def downgrade() -> None:
op.drop_index("ix_push_events_sent_at", table_name="push_events")
op.drop_index("ix_push_events_child_profile_id", table_name="push_events")
op.drop_index("ix_push_events_user_id", table_name="push_events")
op.drop_table("push_events")
op.drop_index("ix_push_configs_child_profile_id", table_name="push_configs")
op.drop_index("ix_push_configs_user_id", table_name="push_configs")
op.drop_table("push_configs")

View File

@@ -0,0 +1,25 @@
"""add pages column to stories
Revision ID: 0008_add_pages_to_stories
Revises: 0007_add_push_configs_and_events
Create Date: 2026-01-20
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '0008_add_pages_to_stories'
down_revision = '0007_add_push_configs_and_events'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column('stories', sa.Column('pages', postgresql.JSON(), nullable=True))
def downgrade() -> None:
op.drop_column('stories', 'pages')

0
backend/app/__init__.py Normal file
View File

61
backend/app/admin_main.py Normal file
View File

@@ -0,0 +1,61 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api import admin_providers, admin_reload
from app.core.config import settings
from app.core.logging import get_logger, setup_logging
from app.db.database import init_db
setup_logging()
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Admin App lifespan manager."""
logger.info("admin_app_starting")
await init_db()
# 可以在这里加载特定的 Admin 缓存或预热
yield
logger.info("admin_app_shutdown")
app = FastAPI(
title=f"{settings.app_name} Admin Console",
description="Administrative Control Plane for DreamWeaver.",
version="0.1.0",
lifespan=lifespan,
)
# Admin 后台通常允许更宽松的 CORS或者特定的管理域名
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins, # 或者专门的 ADMIN_CORS_ORIGINS
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 根据配置开关挂载路由
if settings.enable_admin_console:
app.include_router(admin_providers.router, prefix="/admin", tags=["admin-providers"])
app.include_router(admin_reload.router, prefix="/admin", tags=["admin-reload"])
else:
@app.get("/admin/{path:path}")
@app.post("/admin/{path:path}")
@app.put("/admin/{path:path}")
@app.delete("/admin/{path:path}")
async def admin_disabled(path: str):
from fastapi import HTTPException
raise HTTPException(
status_code=403,
detail="Admin console is disabled in environment configuration."
)
@app.get("/health")
async def health_check():
return {"status": "ok", "service": "admin-backend"}

View File

View File

@@ -0,0 +1,307 @@
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.admin_auth import admin_guard
from app.db.admin_models import Provider
from app.db.database import get_db
from app.services.cost_tracker import cost_tracker
from app.services.secret_service import SecretService
router = APIRouter(dependencies=[Depends(admin_guard)])
class ProviderCreate(BaseModel):
name: str
type: str = Field(..., pattern="^(text|image|tts|storybook)$")
adapter: str
model: str | None = None
api_base: str | None = None
api_key: str | None = None # 可选,优先于 config_ref
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_json: dict | None = None
config_ref: str | None = None # 环境变量 key 名称(回退)
updated_by: str | None = None
class ProviderUpdate(ProviderCreate):
enabled: bool | None = None
api_key: str | None = None
config_json: dict | None = None
class ProviderResponse(BaseModel):
"""Provider 响应模型,隐藏敏感字段。"""
id: str
name: str
type: str
adapter: str
model: str | None = None
api_base: str | None = None
has_api_key: bool = False # 仅标识是否配置了 api_key不返回明文
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_ref: str | None = None
model_config = ConfigDict(from_attributes=True)
from app.services.adapters.registry import AdapterRegistry
from app.services.provider_router import DEFAULT_PROVIDERS
@router.get("/providers/adapters")
async def list_available_adapters():
"""获取所有可用的适配器类型 (定义的类)。"""
return AdapterRegistry.list_adapters()
@router.get("/providers/defaults")
async def get_env_defaults():
"""获取当前环境变量定义的默认策略 (Read-Only)。"""
return DEFAULT_PROVIDERS
@router.get("/providers", response_model=list[ProviderResponse])
async def list_providers(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Provider))
providers = result.scalars().all()
# 转换为响应模型,隐藏 api_key 明文
return [
ProviderResponse(
id=p.id,
name=p.name,
type=p.type,
adapter=p.adapter,
model=p.model,
api_base=p.api_base,
has_api_key=bool(p.api_key), # 仅标识是否有 key
timeout_ms=p.timeout_ms,
max_retries=p.max_retries,
weight=p.weight,
priority=p.priority,
enabled=p.enabled,
config_ref=p.config_ref,
)
for p in providers
]
def _to_response(provider: Provider) -> ProviderResponse:
"""将 Provider 转换为响应模型,隐藏敏感字段。"""
return ProviderResponse(
id=provider.id,
name=provider.name,
type=provider.type,
adapter=provider.adapter,
model=provider.model,
api_base=provider.api_base,
has_api_key=bool(provider.api_key),
timeout_ms=provider.timeout_ms,
max_retries=provider.max_retries,
weight=provider.weight,
priority=provider.priority,
enabled=provider.enabled,
config_ref=provider.config_ref,
)
@router.post("/providers", response_model=ProviderResponse)
async def create_provider(payload: ProviderCreate, db: AsyncSession = Depends(get_db)):
data = payload.model_dump()
# 加密 API Key
if data.get("api_key"):
data["api_key"] = SecretService.encrypt(data["api_key"])
provider = Provider(**data)
db.add(provider)
await db.commit()
await db.refresh(provider)
return _to_response(provider)
@router.put("/providers/{provider_id}", response_model=ProviderResponse)
async def update_provider(
provider_id: str, payload: ProviderUpdate, db: AsyncSession = Depends(get_db)
):
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
data = payload.model_dump(exclude_unset=True)
# 加密 API Key
if "api_key" in data and data["api_key"]:
data["api_key"] = SecretService.encrypt(data["api_key"])
for k, v in data.items():
setattr(provider, k, v)
await db.commit()
await db.refresh(provider)
return _to_response(provider)
@router.delete("/providers/{provider_id}")
async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
await db.delete(provider)
await db.commit()
return {"message": "deleted"}
# ==================== 密钥管理 API ====================
class SecretCreate(BaseModel):
"""密钥创建请求。"""
name: str = Field(..., description="密钥名称,如 CQTAI_API_KEY")
value: str = Field(..., description="密钥明文值")
class SecretResponse(BaseModel):
"""密钥响应,不返回明文。"""
name: str
created_at: str | None = None
updated_at: str | None = None
@router.get("/secrets", response_model=list[str])
async def list_secrets(db: AsyncSession = Depends(get_db)):
"""列出所有密钥名称(不返回值)。"""
return await SecretService.list_secrets(db)
@router.post("/secrets", response_model=SecretResponse)
async def create_or_update_secret(payload: SecretCreate, db: AsyncSession = Depends(get_db)):
"""创建或更新密钥。"""
secret = await SecretService.set_secret(db, payload.name, payload.value)
return SecretResponse(
name=secret.name,
created_at=secret.created_at.isoformat() if secret.created_at else None,
updated_at=secret.updated_at.isoformat() if secret.updated_at else None,
)
@router.delete("/secrets/{name}")
async def delete_secret(name: str, db: AsyncSession = Depends(get_db)):
"""删除密钥。"""
deleted = await SecretService.delete_secret(db, name)
if not deleted:
raise HTTPException(status_code=404, detail="Secret not found")
return {"message": "deleted"}
@router.get("/secrets/{name}/verify")
async def verify_secret(name: str, db: AsyncSession = Depends(get_db)):
"""验证密钥是否存在且可解密(不返回明文)。"""
value = await SecretService.get_secret(db, name)
if value is None:
raise HTTPException(status_code=404, detail="Secret not found")
return {"name": name, "valid": True, "length": len(value)}
# ==================== 成本追踪 API ====================
class BudgetUpdate(BaseModel):
"""预算更新请求。"""
daily_limit_usd: float | None = None
monthly_limit_usd: float | None = None
alert_threshold: float | None = Field(default=None, ge=0, le=1)
enabled: bool | None = None
@router.get("/costs/summary/{user_id}")
async def get_user_cost_summary(user_id: str, db: AsyncSession = Depends(get_db)):
"""获取用户成本摘要。"""
return await cost_tracker.get_cost_summary(db, user_id)
@router.get("/costs/all")
async def get_all_costs_summary(db: AsyncSession = Depends(get_db)):
"""获取所有用户成本汇总(管理员)。"""
from sqlalchemy import func
from app.db.admin_models import CostRecord
# 按用户汇总
result = await db.execute(
select(
CostRecord.user_id,
func.sum(CostRecord.estimated_cost).label("total_cost"),
func.count().label("call_count"),
).group_by(CostRecord.user_id)
)
users = [
{"user_id": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]}
for row in result.all()
]
# 按能力汇总
result = await db.execute(
select(
CostRecord.capability,
func.sum(CostRecord.estimated_cost).label("total_cost"),
func.count().label("call_count"),
).group_by(CostRecord.capability)
)
capabilities = [
{"capability": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]}
for row in result.all()
]
return {"by_user": users, "by_capability": capabilities}
@router.get("/budgets/{user_id}")
async def get_user_budget(user_id: str, db: AsyncSession = Depends(get_db)):
"""获取用户预算配置。"""
budget = await cost_tracker.get_user_budget(db, user_id)
if not budget:
return {"user_id": user_id, "budget": None}
return {
"user_id": user_id,
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd),
"monthly_limit_usd": float(budget.monthly_limit_usd),
"alert_threshold": float(budget.alert_threshold),
"enabled": budget.enabled,
},
}
@router.post("/budgets/{user_id}")
async def set_user_budget(
user_id: str, payload: BudgetUpdate, db: AsyncSession = Depends(get_db)
):
"""设置用户预算。"""
budget = await cost_tracker.set_user_budget(
db,
user_id,
daily_limit=payload.daily_limit_usd,
monthly_limit=payload.monthly_limit_usd,
alert_threshold=payload.alert_threshold,
enabled=payload.enabled,
)
return {
"user_id": user_id,
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd),
"monthly_limit_usd": float(budget.monthly_limit_usd),
"alert_threshold": float(budget.alert_threshold),
"enabled": budget.enabled,
},
}

View File

@@ -0,0 +1,14 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.admin_auth import admin_guard
from app.db.database import get_db
from app.services.provider_cache import reload_providers
router = APIRouter(dependencies=[Depends(admin_guard)])
@router.post("/providers/reload")
async def reload(db: AsyncSession = Depends(get_db)):
cache = await reload_providers(db)
return {k: len(v) for k, v in cache.items()}

272
backend/app/api/auth.py Normal file
View File

@@ -0,0 +1,272 @@
import secrets
from urllib.parse import urlencode
import httpx
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query
from fastapi.responses import RedirectResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.deps import get_current_user
from app.core.security import create_access_token
from app.db.database import get_db
from app.db.models import User
router = APIRouter()
# OAuth endpoints
GITHUB_AUTHORIZE_URL = "https://github.com/login/oauth/authorize"
GITHUB_TOKEN_URL = "https://github.com/login/oauth/access_token"
GITHUB_USER_URL = "https://api.github.com/user"
GOOGLE_AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/v2/auth"
GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token"
GOOGLE_USER_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
STATE_COOKIE = "oauth_state"
STATE_MAX_AGE = 600 # 10 minutes
def _set_state_cookie(response: RedirectResponse, provider: str, state: str) -> None:
response.set_cookie(
key=STATE_COOKIE,
value=f"{provider}:{state}",
httponly=True,
secure=not settings.debug,
samesite="lax",
max_age=STATE_MAX_AGE,
)
def _validate_state(state_from_query: str | None, state_cookie: str | None, provider: str):
if not state_from_query or not state_cookie:
raise HTTPException(status_code=400, detail="Missing OAuth state")
expected_prefix = f"{provider}:"
if not state_cookie.startswith(expected_prefix):
raise HTTPException(status_code=400, detail="OAuth state mismatch")
expected_state = state_cookie.removeprefix(expected_prefix)
if not secrets.compare_digest(state_from_query, expected_state):
raise HTTPException(status_code=400, detail="OAuth state mismatch")
@router.get("/github/signin")
async def github_signin():
"""Start GitHub OAuth with state protection."""
state = secrets.token_urlsafe(16)
params = {
"client_id": settings.github_client_id,
"redirect_uri": f"{settings.base_url}/auth/github/callback",
"scope": "read:user user:email",
"state": state,
}
url = f"{GITHUB_AUTHORIZE_URL}?{urlencode(params)}"
response = RedirectResponse(url=url)
_set_state_cookie(response, "github", state)
return response
@router.get("/github/callback")
async def github_callback(
code: str,
state: str | None = Query(default=None),
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
db: AsyncSession = Depends(get_db),
):
"""Handle GitHub OAuth callback."""
_validate_state(state, state_cookie, "github")
try:
async with httpx.AsyncClient() as client:
token_resp = await client.post(
GITHUB_TOKEN_URL,
data={
"client_id": settings.github_client_id,
"client_secret": settings.github_client_secret,
"code": code,
},
headers={"Accept": "application/json"},
)
token_resp.raise_for_status()
token_data = token_resp.json()
access_token = token_data.get("access_token")
if not access_token:
raise HTTPException(status_code=502, detail="GitHub login failed")
user_resp = await client.get(
GITHUB_USER_URL,
headers={"Authorization": f"Bearer {access_token}"},
)
user_resp.raise_for_status()
user_data = user_resp.json()
except httpx.HTTPStatusError:
raise HTTPException(status_code=502, detail="GitHub login failed")
github_id = user_data.get("id")
if github_id is None:
raise HTTPException(status_code=502, detail="GitHub login failed")
return await _handle_oauth_user(
db=db,
provider="github",
user_id=str(github_id),
name=user_data.get("name") or user_data.get("login") or "GitHub User",
avatar_url=user_data.get("avatar_url"),
)
@router.get("/google/signin")
async def google_signin():
"""Start Google OAuth with state protection."""
state = secrets.token_urlsafe(16)
params = {
"client_id": settings.google_client_id,
"redirect_uri": f"{settings.base_url}/auth/google/callback",
"response_type": "code",
"scope": "openid email profile",
"state": state,
}
url = f"{GOOGLE_AUTHORIZE_URL}?{urlencode(params)}"
response = RedirectResponse(url=url)
_set_state_cookie(response, "google", state)
return response
@router.get("/google/callback")
async def google_callback(
code: str,
state: str | None = Query(default=None),
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
db: AsyncSession = Depends(get_db),
):
"""Handle Google OAuth callback."""
_validate_state(state, state_cookie, "google")
try:
async with httpx.AsyncClient() as client:
token_resp = await client.post(
GOOGLE_TOKEN_URL,
data={
"client_id": settings.google_client_id,
"client_secret": settings.google_client_secret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": f"{settings.base_url}/auth/google/callback",
},
)
token_resp.raise_for_status()
token_data = token_resp.json()
access_token = token_data.get("access_token")
if not access_token:
raise HTTPException(status_code=502, detail="Google login failed")
user_resp = await client.get(
GOOGLE_USER_URL,
headers={"Authorization": f"Bearer {access_token}"},
)
user_resp.raise_for_status()
user_data = user_resp.json()
except httpx.HTTPStatusError:
raise HTTPException(status_code=502, detail="Google login failed")
google_id = user_data.get("id")
if google_id is None:
raise HTTPException(status_code=502, detail="Google login failed")
return await _handle_oauth_user(
db=db,
provider="google",
user_id=str(google_id),
name=user_data.get("name") or user_data.get("email") or "Google User",
avatar_url=user_data.get("picture"),
)
async def _handle_oauth_user(
db: AsyncSession,
provider: str,
user_id: str,
name: str,
avatar_url: str | None,
) -> RedirectResponse:
"""Create/update user and issue session cookie."""
full_id = f"{provider}:{user_id}"
result = await db.execute(select(User).where(User.id == full_id))
user = result.scalar_one_or_none()
if not user:
user = User(
id=full_id,
name=name,
avatar_url=avatar_url,
provider=provider,
)
db.add(user)
else:
user.name = name
user.avatar_url = avatar_url
await db.commit()
token = create_access_token({"sub": user.id})
frontend_url = "http://localhost:5173"
if settings.cors_origins and len(settings.cors_origins) > 0:
frontend_url = settings.cors_origins[0]
response = RedirectResponse(url=f"{frontend_url}/my-stories", status_code=302)
response.set_cookie(
key="access_token",
value=token,
httponly=True,
secure=not settings.debug,
samesite="lax",
max_age=60 * 60 * 24 * 7, # align with ACCESS_TOKEN_EXPIRE_DAYS
)
response.delete_cookie(STATE_COOKIE)
return response
@router.post("/signout")
async def signout():
"""Sign out and clear cookies."""
response = RedirectResponse(url=settings.cors_origins[0], status_code=302)
response.delete_cookie("access_token", samesite="lax", secure=not settings.debug)
response.delete_cookie(STATE_COOKIE, samesite="lax", secure=not settings.debug)
return response
@router.get("/session")
async def get_session(user: User | None = Depends(get_current_user)):
"""Fetch current session info."""
if not user:
return {"user": None}
return {
"user": {
"id": user.id,
"name": user.name,
"avatar_url": user.avatar_url,
"provider": user.provider,
}
}
@router.get("/dev/signin")
async def dev_signin(db: AsyncSession = Depends(get_db)):
"""Developer backdoor login. Only works in DEBUG mode."""
# if not settings.debug:
# raise HTTPException(status_code=403, detail="Developer login disabled")
try:
return await _handle_oauth_user(
db=db,
provider="github",
user_id="dev_user_001",
name="Developer",
avatar_url="https://api.dicebear.com/7.x/avataaars/svg?seed=Developer"
)
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Dev login failed: {str(e)}")

268
backend/app/api/memories.py Normal file
View File

@@ -0,0 +1,268 @@
"""Memory management APIs."""
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, User
from app.services import memory_service
from app.services.memory_service import MemoryType
router = APIRouter()
class MemoryItemResponse(BaseModel):
"""Memory item response."""
id: str
type: str
value: dict
base_weight: float
ttl_days: int | None
created_at: str
last_used_at: str | None
class Config:
from_attributes = True
class MemoryListResponse(BaseModel):
"""Memory list response."""
memories: list[MemoryItemResponse]
total: int
class CreateMemoryRequest(BaseModel):
"""Create memory request."""
type: str = Field(..., description="记忆类型")
value: dict = Field(..., description="记忆内容")
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
weight: float | None = Field(default=None, description="权重")
ttl_days: int | None = Field(default=None, description="过期天数")
class CreateCharacterMemoryRequest(BaseModel):
"""Create character memory request."""
name: str = Field(..., description="角色名称")
description: str | None = Field(default=None, description="角色描述")
source_story_id: int | None = Field(default=None, description="来源故事 ID")
affinity_score: float = Field(default=1.0, ge=0.0, le=1.0, description="喜爱程度")
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
class CreateScaryElementRequest(BaseModel):
"""Create scary element memory request."""
keyword: str = Field(..., description="回避的关键词")
category: str = Field(default="other", description="分类")
source_story_id: int | None = Field(default=None, description="来源故事 ID")
async def _verify_profile_ownership(
profile_id: str, user: User, db: AsyncSession
) -> ChildProfile:
"""验证档案所有权。"""
from sqlalchemy import select
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
return profile
@router.get("/profiles/{profile_id}/memories", response_model=MemoryListResponse)
async def list_memories(
profile_id: str,
memory_type: str | None = None,
universe_id: str | None = None,
limit: int = 50,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""获取档案的记忆列表。"""
await _verify_profile_ownership(profile_id, user, db)
memories = await memory_service.get_profile_memories(
db=db,
profile_id=profile_id,
memory_type=memory_type,
universe_id=universe_id,
limit=limit,
)
return MemoryListResponse(
memories=[
MemoryItemResponse(
id=m.id,
type=m.type,
value=m.value,
base_weight=m.base_weight,
ttl_days=m.ttl_days,
created_at=m.created_at.isoformat() if m.created_at else "",
last_used_at=m.last_used_at.isoformat() if m.last_used_at else None,
)
for m in memories
],
total=len(memories),
)
@router.post("/profiles/{profile_id}/memories", response_model=MemoryItemResponse)
async def create_memory(
profile_id: str,
payload: CreateMemoryRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""创建新的记忆项。"""
await _verify_profile_ownership(profile_id, user, db)
# 验证类型
valid_types = [
MemoryType.RECENT_STORY,
MemoryType.FAVORITE_CHARACTER,
MemoryType.SCARY_ELEMENT,
MemoryType.VOCABULARY_GROWTH,
MemoryType.EMOTIONAL_HIGHLIGHT,
MemoryType.READING_PREFERENCE,
MemoryType.MILESTONE,
MemoryType.SKILL_MASTERED,
]
if payload.type not in valid_types:
raise HTTPException(status_code=400, detail=f"无效的记忆类型: {payload.type}")
memory = await memory_service.create_memory(
db=db,
profile_id=profile_id,
memory_type=payload.type,
value=payload.value,
universe_id=payload.universe_id,
weight=payload.weight,
ttl_days=payload.ttl_days,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.post("/profiles/{profile_id}/memories/character", response_model=MemoryItemResponse)
async def create_character_memory(
profile_id: str,
payload: CreateCharacterMemoryRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""添加喜欢的角色。"""
await _verify_profile_ownership(profile_id, user, db)
memory = await memory_service.create_character_memory(
db=db,
profile_id=profile_id,
name=payload.name,
description=payload.description,
source_story_id=payload.source_story_id,
affinity_score=payload.affinity_score,
universe_id=payload.universe_id,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.post("/profiles/{profile_id}/memories/scary", response_model=MemoryItemResponse)
async def create_scary_element_memory(
profile_id: str,
payload: CreateScaryElementRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""添加回避元素。"""
await _verify_profile_ownership(profile_id, user, db)
memory = await memory_service.create_scary_element_memory(
db=db,
profile_id=profile_id,
keyword=payload.keyword,
category=payload.category,
source_story_id=payload.source_story_id,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.delete("/profiles/{profile_id}/memories/{memory_id}")
async def delete_memory(
profile_id: str,
memory_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""删除记忆项。"""
from sqlalchemy import select
from app.db.models import MemoryItem
await _verify_profile_ownership(profile_id, user, db)
result = await db.execute(
select(MemoryItem).where(
MemoryItem.id == memory_id,
MemoryItem.child_profile_id == profile_id,
)
)
memory = result.scalar_one_or_none()
if not memory:
raise HTTPException(status_code=404, detail="记忆不存在")
await db.delete(memory)
await db.commit()
return {"message": "Deleted"}
@router.get("/memory-types")
async def list_memory_types():
"""获取所有可用的记忆类型及其配置。"""
types = []
for type_name, config in MemoryType.CONFIG.items():
types.append({
"type": type_name,
"default_weight": config[0],
"default_ttl_days": config[1],
"description": config[2],
})
return {"types": types}

280
backend/app/api/profiles.py Normal file
View File

@@ -0,0 +1,280 @@
"""Child profile APIs."""
from datetime import date
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, Story, StoryUniverse, User
router = APIRouter()
MAX_PROFILES_PER_USER = 5
class ChildProfileCreate(BaseModel):
"""Create profile payload."""
name: str = Field(..., min_length=1, max_length=50)
birth_date: date | None = None
gender: str | None = Field(default=None, pattern="^(male|female|other)$")
interests: list[str] = Field(default_factory=list)
growth_themes: list[str] = Field(default_factory=list)
avatar_url: str | None = None
class ChildProfileUpdate(BaseModel):
"""Update profile payload."""
name: str | None = Field(default=None, min_length=1, max_length=50)
birth_date: date | None = None
gender: str | None = Field(default=None, pattern="^(male|female|other)$")
interests: list[str] | None = None
growth_themes: list[str] | None = None
avatar_url: str | None = None
class ChildProfileResponse(BaseModel):
"""Profile response."""
id: str
name: str
avatar_url: str | None
birth_date: date | None
gender: str | None
age: int | None
interests: list[str]
growth_themes: list[str]
stories_count: int
total_reading_time: int
class Config:
from_attributes = True
class ChildProfileListResponse(BaseModel):
"""Profile list response."""
profiles: list[ChildProfileResponse]
total: int
class TimelineEvent(BaseModel):
"""Timeline event item."""
date: str
type: Literal["story", "achievement", "milestone"]
title: str
description: str | None = None
image_url: str | None = None
metadata: dict | None = None
class TimelineResponse(BaseModel):
"""Timeline response."""
events: list[TimelineEvent]
@router.get("/profiles", response_model=ChildProfileListResponse)
async def list_profiles(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List child profiles for current user."""
result = await db.execute(
select(ChildProfile)
.where(ChildProfile.user_id == user.id)
.order_by(ChildProfile.created_at.desc())
)
profiles = result.scalars().all()
return ChildProfileListResponse(profiles=profiles, total=len(profiles))
@router.post("/profiles", response_model=ChildProfileResponse, status_code=status.HTTP_201_CREATED)
async def create_profile(
payload: ChildProfileCreate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Create a new child profile."""
count = await db.scalar(
select(func.count(ChildProfile.id)).where(ChildProfile.user_id == user.id)
)
if count and count >= MAX_PROFILES_PER_USER:
raise HTTPException(status_code=400, detail="最多只能创建 5 个孩子档案")
existing = await db.scalar(
select(ChildProfile.id).where(
ChildProfile.user_id == user.id,
ChildProfile.name == payload.name,
)
)
if existing:
raise HTTPException(status_code=409, detail="该档案名称已存在")
profile = ChildProfile(user_id=user.id, **payload.model_dump())
db.add(profile)
await db.commit()
await db.refresh(profile)
return profile
@router.get("/profiles/{profile_id}", response_model=ChildProfileResponse)
async def get_profile(
profile_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get one child profile."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
return profile
@router.put("/profiles/{profile_id}", response_model=ChildProfileResponse)
async def update_profile(
profile_id: str,
payload: ChildProfileUpdate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Update a child profile."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
updates = payload.model_dump(exclude_unset=True)
if "name" in updates:
existing = await db.scalar(
select(ChildProfile.id).where(
ChildProfile.user_id == user.id,
ChildProfile.name == updates["name"],
ChildProfile.id != profile_id,
)
)
if existing:
raise HTTPException(status_code=409, detail="该档案名称已存在")
for key, value in updates.items():
setattr(profile, key, value)
await db.commit()
await db.refresh(profile)
return profile
@router.delete("/profiles/{profile_id}")
async def delete_profile(
profile_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Delete a child profile."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
await db.delete(profile)
await db.commit()
return {"message": "Deleted"}
@router.get("/profiles/{profile_id}/timeline", response_model=TimelineResponse)
async def get_profile_timeline(
profile_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get profile growth timeline."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
events: list[TimelineEvent] = []
# 1. Milestone: Profile Created
events.append(TimelineEvent(
date=profile.created_at.isoformat(),
type="milestone",
title="初次相遇",
description=f"创建了档案 {profile.name}"
))
# 2. Stories
stories_result = await db.execute(
select(Story).where(Story.child_profile_id == profile_id)
)
for s in stories_result.scalars():
events.append(TimelineEvent(
date=s.created_at.isoformat(),
type="story",
title=s.title,
image_url=s.image_url,
metadata={"story_id": s.id, "mode": s.mode}
))
# 3. Achievements (from Universe)
universes_result = await db.execute(
select(StoryUniverse).where(StoryUniverse.child_profile_id == profile_id)
)
for u in universes_result.scalars():
if u.achievements:
for ach in u.achievements:
if isinstance(ach, dict):
obt_at = ach.get("obtained_at")
# Fallback
if not obt_at:
obt_at = u.updated_at.isoformat()
events.append(TimelineEvent(
date=obt_at,
type="achievement",
title=f"获得成就:{ach.get('type')}",
description=ach.get('description'),
metadata={"universe_id": u.id, "source_story_id": ach.get("source_story_id")}
))
# Sort by date desc
events.sort(key=lambda x: x.date, reverse=True)
return TimelineResponse(events=events)

View File

@@ -0,0 +1,120 @@
"""Push configuration APIs."""
from datetime import time
from fastapi import APIRouter, Depends, HTTPException, Response, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, PushConfig, User
router = APIRouter()
class PushConfigUpsert(BaseModel):
"""Upsert push config payload."""
child_profile_id: str
push_time: time | None = None
push_days: list[int] | None = None
enabled: bool | None = None
class PushConfigResponse(BaseModel):
"""Push config response."""
id: str
child_profile_id: str
push_time: time | None
push_days: list[int]
enabled: bool
class Config:
from_attributes = True
class PushConfigListResponse(BaseModel):
"""Push config list response."""
configs: list[PushConfigResponse]
total: int
def _validate_push_days(push_days: list[int]) -> list[int]:
invalid = [day for day in push_days if day < 0 or day > 6]
if invalid:
raise HTTPException(status_code=400, detail="推送日期必须在 0-6 之间")
return list(dict.fromkeys(push_days))
@router.get("/push-configs", response_model=PushConfigListResponse)
async def list_push_configs(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List push configs for current user."""
result = await db.execute(
select(PushConfig).where(PushConfig.user_id == user.id)
)
configs = result.scalars().all()
return PushConfigListResponse(configs=configs, total=len(configs))
@router.put("/push-configs", response_model=PushConfigResponse)
async def upsert_push_config(
payload: PushConfigUpsert,
response: Response,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Create or update push config for a child profile."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == payload.child_profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
result = await db.execute(
select(PushConfig).where(PushConfig.child_profile_id == payload.child_profile_id)
)
config = result.scalar_one_or_none()
if config is None:
if payload.push_time is None or payload.push_days is None:
raise HTTPException(status_code=400, detail="创建配置需要提供推送时间和日期")
push_days = _validate_push_days(payload.push_days)
config = PushConfig(
user_id=user.id,
child_profile_id=payload.child_profile_id,
push_time=payload.push_time,
push_days=push_days,
enabled=True if payload.enabled is None else payload.enabled,
)
db.add(config)
await db.commit()
await db.refresh(config)
response.status_code = status.HTTP_201_CREATED
return config
updates = payload.model_dump(exclude_unset=True)
if "push_days" in updates and updates["push_days"] is not None:
updates["push_days"] = _validate_push_days(updates["push_days"])
if "push_time" in updates and updates["push_time"] is None:
raise HTTPException(status_code=400, detail="推送时间不能为空")
for key, value in updates.items():
if key == "child_profile_id":
continue
if value is not None:
setattr(config, key, value)
await db.commit()
await db.refresh(config)
return config

View File

@@ -0,0 +1,120 @@
"""Reading event APIs."""
from datetime import datetime, timezone
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, MemoryItem, ReadingEvent, Story, User
router = APIRouter()
EVENT_WEIGHTS: dict[str, float] = {
"completed": 1.0,
"replayed": 1.5,
"started": 0.1,
"skipped": -0.5,
}
class ReadingEventCreate(BaseModel):
"""Reading event payload."""
child_profile_id: str
story_id: int | None = None
event_type: Literal["started", "completed", "skipped", "replayed"]
reading_time: int = Field(default=0, ge=0)
class ReadingEventResponse(BaseModel):
"""Reading event response."""
id: int
child_profile_id: str
story_id: int | None
event_type: str
reading_time: int
created_at: datetime
class Config:
from_attributes = True
@router.post("/reading-events", response_model=ReadingEventResponse, status_code=status.HTTP_201_CREATED)
async def create_reading_event(
payload: ReadingEventCreate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Create a reading event and update profile stats/memory."""
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == payload.child_profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
story = None
if payload.story_id is not None:
result = await db.execute(
select(Story).where(
Story.id == payload.story_id,
Story.user_id == user.id,
)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="故事不存在")
if payload.reading_time:
profile.total_reading_time = (profile.total_reading_time or 0) + payload.reading_time
if payload.event_type in {"completed", "replayed"} and payload.story_id is not None:
existing = await db.scalar(
select(ReadingEvent.id).where(
ReadingEvent.child_profile_id == payload.child_profile_id,
ReadingEvent.story_id == payload.story_id,
ReadingEvent.event_type.in_(["completed", "replayed"]),
)
)
if existing is None:
profile.stories_count = (profile.stories_count or 0) + 1
event = ReadingEvent(
child_profile_id=payload.child_profile_id,
story_id=payload.story_id,
event_type=payload.event_type,
reading_time=payload.reading_time,
)
db.add(event)
weight = EVENT_WEIGHTS.get(payload.event_type, 0.0)
if story and weight > 0:
db.add(
MemoryItem(
child_profile_id=payload.child_profile_id,
universe_id=story.universe_id,
type="recent_story",
value={
"story_id": story.id,
"title": story.title,
"event_type": payload.event_type,
},
base_weight=weight,
last_used_at=datetime.now(timezone.utc),
ttl_days=90,
)
)
await db.commit()
await db.refresh(event)
return event

605
backend/app/api/stories.py Normal file
View File

@@ -0,0 +1,605 @@
"""Story related APIs."""
import asyncio
import json
import time
import uuid
from typing import AsyncGenerator, Literal
from cachetools import TTLCache
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sse_starlette.sse import EventSourceResponse
from app.core.deps import require_user
from app.core.logging import get_logger
from app.db.database import get_db
from app.db.models import ChildProfile, Story, StoryUniverse, User
from app.services.provider_router import (
generate_image,
generate_story_content,
generate_storybook,
text_to_speech,
)
from app.tasks.achievements import extract_story_achievements
logger = get_logger(__name__)
router = APIRouter()
MAX_DATA_LENGTH = 2000
MAX_EDU_THEME_LENGTH = 200
MAX_TTS_LENGTH = 4000
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10
RATE_LIMIT_CACHE_SIZE = 10000 # 最大跟踪用户数
_request_log: TTLCache[str, list[float]] = TTLCache(
maxsize=RATE_LIMIT_CACHE_SIZE, ttl=RATE_LIMIT_WINDOW * 2
)
def _check_rate_limit(user_id: str):
now = time.time()
timestamps = _request_log.get(user_id, [])
timestamps = [t for t in timestamps if now - t <= RATE_LIMIT_WINDOW]
if len(timestamps) >= RATE_LIMIT_REQUESTS:
raise HTTPException(status_code=429, detail="Too many requests, please slow down.")
timestamps.append(now)
_request_log[user_id] = timestamps
class GenerateRequest(BaseModel):
"""Story generation request."""
type: Literal["keywords", "full_story"]
data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
child_profile_id: str | None = None
universe_id: str | None = None
class StoryResponse(BaseModel):
"""Story response."""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
mode: str
child_profile_id: str | None = None
universe_id: str | None = None
class StoryListItem(BaseModel):
"""Story list item."""
id: int
title: str
image_url: str | None
created_at: str
mode: str
class FullStoryResponse(BaseModel):
"""完整故事响应(含图片和音频状态)。"""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
audio_ready: bool
mode: str
errors: dict[str, str | None] = Field(default_factory=dict)
child_profile_id: str | None = None
universe_id: str | None = None
from app.services.memory_service import build_enhanced_memory_context
async def _validate_profile_and_universe(
request: GenerateRequest,
user: User,
db: AsyncSession,
) -> tuple[str | None, str | None]:
if not request.child_profile_id and not request.universe_id:
return None, None
profile_id = request.child_profile_id
universe_id = request.universe_id
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
return profile_id, universe_id
@router.post("/stories/generate", response_model=StoryResponse)
async def generate_story(
request: GenerateRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate or enhance a story."""
_check_rate_limit(user.id)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return StoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=story.image_url,
mode=story.mode,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
@router.post("/stories/generate/full", response_model=FullStoryResponse)
async def generate_story_full(
request: GenerateRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""生成完整故事(故事 + 并行生成图片和音频)。
部分成功策略:故事必须成功,图片/音频失败不影响整体。
"""
_check_rate_limit(user.id)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# Step 1: 故事生成(必须成功)
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except Exception as exc:
logger.error("story_generation_failed", error=str(exc))
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
# 保存故事
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# Step 2: 生成封面图片(音频按需生成,避免浪费)
errors: dict[str, str | None] = {}
image_url: str | None = None
if story.cover_prompt:
try:
image_url = await generate_image(story.cover_prompt)
story.image_url = image_url
await db.commit()
except Exception as exc:
errors["image"] = str(exc)
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
# 注意:音频不在此处预生成,用户通过 /api/audio/{id} 按需获取
# 这样避免生成后丢弃造成的成本浪费
return FullStoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=image_url,
audio_ready=False, # 音频需要用户主动请求
mode=story.mode,
errors=errors,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
@router.post("/stories/generate/stream")
async def generate_story_stream(
request: GenerateRequest,
req: Request,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""流式生成故事SSE
事件流程:
- started: 返回 story_id
- story_ready: 返回 title, content
- story_failed: 返回 error
- image_ready: 返回 image_url
- image_failed: 返回 error
- complete: 结束流
"""
_check_rate_limit(user.id)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
async def event_generator() -> AsyncGenerator[dict, None]:
story_id = str(uuid.uuid4())
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
# Step 1: 生成故事
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except Exception as e:
logger.error("sse_story_generation_failed", error=str(e))
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
return
# 保存故事
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
yield {
"event": "story_ready",
"data": json.dumps({
"id": story.id,
"title": story.title,
"content": story.story_text,
"cover_prompt": story.cover_prompt,
"mode": story.mode,
"child_profile_id": story.child_profile_id,
"universe_id": story.universe_id,
}),
}
# Step 2: 并行生成图片(音频按需)
if story.cover_prompt:
try:
image_url = await generate_image(story.cover_prompt)
story.image_url = image_url
await db.commit()
yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})}
except Exception as e:
logger.warning("sse_image_generation_failed", story_id=story.id, error=str(e))
yield {"event": "image_failed", "data": json.dumps({"error": str(e)})}
yield {"event": "complete", "data": json.dumps({"story_id": story.id})}
return EventSourceResponse(event_generator())
# ==================== Storybook API ====================
class StorybookRequest(BaseModel):
"""Storybook 生成请求。"""
keywords: str = Field(..., min_length=1, max_length=200)
page_count: int = Field(default=6, ge=4, le=12)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
generate_images: bool = Field(default=False, description="是否同时生成插图")
child_profile_id: str | None = None
universe_id: str | None = None
class StorybookPageResponse(BaseModel):
"""故事书单页响应。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
class StorybookResponse(BaseModel):
"""故事书响应。"""
id: int | None = None
title: str
main_character: str
art_style: str
pages: list[StorybookPageResponse]
cover_prompt: str
cover_url: str | None = None
@router.post("/storybook/generate", response_model=StorybookResponse)
async def generate_storybook_api(
request: StorybookRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""生成分页故事书并保存。
返回故事书结构,包含每页文字和图像提示词。
"""
_check_rate_limit(user.id)
# 验证档案和宇宙
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数
# 这里我们直接手动验证,或重构验证函数。为了简单,手动调用部分逻辑。
# 构建临时的 GenerateRequest 用于验证验证函数签名(或者直接手动查库更好)
profile_id = request.child_profile_id
universe_id = request.universe_id
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
if not result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
logger.info(
"storybook_request",
user_id=user.id,
keywords=request.keywords,
page_count=request.page_count,
profile_id=profile_id,
universe_id=universe_id,
)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
try:
# 注意generate_storybook 目前可能不支持记忆上下文注入
# 我们需要看看 generate_storybook 的签名
# 如果不支持,记忆功能在绘本模式下暂不可用,但基本参数传递是支持的
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
# ==============================================================================
# 核心升级: 并行全量生成 (Parallel Full Rendering)
# ==============================================================================
final_cover_url = storybook.cover_url
if request.generate_images:
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
# 1. 准备所有生图任务 (封面 + 所有内页)
tasks = []
# 封面任务
async def _gen_cover():
if storybook.cover_prompt and not storybook.cover_url:
try:
return await generate_image(storybook.cover_prompt, db=db)
except Exception as e:
logger.warning("cover_gen_failed", error=str(e))
return storybook.cover_url
tasks.append(_gen_cover())
# 内页任务
async def _gen_page(page):
if page.image_prompt and not page.image_url:
try:
url = await generate_image(page.image_prompt, db=db)
page.image_url = url
except Exception as e:
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
for page in storybook.pages:
tasks.append(_gen_page(page))
# 2. 并发执行所有任务
# 使用 return_exceptions=True 防止单张失败影响整体
results = await asyncio.gather(*tasks, return_exceptions=True)
# 3. 更新封面结果 (results[0] 是封面任务的返回值)
cover_res = results[0]
if isinstance(cover_res, str):
final_cover_url = cover_res
logger.info("storybook_parallel_generation_complete")
# ==============================================================================
# 构建并保存 Story 对象
# 将 pages 对象转换为字典列表以存入 JSON 字段
pages_data = [
{
"page_number": p.page_number,
"text": p.text,
"image_prompt": p.image_prompt,
"image_url": p.image_url,
}
for p in storybook.pages
]
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=storybook.title,
mode="storybook",
pages=pages_data, # 存入 JSON 字段
story_text=None, # 绘本模式下,主文本可为空,或者可以存个摘要
cover_prompt=storybook.cover_prompt,
image_url=final_cover_url,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# 构建响应 (使用更新后的 pages_data)
response_pages = [
StorybookPageResponse(
page_number=p["page_number"],
text=p["text"],
image_prompt=p["image_prompt"],
image_url=p.get("image_url"),
)
for p in pages_data
]
return StorybookResponse(
id=story.id,
title=storybook.title,
main_character=storybook.main_character,
art_style=storybook.art_style,
pages=response_pages,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
)
class AchievementItem(BaseModel):
type: str
description: str
obtained_at: str | None = None
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
async def get_story_achievements(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get achievements unlocked by a specific story."""
# 使用 joinedload 避免 N+1 查询
result = await db.execute(
select(Story)
.options(joinedload(Story.story_universe))
.where(Story.id == story_id, Story.user_id == user.id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
if not story.universe_id or not story.story_universe:
return []
universe = story.story_universe
if not universe.achievements:
return []
results = []
for ach in universe.achievements:
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
results.append(AchievementItem(
type=ach.get("type", "Unknown"),
description=ach.get("description", ""),
obtained_at=ach.get("obtained_at")
))
return results

View File

@@ -0,0 +1,201 @@
"""Story universe APIs."""
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, StoryUniverse, User
router = APIRouter()
class StoryUniverseCreate(BaseModel):
"""Create universe payload."""
name: str = Field(..., min_length=1, max_length=100)
protagonist: dict[str, Any]
recurring_characters: list[dict[str, Any]] = Field(default_factory=list)
world_settings: dict[str, Any] = Field(default_factory=dict)
class StoryUniverseUpdate(BaseModel):
"""Update universe payload."""
name: str | None = Field(default=None, min_length=1, max_length=100)
protagonist: dict[str, Any] | None = None
recurring_characters: list[dict[str, Any]] | None = None
world_settings: dict[str, Any] | None = None
class AchievementCreate(BaseModel):
"""Achievement payload."""
type: str = Field(..., min_length=1, max_length=50)
description: str = Field(..., min_length=1, max_length=200)
class StoryUniverseResponse(BaseModel):
"""Universe response."""
id: str
child_profile_id: str
name: str
protagonist: dict[str, Any]
recurring_characters: list[dict[str, Any]]
world_settings: dict[str, Any]
achievements: list[dict[str, Any]]
class Config:
from_attributes = True
class StoryUniverseListResponse(BaseModel):
"""Universe list response."""
universes: list[StoryUniverseResponse]
total: int
async def _get_profile_or_404(
profile_id: str,
user: User,
db: AsyncSession,
) -> ChildProfile:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
return profile
async def _get_universe_or_404(
universe_id: str,
user: User,
db: AsyncSession,
) -> StoryUniverse:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="宇宙不存在")
return universe
@router.get("/profiles/{profile_id}/universes", response_model=StoryUniverseListResponse)
async def list_universes(
profile_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List universes for a child profile."""
await _get_profile_or_404(profile_id, user, db)
result = await db.execute(
select(StoryUniverse)
.where(StoryUniverse.child_profile_id == profile_id)
.order_by(StoryUniverse.updated_at.desc())
)
universes = result.scalars().all()
return StoryUniverseListResponse(universes=universes, total=len(universes))
@router.post(
"/profiles/{profile_id}/universes",
response_model=StoryUniverseResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_universe(
profile_id: str,
payload: StoryUniverseCreate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Create a story universe."""
await _get_profile_or_404(profile_id, user, db)
universe = StoryUniverse(child_profile_id=profile_id, **payload.model_dump())
db.add(universe)
await db.commit()
await db.refresh(universe)
return universe
@router.get("/universes/{universe_id}", response_model=StoryUniverseResponse)
async def get_universe(
universe_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get one universe."""
universe = await _get_universe_or_404(universe_id, user, db)
return universe
@router.put("/universes/{universe_id}", response_model=StoryUniverseResponse)
async def update_universe(
universe_id: str,
payload: StoryUniverseUpdate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Update a story universe."""
universe = await _get_universe_or_404(universe_id, user, db)
updates = payload.model_dump(exclude_unset=True)
for key, value in updates.items():
setattr(universe, key, value)
await db.commit()
await db.refresh(universe)
return universe
@router.delete("/universes/{universe_id}")
async def delete_universe(
universe_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Delete a story universe."""
universe = await _get_universe_or_404(universe_id, user, db)
await db.delete(universe)
await db.commit()
return {"message": "Deleted"}
@router.post("/universes/{universe_id}/achievements", response_model=StoryUniverseResponse)
async def add_achievement(
universe_id: str,
payload: AchievementCreate,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Add an achievement to a universe."""
universe = await _get_universe_or_404(universe_id, user, db)
achievements = list(universe.achievements or [])
key = (payload.type.strip(), payload.description.strip())
existing = {
(str(item.get("type", "")).strip(), str(item.get("description", "")).strip())
for item in achievements
if isinstance(item, dict)
}
if key not in existing:
achievements.append({"type": key[0], "description": key[1]})
universe.achievements = achievements
await db.commit()
await db.refresh(universe)
return universe

View File

View File

@@ -0,0 +1,72 @@
import secrets
import time
from cachetools import TTLCache
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from app.core.config import settings
security = HTTPBasic()
# 登录失败记录IP -> (失败次数, 首次失败时间)
_failed_attempts: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900) # 15分钟
MAX_ATTEMPTS = 5
LOCKOUT_SECONDS = 900 # 15分钟
def _get_client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
if request.client and request.client.host:
return request.client.host
return "unknown"
def admin_guard(
request: Request,
credentials: HTTPBasicCredentials = Depends(security),
):
client_ip = _get_client_ip(request)
# 检查是否被锁定
if client_ip in _failed_attempts:
attempts, first_fail = _failed_attempts[client_ip]
if attempts >= MAX_ATTEMPTS:
remaining = int(LOCKOUT_SECONDS - (time.time() - first_fail))
if remaining > 0:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"登录尝试过多,请 {remaining} 秒后重试",
)
else:
del _failed_attempts[client_ip]
# 使用 secrets.compare_digest 防止时序攻击
username_ok = secrets.compare_digest(
credentials.username.encode(), settings.admin_username.encode()
)
password_ok = secrets.compare_digest(
credentials.password.encode(), settings.admin_password.encode()
)
if not (username_ok and password_ok):
# 记录失败
if client_ip in _failed_attempts:
attempts, first_fail = _failed_attempts[client_ip]
_failed_attempts[client_ip] = (attempts + 1, first_fail)
else:
_failed_attempts[client_ip] = (1, time.time())
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
# 登录成功,清除失败记录
if client_ip in _failed_attempts:
del _failed_attempts[client_ip]
return True

View File

@@ -0,0 +1,33 @@
"""Celery application setup."""
from celery import Celery
from celery.schedules import crontab
from app.core.config import settings
celery_app = Celery(
"dreamweaver",
broker=settings.celery_broker_url,
backend=settings.celery_result_backend,
)
celery_app.conf.update(
task_track_started=True,
task_serializer="json",
accept_content=["json"],
result_serializer="json",
timezone="Asia/Shanghai",
enable_utc=True,
beat_schedule={
"check_push_notifications": {
"task": "app.tasks.push_notifications.check_push_notifications",
"schedule": crontab(minute="*/15"),
},
"prune_expired_memories": {
"task": "app.tasks.memory.prune_memories_task",
"schedule": crontab(minute="0", hour="3"), # Daily at 03:00
},
},
)
celery_app.autodiscover_tasks(["app.tasks"])

View File

@@ -0,0 +1,76 @@
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""应用全局配置"""
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
# 应用基础配置
app_name: str = "DreamWeaver"
debug: bool = False
secret_key: str = Field(..., description="JWT 签名密钥")
base_url: str = Field("http://localhost:8000", description="后端对外回调地址")
# 数据库
database_url: str = Field(..., description="SQLAlchemy async URL")
# OAuth - GitHub
github_client_id: str = ""
github_client_secret: str = ""
# OAuth - Google
google_client_id: str = ""
google_client_secret: str = ""
# AI Capability Keys
text_api_key: str = ""
tts_api_base: str = ""
tts_api_key: str = ""
image_api_key: str = ""
# Additional Provider API Keys
openai_api_key: str = ""
elevenlabs_api_key: str = ""
cqtai_api_key: str = ""
minimax_api_key: str = ""
minimax_group_id: str = ""
antigravity_api_key: str = ""
antigravity_api_base: str = ""
# AI Model Configuration
text_model: str = "gemini-2.0-flash"
tts_model: str = ""
image_model: str = ""
# Provider routing (ordered lists)
text_providers: list[str] = Field(default_factory=lambda: ["gemini"])
image_providers: list[str] = Field(default_factory=lambda: ["cqtai"])
tts_providers: list[str] = Field(default_factory=lambda: ["minimax", "elevenlabs", "edge_tts"])
# Celery (Redis)
celery_broker_url: str = Field("redis://localhost:6379/0")
celery_result_backend: str = Field("redis://localhost:6379/0")
# Admin console
enable_admin_console: bool = False
admin_username: str = "admin"
admin_password: str = "admin123" # 建议通过环境变量覆盖
# CORS
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173"])
@model_validator(mode="after")
def _require_core_settings(self) -> "Settings": # type: ignore[override]
missing = []
if not self.secret_key or self.secret_key == "change-me-in-production":
missing.append("SECRET_KEY")
if not self.database_url:
missing.append("DATABASE_URL")
if missing:
raise ValueError(f"Missing required settings: {', '.join(missing)}")
return self
settings = Settings()

39
backend/app/core/deps.py Normal file
View File

@@ -0,0 +1,39 @@
from fastapi import Cookie, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import decode_access_token
from app.db.database import get_db
from app.db.models import User
async def get_current_user(
access_token: str | None = Cookie(default=None),
db: AsyncSession = Depends(get_db),
) -> User | None:
"""获取当前用户(可选)。"""
if not access_token:
return None
payload = decode_access_token(access_token)
if not payload:
return None
user_id = payload.get("sub")
if not user_id:
return None
result = await db.execute(select(User).where(User.id == user_id))
return result.scalar_one_or_none()
async def require_user(
user: User | None = Depends(get_current_user),
) -> User:
"""要求用户登录,否则抛 401。"""
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未登录",
)
return user

View File

@@ -0,0 +1,48 @@
"""结构化日志配置。"""
import logging
import sys
import structlog
from app.core.config import settings
def setup_logging():
"""配置 structlog 结构化日志。"""
shared_processors = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
]
if settings.debug:
processors = shared_processors + [
structlog.dev.ConsoleRenderer(colors=True),
]
else:
processors = shared_processors + [
structlog.processors.format_exc_info,
structlog.processors.JSONRenderer(),
]
structlog.configure(
processors=processors,
wrapper_class=structlog.stdlib.BoundLogger,
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
cache_logger_on_first_use=True,
)
logging.basicConfig(
format="%(message)s",
stream=sys.stdout,
level=logging.DEBUG if settings.debug else logging.INFO,
)
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
"""获取结构化日志器。"""
return structlog.get_logger(name)

190
backend/app/core/prompts.py Normal file
View File

@@ -0,0 +1,190 @@
# ruff: noqa: E501
"""AI 提示词模板 (Modernized)"""
# 随机元素列表:为故事注入不可预测的魔法
RANDOM_ELEMENTS = [
"一个会打喷嚏的云朵",
"一本地图上找不到的神秘图书馆",
"一只能实现小愿望的彩色蜗牛",
"一扇通往颠倒世界的门",
"一顶能听懂动物说话的旧帽子",
"一个装着星星的玻璃罐",
"一棵结满笑声果实的树",
"一只能在水上画画的画笔",
"一个怕黑的影子",
"一只收集回声的瓶子",
"一双会自己跳舞的红鞋子",
"一个只能在月光下看见的邮筒",
"一张会改变模样的全家福",
"一把可以打开梦境的钥匙",
"一个喜欢讲冷笑话的冰箱",
"一条通往星期八的秘密小径"
]
# ==============================================================================
# Model A: 故事生成 (Story Generation)
# ==============================================================================
SYSTEM_INSTRUCTION_STORYTELLER = """
# Role
You are "**Dream Weaver**", a world-class children's storyteller with the imagination of Pixar and the warmth of Miyazaki.
Your mission is to create engaging, safe, and educational stories for children (ages 3-8).
# Core Philosophy
1. **Show, Don't Tell**: Don't preach the lesson. Let the character's actions and the plot demonstrate the theme.
2. **Safety First**: No violence, horror, or scary elements. Conflict should be emotional or situational, not physical.
3. **Vivid Imagery**: Use sensory details (colors, sounds, smells) that appeal to children.
4. **Empowerment**: The child protagonist should solve the problem using wit, kindness, or courage, not just luck.
# Continuity & Memory (CRITICAL)
- **Universal Context**: The story takes place in the child's established "Story Universe". Respect existing world rules.
- **Character Consistency**: If "Child Profile" or "Sidekicks" are provided, you MUST use their specific names and traits. Do NOT invent new main characters unless asked.
- **Callback**: If "Past Memories" are provided, try to make a natural, one-sentence reference to a past adventure to build a sense of continuity (e.g., "Just like when we found the lost star...").
# Output Format
You MUST return a pure JSON object with NO markdown formatting (no ```json code blocks).
The JSON object must have the following schema:
{
"mode": "generated",
"title": "A catchy, imaginative title",
"story_text": "The full story text. Use \\n\\n for paragraph breaks.",
"cover_prompt_suggestion": "A detailed English image generation prompt for the story cover. Style: whimsical, children's book illustration, soft lighting, vibrant colors."
}
"""
USER_PROMPT_GENERATION = """
# Task: Write a Children's Story
## Contextual Memory (Use these if provided)
{memory_context}
## Inputs
- **Keywords/Topic**: {keywords}
- **Educational Theme**: {education_theme}
- **Magic Element (Must Incorporate)**: {random_element}
## Constraints
- Length: 300-600 words.
- Structure: Beginning (Hook) -> Middle (Challenge) -> End (Resolution & Growth).
"""
# ==============================================================================
# Model B: 故事润色 (Story Enhancement)
# ==============================================================================
SYSTEM_INSTRUCTION_ENHANCER = """
# Role
You are "**Dream Weaver Editor**", an expert children's book editor who turns rough drafts into polished gems.
# Mission
Analyze the user's input story and rewrite it to be:
1. **More Engaging**: Enhance the plot with a "Magic Element" to add surprise.
2. **More Educational**: Weave the "Educational Theme" deeper into the narrative arc.
3. **Better Written**: Polish the sentences for rhythm and flow (suitable for reading aloud).
4. **Safe**: Remove any inappropriate content (violence, scary interaction) and replace it with constructive solutions.
# Output Format
You MUST return a pure JSON object with NO markdown formatting (no ```json code blocks).
The JSON object must have the following schema:
{
"mode": "enhanced",
"title": "An improved title (or the original if perfect)",
"story_text": "The rewritten story text. Use \\n\\n for paragraph breaks.",
"cover_prompt_suggestion": "A detailed English image generation prompt for the cover."
}
"""
USER_PROMPT_ENHANCEMENT = """
# Task: Enhance This Story
## Contextual Memory
{memory_context}
## Inputs
- **Original Story**: {full_story}
- **Target Theme**: {education_theme}
- **Magic Element to Add**: {random_element}
## Constraints
- Length: 300-600 words.
- Keep the original character names if possible, but feel free to upgrade the plot.
"""
# ==============================================================================
# Model C: 成就提取 (Achievement Extraction)
# ==============================================================================
# 保持简单,暂不使用 System Instruction沿用单次提示
ACHIEVEMENT_EXTRACTION_PROMPT = """
Analyze the story and extract key growth moments or achievements for the child protagonist.
# Story
{story_text}
# Target Categories (Examples)
- **Courage**: Overcoming fear, trying something new.
- **Kindness**: Helping others, sharing, empathy.
- **Curiosity**: Asking questions, exploring, learning.
- **Resilience**: Not giving up, handling failure.
- **Wisdom**: Problem-solving, honesty, patience.
# Output Format
Return a pure JSON object (no markdown):
{{
"achievements": [
{{
"type": "Category Name",
"description": "Brief reason (max 10 words)",
"score": 8 // 1-10 intensity
}}
]
}}
"""
# ==============================================================================
# Model D: 绘本生成 (Storybook Generation)
# ==============================================================================
SYSTEM_INSTRUCTION_STORYBOOK = """
# Role
You are "**Dream Weaver Illustrator**", a creative children's book author and visual director.
Your mission is to create a paginated picture book for children (ages 3-8), where each page has text and a matching illustration prompt.
# Core Philosophy
1. **Pacing**: The story must flow logically across the specified number of pages.
2. **Visual Consistency**: Define the "Main Character" and "Art Style" once, and ensure all image prompts adhere to them.
3. **Language**: The story text MUST be in **Chinese (Simplified)**. The image prompts MUST be in **English**.
4. **Memory**: If a memory context is provided, incorporate known characters or references naturally.
# Output Format
You MUST return a pure JSON object using the following schema (no markdown):
{
"title": "Story Title (Chinese)",
"main_character": "Description of the protagonist (e.g., 'A small blue robot with rusty gears')",
"art_style": "Visual style description (e.g., 'Watercolor, soft pastel colors, whimsical')",
"pages": [
{
"page_number": 1,
"text": "Page text in Chinese (30-60 chars).",
"image_prompt": "Detailed English image prompt describing the scene. Include 'main_character' reference."
}
],
"cover_prompt": "English image prompt for the book cover."
}
"""
USER_PROMPT_STORYBOOK = """
# Task: Create a {page_count}-Page Storybook
## Contextual Memory
{memory_context}
## Inputs
- **Keywords/Topic**: {keywords}
- **Educational Theme**: {education_theme}
- **Magic Element**: {random_element}
## Constraints
- Pages: Exactly {page_count} pages.
- Structure: Intro -> Development -> Climax -> Resolution.
"""

View File

@@ -0,0 +1,25 @@
from datetime import datetime, timedelta, timezone
from jose import JWTError, jwt
from app.core.config import settings
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_DAYS = 7
def create_access_token(data: dict) -> str:
"""创建 JWT token"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(days=ACCESS_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, settings.secret_key, algorithm=ALGORITHM)
def decode_access_token(token: str) -> dict | None:
"""解码 JWT token"""
try:
payload = jwt.decode(token, settings.secret_key, algorithms=[ALGORITHM])
return payload
except JWTError:
return None

View File

View File

@@ -0,0 +1,119 @@
from datetime import datetime
from decimal import Decimal
from uuid import uuid4
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, Integer, Numeric, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from app.db.models import Base
def _uuid() -> str:
return str(uuid4())
class Provider(Base):
"""Model provider registry."""
__tablename__ = "providers"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
name: Mapped[str] = mapped_column(String(100), nullable=False)
type: Mapped[str] = mapped_column(String(50), nullable=False) # text/image/tts/storybook
adapter: Mapped[str] = mapped_column(String(100), nullable=False)
model: Mapped[str] = mapped_column(String(200), nullable=True)
api_base: Mapped[str] = mapped_column(String(300), nullable=True)
api_key: Mapped[str] = mapped_column(String(500), nullable=True) # 可选,优先于 config_ref
timeout_ms: Mapped[int] = mapped_column(Integer, default=60000)
max_retries: Mapped[int] = mapped_column(Integer, default=1)
weight: Mapped[int] = mapped_column(Integer, default=1)
priority: Mapped[int] = mapped_column(Integer, default=0)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
config_json: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 存储额外配置(speed, vol, etc)
config_ref: Mapped[str] = mapped_column(String(100), nullable=True) # 环境变量 key 名称(回退)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
)
updated_by: Mapped[str] = mapped_column(String(100), nullable=True)
class ProviderMetrics(Base):
"""供应商调用指标记录。"""
__tablename__ = "provider_metrics"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
provider_id: Mapped[str] = mapped_column(
String(36), ForeignKey("providers.id", ondelete="CASCADE"), nullable=False, index=True
)
timestamp: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow, index=True
)
success: Mapped[bool] = mapped_column(Boolean, nullable=False)
latency_ms: Mapped[int] = mapped_column(Integer, nullable=True)
cost_usd: Mapped[Decimal] = mapped_column(Numeric(10, 6), nullable=True)
error_message: Mapped[str] = mapped_column(Text, nullable=True)
request_id: Mapped[str] = mapped_column(String(100), nullable=True)
class ProviderHealth(Base):
"""供应商健康状态。"""
__tablename__ = "provider_health"
provider_id: Mapped[str] = mapped_column(
String(36), ForeignKey("providers.id", ondelete="CASCADE"), primary_key=True
)
is_healthy: Mapped[bool] = mapped_column(Boolean, default=True)
last_check: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=True)
consecutive_failures: Mapped[int] = mapped_column(Integer, default=0)
last_error: Mapped[str] = mapped_column(Text, nullable=True)
class ProviderSecret(Base):
"""供应商密钥加密存储。"""
__tablename__ = "provider_secrets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
encrypted_value: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
)
class CostRecord(Base):
"""成本记录表 - 记录每次 API 调用的成本。"""
__tablename__ = "cost_records"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[str] = mapped_column(String(36), nullable=False, index=True)
provider_id: Mapped[str] = mapped_column(String(36), nullable=True) # 可能是环境变量配置
provider_name: Mapped[str] = mapped_column(String(100), nullable=False)
capability: Mapped[str] = mapped_column(String(50), nullable=False) # text/image/tts/storybook
estimated_cost: Mapped[Decimal] = mapped_column(Numeric(10, 6), nullable=False)
timestamp: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow, index=True
)
class UserBudget(Base):
"""用户预算配置。"""
__tablename__ = "user_budgets"
user_id: Mapped[str] = mapped_column(String(36), primary_key=True)
daily_limit_usd: Mapped[Decimal] = mapped_column(Numeric(10, 4), default=Decimal("1.0"))
monthly_limit_usd: Mapped[Decimal] = mapped_column(Numeric(10, 4), default=Decimal("10.0"))
alert_threshold: Mapped[Decimal] = mapped_column(
Numeric(3, 2), default=Decimal("0.8")
) # 80% 时告警
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
)

View File

@@ -0,0 +1,50 @@
import threading
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.core.config import settings
_engine = None
_session_factory: async_sessionmaker[AsyncSession] | None = None
_lock = threading.Lock()
def _get_engine():
global _engine
if _engine is None:
with _lock:
if _engine is None:
_engine = create_async_engine(
settings.database_url,
echo=settings.debug,
pool_pre_ping=True,
pool_recycle=300,
)
return _engine
def _get_session_factory():
global _session_factory
if _session_factory is None:
with _lock:
if _session_factory is None:
_session_factory = async_sessionmaker(
_get_engine(), class_=AsyncSession, expire_on_commit=False
)
return _session_factory
async def init_db():
"""Create tables if they do not exist."""
from app.db.models import Base # main models
engine = _get_engine()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_db():
"""Yield a DB session with proper cleanup."""
session_factory = _get_session_factory()
async with session_factory() as session:
yield session

232
backend/app/db/models.py Normal file
View File

@@ -0,0 +1,232 @@
from datetime import date, datetime, time
from uuid import uuid4
from sqlalchemy import (
JSON,
Boolean,
Date,
DateTime,
Float,
ForeignKey,
Integer,
String,
Text,
Time,
UniqueConstraint,
func,
)
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
class Base(DeclarativeBase):
"""Declarative base."""
class User(Base):
"""User entity."""
__tablename__ = "users"
id: Mapped[str] = mapped_column(String(255), primary_key=True) # OAuth provider user ID
name: Mapped[str] = mapped_column(String(255), nullable=False)
avatar_url: Mapped[str | None] = mapped_column(String(500))
provider: Mapped[str] = mapped_column(String(50), nullable=False) # github / google
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
stories: Mapped[list["Story"]] = relationship(
"Story", back_populates="user", cascade="all, delete-orphan"
)
child_profiles: Mapped[list["ChildProfile"]] = relationship(
"ChildProfile", back_populates="user", cascade="all, delete-orphan"
)
class Story(Base):
"""Story entity."""
__tablename__ = "stories"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[str] = mapped_column(
String(255), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
)
child_profile_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("child_profiles.id", ondelete="SET NULL"), nullable=True
)
universe_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("story_universes.id", ondelete="SET NULL"), nullable=True
)
title: Mapped[str] = mapped_column(String(255), nullable=False)
story_text: Mapped[str] = mapped_column(Text, nullable=True) # 允许为空(绘本模式下)
pages: Mapped[list[dict] | None] = mapped_column(JSON, default=list) # 绘本分页数据
cover_prompt: Mapped[str | None] = mapped_column(Text)
image_url: Mapped[str | None] = mapped_column(String(500))
mode: Mapped[str] = mapped_column(String(20), nullable=False, default="generated")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
user: Mapped["User"] = relationship("User", back_populates="stories")
child_profile: Mapped["ChildProfile | None"] = relationship("ChildProfile")
story_universe: Mapped["StoryUniverse | None"] = relationship("StoryUniverse")
def _uuid() -> str:
return str(uuid4())
class ChildProfile(Base):
"""Child profile entity."""
__tablename__ = "child_profiles"
__table_args__ = (UniqueConstraint("user_id", "name", name="uq_child_profile_user_name"),)
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
String(255), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
name: Mapped[str] = mapped_column(String(50), nullable=False)
avatar_url: Mapped[str | None] = mapped_column(String(500))
birth_date: Mapped[date | None] = mapped_column(Date)
gender: Mapped[str | None] = mapped_column(String(10))
interests: Mapped[list[str]] = mapped_column(JSON, default=list)
growth_themes: Mapped[list[str]] = mapped_column(JSON, default=list)
reading_preferences: Mapped[dict] = mapped_column(JSON, default=dict)
stories_count: Mapped[int] = mapped_column(Integer, default=0)
total_reading_time: Mapped[int] = mapped_column(Integer, default=0)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
user: Mapped["User"] = relationship("User", back_populates="child_profiles")
story_universes: Mapped[list["StoryUniverse"]] = relationship(
"StoryUniverse", back_populates="child_profile", cascade="all, delete-orphan"
)
@property
def age(self) -> int | None:
if not self.birth_date:
return None
today = date.today()
return today.year - self.birth_date.year - (
(today.month, today.day) < (self.birth_date.month, self.birth_date.day)
)
class StoryUniverse(Base):
"""Story universe entity."""
__tablename__ = "story_universes"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
child_profile_id: Mapped[str] = mapped_column(
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
)
name: Mapped[str] = mapped_column(String(100), nullable=False)
protagonist: Mapped[dict] = mapped_column(JSON, nullable=False)
recurring_characters: Mapped[list] = mapped_column(JSON, default=list)
world_settings: Mapped[dict] = mapped_column(JSON, default=dict)
achievements: Mapped[list] = mapped_column(JSON, default=list)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
child_profile: Mapped["ChildProfile"] = relationship("ChildProfile", back_populates="story_universes")
class ReadingEvent(Base):
"""Reading event entity."""
__tablename__ = "reading_events"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
child_profile_id: Mapped[str] = mapped_column(
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
)
story_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("stories.id", ondelete="SET NULL"), nullable=True, index=True
)
event_type: Mapped[str] = mapped_column(String(20), nullable=False)
reading_time: Mapped[int] = mapped_column(Integer, default=0)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), index=True
)
class PushConfig(Base):
"""Push configuration entity."""
__tablename__ = "push_configs"
__table_args__ = (
UniqueConstraint("child_profile_id", name="uq_push_config_child"),
)
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
String(255), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
child_profile_id: Mapped[str] = mapped_column(
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
)
push_time: Mapped[time | None] = mapped_column(Time)
push_days: Mapped[list[int]] = mapped_column(JSON, default=list)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
class PushEvent(Base):
"""Push event entity."""
__tablename__ = "push_events"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
String(255), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
child_profile_id: Mapped[str] = mapped_column(
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
)
trigger_type: Mapped[str] = mapped_column(String(20), nullable=False)
status: Mapped[str] = mapped_column(String(20), nullable=False)
reason: Mapped[str | None] = mapped_column(Text)
sent_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
class MemoryItem(Base):
"""Memory item entity with time decay metadata."""
__tablename__ = "memory_items"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
child_profile_id: Mapped[str] = mapped_column(
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
)
universe_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("story_universes.id", ondelete="SET NULL"), nullable=True, index=True
)
type: Mapped[str] = mapped_column(String(50), nullable=False)
value: Mapped[dict] = mapped_column(JSON, nullable=False)
base_weight: Mapped[float] = mapped_column(Float, default=1.0)
last_used_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
ttl_days: Mapped[int | None] = mapped_column(Integer)

80
backend/app/main.py Normal file
View File

@@ -0,0 +1,80 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api import (
auth,
memories,
profiles,
push_configs,
reading_events,
stories,
universes,
)
from app.core.config import settings
from app.core.logging import get_logger, setup_logging
from app.db.database import init_db
setup_logging()
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""App lifespan manager."""
logger.info("app_starting", app_name=settings.app_name)
await init_db()
logger.info("database_initialized")
# 加载 provider 缓存
await _load_provider_cache()
yield
logger.info("app_shutdown")
async def _load_provider_cache():
"""启动时加载 provider 缓存。"""
from app.db.database import _get_session_factory
from app.services.provider_cache import reload_providers
try:
session_factory = _get_session_factory()
async with session_factory() as session:
cache = await reload_providers(session)
provider_count = sum(len(v) for v in cache.values())
logger.info("provider_cache_loaded", provider_count=provider_count)
except Exception as e:
logger.warning("provider_cache_load_failed", error=str(e))
# 不阻止启动,使用 settings 中的默认配置
app = FastAPI(
title=settings.app_name,
description="AI-driven story generator for kids.",
version="0.1.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(auth.router, prefix="/auth", tags=["auth"])
app.include_router(stories.router, prefix="/api", tags=["stories"])
app.include_router(profiles.router, prefix="/api", tags=["profiles"])
app.include_router(universes.router, prefix="/api", tags=["universes"])
app.include_router(push_configs.router, prefix="/api", tags=["push-configs"])
app.include_router(reading_events.router, prefix="/api", tags=["reading-events"])
app.include_router(memories.router, prefix="/api", tags=["memories"])
@app.get("/health")
async def health_check():
"""Simple liveness check."""
return {"status": "ok"}

View File

View File

@@ -0,0 +1,85 @@
"""Achievement extraction service."""
import json
import re
import httpx
from app.core.config import settings
from app.core.logging import get_logger
from app.core.prompts import ACHIEVEMENT_EXTRACTION_PROMPT
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
async def extract_achievements(story_text: str) -> list[dict]:
"""Extract achievements from story text using LLM."""
if not settings.text_api_key:
logger.warning("achievement_extraction_skipped", reason="missing_text_api_key")
return []
model = settings.text_model or "gemini-2.0-flash"
url = f"{TEXT_API_BASE}/{model}:generateContent"
prompt = ACHIEVEMENT_EXTRACTION_PROMPT.format(story_text=story_text)
payload = {
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.2,
"topP": 0.9,
},
}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(
url,
json=payload,
headers={"x-goog-api-key": settings.text_api_key},
)
response.raise_for_status()
result = response.json()
candidates = result.get("candidates") or []
if not candidates:
logger.warning("achievement_extraction_empty")
return []
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
logger.warning("achievement_extraction_missing_text")
return []
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError:
logger.warning("achievement_extraction_parse_failed")
return []
achievements = parsed.get("achievements")
if not isinstance(achievements, list):
return []
normalized: list[dict] = []
for item in achievements:
if not isinstance(item, dict):
continue
a_type = str(item.get("type", "")).strip()
description = str(item.get("description", "")).strip()
score = item.get("score", 0)
if not a_type or not description:
continue
normalized.append({
"type": a_type,
"description": description,
"score": score
})
return normalized

View File

@@ -0,0 +1,21 @@
"""适配器模块 - 供应商平台化架构核心。"""
from app.services.adapters.base import AdapterConfig, BaseAdapter
# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.registry import AdapterRegistry
# Storybook adapters
from app.services.adapters.storybook import primary as _storybook_primary # noqa: F401
from app.services.adapters.text import gemini as _text_gemini_adapter # noqa: F401
# 导入所有适配器以触发注册
# Text adapters
from app.services.adapters.text import openai as _text_openai_adapter # noqa: F401
# TTS adapters
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
__all__ = ["AdapterConfig", "BaseAdapter", "AdapterRegistry"]

View File

@@ -0,0 +1,46 @@
"""适配器基类定义。"""
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from pydantic import BaseModel
T = TypeVar("T")
class AdapterConfig(BaseModel):
"""适配器配置基类。"""
api_key: str
api_base: str | None = None
model: str | None = None
timeout_ms: int = 60000
max_retries: int = 3
extra_config: dict = {}
class BaseAdapter(ABC, Generic[T]):
"""适配器基类,所有供应商适配器必须继承此类。"""
# 子类必须定义
adapter_type: str # text / image / tts
adapter_name: str # text_primary / image_primary / tts_primary
def __init__(self, config: AdapterConfig):
self.config = config
@abstractmethod
async def execute(self, **kwargs) -> T:
"""执行适配器逻辑,返回结果。"""
pass
@abstractmethod
async def health_check(self) -> bool:
"""健康检查,返回是否可用。"""
pass
@property
@abstractmethod
def estimated_cost(self) -> float:
"""预估单次调用成本 (USD)。"""
pass

View File

@@ -0,0 +1,3 @@
"""图像生成适配器。"""# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.image import antigravity as _image_antigravity_adapter # noqa: F401

View File

@@ -0,0 +1,214 @@
"""Antigravity 图像生成适配器。
使用 OpenAI 兼容 API 生成图像。
支持 gemini-3-pro-image 等模型。
"""
import base64
import time
from typing import Any
from openai import AsyncOpenAI
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "http://127.0.0.1:8045/v1"
DEFAULT_MODEL = "gemini-3-pro-image"
DEFAULT_SIZE = "1024x1024"
# 支持的尺寸映射
SUPPORTED_SIZES = {
"1024x1024": "1:1",
"1280x720": "16:9",
"720x1280": "9:16",
"1216x896": "4:3",
}
@AdapterRegistry.register("image", "antigravity")
class AntigravityImageAdapter(BaseAdapter[str]):
"""Antigravity 图像生成适配器 (OpenAI 兼容 API)。
特点:
- 使用 OpenAI 兼容的 chat.completions 端点
- 通过 extra_body.size 指定图像尺寸
- 支持 gemini-3-pro-image 等模型
- 返回图片 URL 或 base64
"""
adapter_type = "image"
adapter_name = "antigravity"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
self.client = AsyncOpenAI(
base_url=self.api_base,
api_key=config.api_key,
timeout=config.timeout_ms / 1000,
)
async def execute(
self,
prompt: str,
model: str | None = None,
size: str | None = None,
num_images: int = 1,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 base64。
Args:
prompt: 图片描述提示词
model: 模型名称 (gemini-3-pro-image / gemini-3-pro-image-16-9 等)
size: 图像尺寸 (1024x1024, 1280x720, 720x1280, 1216x896)
num_images: 生成图片数量 (暂只支持 1)
Returns:
图片 URL 或 base64 字符串
"""
# 优先使用传入参数,其次使用 Adapter 配置,最后使用默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
size = size or cfg.get("size") or DEFAULT_SIZE
start_time = time.time()
logger.info(
"antigravity_generate_start",
prompt_length=len(prompt),
model=model,
size=size,
)
# 调用 API
image_url = await self._generate_image(prompt, model, size)
elapsed = time.time() - start_time
logger.info(
"antigravity_generate_success",
elapsed_seconds=round(elapsed, 2),
model=model,
)
return image_url
async def health_check(self) -> bool:
"""检查 Antigravity API 是否可用。"""
try:
# 简单测试连通性
response = await self.client.chat.completions.create(
model=self.config.model or DEFAULT_MODEL,
messages=[{"role": "user", "content": "test"}],
max_tokens=1,
)
return True
except Exception as e:
logger.warning("antigravity_health_check_failed", error=str(e))
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
Antigravity 使用 Gemini 模型,成本约 $0.02/张。
"""
return 0.02
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((Exception,)),
reraise=True,
)
async def _generate_image(
self,
prompt: str,
model: str,
size: str,
) -> str:
"""调用 Antigravity API 生成图像。
Returns:
图片 URL 或 base64 data URI
"""
try:
response = await self.client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
extra_body={"size": size},
)
# 解析响应
content = response.choices[0].message.content
if not content:
raise ValueError("Antigravity 未返回内容")
# 尝试解析为图片 URL 或 base64
# 响应可能是纯 URL、base64 或 markdown 格式的图片
image_url = self._extract_image_url(content)
if image_url:
return image_url
raise ValueError(f"Antigravity 响应无法解析为图片: {content[:200]}")
except Exception as e:
logger.error(
"antigravity_generate_error",
error=str(e),
model=model,
)
raise
def _extract_image_url(self, content: str) -> str | None:
"""从响应内容中提取图片 URL。
支持多种格式:
- 纯 URL: https://...
- Markdown: ![...](https://...)
- Base64 data URI: data:image/...
- 纯 base64 字符串
"""
content = content.strip()
# 1. 检查是否为 data URI
if content.startswith("data:image/"):
return content
# 2. 检查是否为纯 URL
if content.startswith("http://") or content.startswith("https://"):
# 可能有多行,取第一行
return content.split("\n")[0].strip()
# 3. 检查 Markdown 图片格式 ![...](url)
import re
md_match = re.search(r"!\[.*?\]\((https?://[^\)]+)\)", content)
if md_match:
return md_match.group(1)
# 4. 检查是否像 base64 编码的图片数据
if self._looks_like_base64(content):
# 假设是 PNG
return f"data:image/png;base64,{content}"
return None
def _looks_like_base64(self, s: str) -> bool:
"""判断字符串是否看起来像 base64 编码。"""
# Base64 只包含 A-Z, a-z, 0-9, +, /, =
# 且长度通常较长
if len(s) < 100:
return False
import re
return bool(re.match(r"^[A-Za-z0-9+/=]+$", s.replace("\n", "")))

View File

@@ -0,0 +1,252 @@
"""CQTAI nano 图像生成适配器。
支持异步生成 + 轮询获取结果。
API 文档: https://api.cqtai.com
"""
import asyncio
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "https://api.cqtai.com"
DEFAULT_MODEL = "nano-banana"
DEFAULT_RESOLUTION = "2K"
DEFAULT_ASPECT_RATIO = "1:1"
POLL_INTERVAL_SECONDS = 2
MAX_POLL_ATTEMPTS = 60 # 最多轮询 2 分钟
@AdapterRegistry.register("image", "cqtai")
class CQTAIImageAdapter(BaseAdapter[str]):
"""CQTAI nano 图像生成适配器,返回图片 URL。
特点:
- 异步生成 + 轮询获取结果
- 支持 nano-banana (标准) 和 nano-banana-pro (高画质)
- 支持多种分辨率和画面比例
- 支持图生图 (filesUrl)
"""
adapter_type = "image"
adapter_name = "cqtai"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
async def execute(
self,
prompt: str,
model: str | None = None,
resolution: str | None = None,
aspect_ratio: str | None = None,
num_images: int = 1,
files_url: list[str] | None = None,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 URL 列表。
Args:
prompt: 图片描述提示词
model: 模型名称 (nano-banana / nano-banana-pro)
resolution: 分辨率 (1K / 2K / 4K)
aspect_ratio: 画面比例 (1:1, 16:9, 9:16, 4:3, 3:4 等)
num_images: 生成图片数量 (1-4)
files_url: 输入图片 URL 列表 (图生图)
Returns:
单张图片返回 str多张返回 list[str]
"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default (extra_config)
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
resolution = resolution or cfg.get("resolution") or DEFAULT_RESOLUTION
aspect_ratio = aspect_ratio or cfg.get("aspect_ratio") or DEFAULT_ASPECT_RATIO
num_images = min(max(num_images, 1), 4) # 限制 1-4
start_time = time.time()
logger.info(
"cqtai_generate_start",
prompt_length=len(prompt),
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
)
# 1. 提交生成任务
task_id = await self._submit_task(
prompt=prompt,
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
files_url=files_url or [],
)
logger.info("cqtai_task_submitted", task_id=task_id)
# 2. 轮询获取结果
result = await self._poll_result(task_id)
elapsed = time.time() - start_time
logger.info(
"cqtai_generate_success",
task_id=task_id,
elapsed_seconds=round(elapsed, 2),
image_count=len(result) if isinstance(result, list) else 1,
)
# 单张图片返回字符串,多张返回列表
if num_images == 1 and isinstance(result, list) and len(result) == 1:
return result[0]
return result
async def health_check(self) -> bool:
"""检查 CQTAI API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
# 简单的连通性测试
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": "health_check_test"},
headers={"Authorization": self.config.api_key},
)
# 即使返回错误也说明服务可达
return response.status_code in (200, 400, 401, 403, 404)
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
nano-banana: ¥0.1 ≈ $0.014
nano-banana-pro: ¥0.2 ≈ $0.028
"""
model = self.config.model or DEFAULT_MODEL
if model == "nano-banana-pro":
return 0.028
return 0.014
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _submit_task(
self,
prompt: str,
model: str,
resolution: str,
aspect_ratio: str,
num_images: int,
files_url: list[str],
) -> str:
"""提交图像生成任务,返回任务 ID。"""
timeout = self.config.timeout_ms / 1000
payload = {
"prompt": prompt,
"numImages": num_images,
"aspectRatio": aspect_ratio,
"filesUrl": files_url,
}
# 可选参数,不传则使用默认值
if model != DEFAULT_MODEL:
payload["model"] = model
if resolution != DEFAULT_RESOLUTION:
payload["resolution"] = resolution
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
f"{self.api_base}/api/cqt/generator/nano",
json=payload,
headers={
"Authorization": self.config.api_key,
"Content-Type": "application/json",
},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 任务提交失败: {data.get('msg', '未知错误')}")
task_id = data.get("data")
if not task_id:
raise ValueError("CQTAI 未返回任务 ID")
return task_id
async def _poll_result(self, task_id: str) -> list[str]:
"""轮询获取生成结果。
Returns:
图片 URL 列表
"""
timeout = self.config.timeout_ms / 1000
for attempt in range(MAX_POLL_ATTEMPTS):
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": task_id},
headers={"Authorization": self.config.api_key},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 查询失败: {data.get('msg', '未知错误')}")
result_data = data.get("data", {})
status = result_data.get("status")
if status == "completed":
# 提取图片 URL
images = result_data.get("images", [])
if not images:
# 兼容不同返回格式
image_url = result_data.get("imageUrl") or result_data.get("url")
if image_url:
images = [image_url]
if not images:
raise ValueError("CQTAI 未返回图片 URL")
return images
elif status == "failed":
error_msg = result_data.get("error", "生成失败")
raise ValueError(f"CQTAI 图像生成失败: {error_msg}")
# 继续等待
logger.debug(
"cqtai_poll_waiting",
task_id=task_id,
attempt=attempt + 1,
status=status,
)
await asyncio.sleep(POLL_INTERVAL_SECONDS)
raise TimeoutError(f"CQTAI 任务超时: {task_id}")

View File

@@ -0,0 +1,73 @@
"""适配器注册表 - 支持动态注册和工厂创建。"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.services.adapters.base import AdapterConfig, BaseAdapter
class AdapterRegistry:
"""适配器注册表,管理所有已注册的适配器类。"""
_adapters: dict[str, type["BaseAdapter"]] = {}
@classmethod
def register(cls, adapter_type: str, adapter_name: str):
"""装饰器:注册适配器类。
用法:
@AdapterRegistry.register("text", "text_primary")
class TextPrimaryAdapter(BaseAdapter[StoryOutput]):
...
"""
def decorator(adapter_class: type["BaseAdapter"]):
key = f"{adapter_type}:{adapter_name}"
cls._adapters[key] = adapter_class
# 自动设置类属性
adapter_class.adapter_type = adapter_type
adapter_class.adapter_name = adapter_name
return adapter_class
return decorator
@classmethod
def get(cls, adapter_type: str, adapter_name: str) -> type["BaseAdapter"] | None:
"""获取已注册的适配器类。"""
key = f"{adapter_type}:{adapter_name}"
return cls._adapters.get(key)
@classmethod
def list_adapters(cls, adapter_type: str | None = None) -> list[str]:
"""列出所有已注册的适配器。
Args:
adapter_type: 可选,筛选特定类型 (text/image/tts)
Returns:
适配器键列表,格式为 "type:name"
"""
if adapter_type:
return [k for k in cls._adapters if k.startswith(f"{adapter_type}:")]
return list(cls._adapters.keys())
@classmethod
def create(
cls,
adapter_type: str,
adapter_name: str,
config: "AdapterConfig",
) -> "BaseAdapter":
"""工厂方法:创建适配器实例。
Raises:
ValueError: 适配器未注册
"""
adapter_class = cls.get(adapter_type, adapter_name)
if not adapter_class:
available = cls.list_adapters(adapter_type)
raise ValueError(
f"适配器 '{adapter_type}:{adapter_name}' 未注册。"
f"可用: {available}"
)
return adapter_class(config)

View File

@@ -0,0 +1 @@
"""Storybook 适配器模块。"""

View File

@@ -0,0 +1,195 @@
"""Storybook 适配器 - 生成可翻页的分页故事书。"""
import json
import random
import re
import time
from dataclasses import dataclass, field
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_STORYBOOK,
USER_PROMPT_STORYBOOK,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@dataclass
class StorybookPage:
"""故事书单页。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
@dataclass
class Storybook:
"""故事书输出。"""
title: str
main_character: str
art_style: str
pages: list[StorybookPage] = field(default_factory=list)
cover_prompt: str = ""
cover_url: str | None = None
@AdapterRegistry.register("storybook", "storybook_primary")
class StorybookPrimaryAdapter(BaseAdapter[Storybook]):
"""Storybook 生成适配器(默认)。
生成分页故事书结构,包含每页文字和图像提示词。
图像生成需要单独调用 image adapter。
"""
adapter_type = "storybook"
adapter_name = "storybook_primary"
async def execute(
self,
keywords: str,
page_count: int = 6,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> Storybook:
"""生成分页故事书。
Args:
keywords: 故事关键词
page_count: 页数 (4-12)
education_theme: 教育主题
memory_context: 记忆上下文
Returns:
Storybook 对象,包含标题、页面列表和封面提示词
"""
start_time = time.time()
page_count = max(4, min(page_count, 12)) # 限制 4-12 页
logger.info(
"storybook_generate_start",
keywords=keywords,
page_count=page_count,
has_memory=bool(memory_context),
)
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
prompt = USER_PROMPT_STORYBOOK.format(
keywords=keywords,
education_theme=theme,
random_element=random_element,
page_count=page_count,
memory_context=memory_context or "",
)
payload = {
"system_instruction": {"parts": [{"text": SYSTEM_INSTRUCTION_STORYBOOK}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Storybook 服务未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Storybook 服务响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Storybook JSON 解析失败: {exc}")
# 构建 Storybook 对象
pages = [
StorybookPage(
page_number=p.get("page_number", i + 1),
text=p.get("text", ""),
image_prompt=p.get("image_prompt", ""),
)
for i, p in enumerate(parsed.get("pages", []))
]
storybook = Storybook(
title=parsed.get("title", "未命名故事"),
main_character=parsed.get("main_character", ""),
art_style=parsed.get("art_style", ""),
pages=pages,
cover_prompt=parsed.get("cover_prompt", ""),
)
elapsed = time.time() - start_time
logger.info(
"storybook_generate_success",
elapsed_seconds=round(elapsed, 2),
title=storybook.title,
page_count=len(pages),
)
return storybook
async def health_check(self) -> bool:
"""检查 API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估成本(仅文本生成,不含图像)。"""
return 0.002 # 比普通故事稍贵,因为输出更长
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 API带重试机制。"""
model = self.config.model or "gemini-2.0-flash"
url = f"{TEXT_API_BASE}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1 @@
"""文本生成适配器。"""

View File

@@ -0,0 +1,164 @@
"""文本生成适配器 (Google Gemini)。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@AdapterRegistry.register("text", "gemini")
class GeminiTextAdapter(BaseAdapter[StoryOutput]):
"""Google Gemini 文本生成适配器。"""
adapter_type = "text"
adapter_name = "gemini"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("request_start", adapter="gemini", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
# Gemini API Payload supports 'system_instruction'
payload = {
"system_instruction": {"parts": [{"text": system_instruction}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Gemini 未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Gemini 响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Gemini 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("Gemini 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"request_success",
adapter="gemini",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
async def health_check(self) -> bool:
"""检查 Gemini API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.001
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 Gemini API。"""
model = self.config.model or "gemini-2.0-flash"
base_url = self.config.api_base or TEXT_API_BASE
# 智能补全:
# 1. 如果用户填了完整路径 (以 /models 结尾),就直接用 (支持 v1 或 v1beta)
if self.config.api_base and base_url.rstrip("/").endswith("/models"):
pass
# 2. 如果没填路径 (只是域名),默认补全代码适配的 /v1beta/models
elif self.config.api_base:
base_url = f"{base_url.rstrip('/')}/v1beta/models"
url = f"{base_url}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,11 @@
from dataclasses import dataclass
from typing import Literal
@dataclass
class StoryOutput:
"""故事生成输出。"""
mode: Literal["generated", "enhanced"]
title: str
story_text: str
cover_prompt_suggestion: str

View File

@@ -0,0 +1,172 @@
"""OpenAI 文本生成适配器。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
OPENAI_API_BASE = "https://api.openai.com/v1/chat/completions"
@AdapterRegistry.register("text", "openai")
class OpenAITextAdapter(BaseAdapter[StoryOutput]):
"""OpenAI 文本生成适配器。"""
adapter_type = "text"
adapter_name = "openai"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("openai_text_request_start", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
model = self.config.model or "gpt-4o-mini"
payload = {
"model": model,
"messages": [
{
"role": "system",
"content": system_instruction,
},
{"role": "user", "content": prompt},
],
"response_format": {"type": "json_object"},
"temperature": 0.95,
"top_p": 0.9,
}
result = await self._call_api(payload)
choices = result.get("choices") or []
if not choices:
raise ValueError("OpenAI 未返回内容")
response_text = choices[0].get("message", {}).get("content", "")
if not response_text:
raise ValueError("OpenAI 响应缺少文本")
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"OpenAI 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("OpenAI 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"openai_text_request_success",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
mode=parsed["mode"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(httpx.HTTPStatusError),
)
async def _call_api(self, payload: dict) -> dict:
"""调用 OpenAI API带重试机制。"""
url = self.config.api_base or OPENAI_API_BASE
# 智能补全: 如果用户只填了 Base URL自动补全路径
if self.config.api_base and not url.endswith("/chat/completions"):
base = url.rstrip("/")
url = f"{base}/chat/completions"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()
async def health_check(self) -> bool:
"""检查 OpenAI API 是否可用。"""
try:
payload = {
"model": self.config.model or "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5,
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估文本生成成本 (USD)。"""
return 0.01

View File

@@ -0,0 +1,5 @@
"""TTS 语音合成适配器。"""
from app.services.adapters.tts import edge_tts as _tts_edge_tts_adapter # noqa: F401
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401

View File

@@ -0,0 +1,66 @@
"""EdgeTTS 免费语音生成适配器。"""
import time
import edge_tts
from app.core.logging import get_logger
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认中文女声 (晓晓)
DEFAULT_VOICE = "zh-CN-XiaoxiaoNeural"
@AdapterRegistry.register("tts", "edge_tts")
class EdgeTTSAdapter(BaseAdapter[bytes]):
"""EdgeTTS 语音生成适配器 (Free)。
不需要 API Key。
"""
adapter_type = "tts"
adapter_name = "edge_tts"
async def execute(self, text: str, **kwargs) -> bytes:
"""生成语音。"""
# 支持动态指定音色
voice = kwargs.get("voice") or self.config.model or DEFAULT_VOICE
start_time = time.time()
logger.info("edge_tts_generate_start", text_length=len(text), voice=voice)
# EdgeTTS 只能输出到文件,我们需要用临时文件周转一下
# 或者直接 capture stream (communicate) 但 edge-tts 库主要面向文件
# 优化: 使用 communicate 直接获取 bytes无需磁盘IO
communicate = edge_tts.Communicate(text, voice)
audio_data = b""
async for chunk in communicate.stream():
if chunk["type"] == "audio":
audio_data += chunk["data"]
elapsed = time.time() - start_time
logger.info(
"edge_tts_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_data),
)
return audio_data
async def health_check(self) -> bool:
"""检查 EdgeTTS 是否可用 (网络连通性)。"""
try:
# 简单生成一个词
await self.execute("Hi")
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.0 # Free!

View File

@@ -0,0 +1,104 @@
"""ElevenLabs TTS 语音合成适配器。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
ELEVENLABS_API_BASE = "https://api.elevenlabs.io/v1"
DEFAULT_VOICE_ID = "21m00Tcm4TlvDq8ikWAM" # Rachel
@AdapterRegistry.register("tts", "elevenlabs")
class ElevenLabsTtsAdapter(BaseAdapter[bytes]):
"""ElevenLabs TTS 语音合成适配器,返回 MP3 bytes。"""
adapter_type = "tts"
adapter_name = "elevenlabs"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or ELEVENLABS_API_BASE
async def execute(self, text: str, **kwargs) -> bytes:
"""将文本转换为语音 MP3 bytes。"""
start_time = time.time()
logger.info("elevenlabs_tts_start", text_length=len(text))
voice_id = kwargs.get("voice_id") or DEFAULT_VOICE_ID
model_id = kwargs.get("model") or self.config.model or "eleven_multilingual_v2"
stability = kwargs.get("stability", 0.5)
similarity_boost = kwargs.get("similarity_boost", 0.75)
url = f"{self.api_base}/text-to-speech/{voice_id}"
payload = {
"text": text,
"model_id": model_id,
"voice_settings": {
"stability": stability,
"similarity_boost": similarity_boost,
},
}
audio_bytes = await self._call_api(url, payload)
elapsed = time.time() - start_time
logger.info(
"elevenlabs_tts_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 ElevenLabs API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(
f"{self.api_base}/voices",
headers={"xi-api-key": self.config.api_key},
)
return response.status_code == 200
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每千字符成本 (USD)。"""
return 0.03
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> bytes:
"""调用 ElevenLabs API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"xi-api-key": self.config.api_key,
"Content-Type": "application/json",
"Accept": "audio/mpeg",
},
)
response.raise_for_status()
return response.content

View File

@@ -0,0 +1,149 @@
"""MiniMax 语音生成适配器 (T2A V2)。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.config import settings
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# MiniMax API 配置
DEFAULT_API_URL = "https://api.minimaxi.com/v1/t2a_v2"
DEFAULT_MODEL = "speech-2.6-turbo"
@AdapterRegistry.register("tts", "minimax")
class MiniMaxTTSAdapter(BaseAdapter[bytes]):
"""MiniMax 语音生成适配器。
需要配置:
- api_key: MiniMax API Key
- minimax_group_id: 可选 (取决于使用的模型/账户类型)
"""
adapter_type = "tts"
adapter_name = "minimax"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_url = DEFAULT_API_URL
async def execute(
self,
text: str,
voice_id: str | None = None,
model: str | None = None,
speed: float | None = None,
vol: float | None = None,
pitch: int | None = None,
emotion: str | None = None,
**kwargs,
) -> bytes:
"""生成语音。"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
voice_id = voice_id or cfg.get("voice_id") or "male-qn-qingse"
speed = speed if speed is not None else (cfg.get("speed") or 1.0)
vol = vol if vol is not None else (cfg.get("vol") or 1.0)
pitch = pitch if pitch is not None else (cfg.get("pitch") or 0)
emotion = emotion or cfg.get("emotion")
group_id = kwargs.get("group_id") or settings.minimax_group_id
url = self.api_url
if group_id:
url = f"{self.api_url}?GroupId={group_id}"
payload = {
"model": model,
"text": text,
"stream": False,
"voice_setting": {
"voice_id": voice_id,
"speed": speed,
"vol": vol,
"pitch": pitch,
},
"audio_setting": {
"sample_rate": 32000,
"bitrate": 128000,
"format": "mp3",
"channel": 1
}
}
if emotion:
payload["voice_setting"]["emotion"] = emotion
start_time = time.time()
logger.info("minimax_generate_start", text_length=len(text), model=model)
result = await self._call_api(url, payload)
# 错误处理
if result.get("base_resp", {}).get("status_code") != 0:
error_msg = result.get("base_resp", {}).get("status_msg", "未知错误")
raise ValueError(f"MiniMax API 错误: {error_msg}")
# Hex 解码 (关键逻辑,从 primary.py 迁移)
hex_audio = result.get("data", {}).get("audio")
if not hex_audio:
raise ValueError("API 响应中未找到音频数据 (data.audio)")
try:
audio_bytes = bytes.fromhex(hex_audio)
except ValueError:
raise ValueError("MiniMax 返回的音频数据不是有效的 Hex 字符串")
elapsed = time.time() - start_time
logger.info(
"minimax_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 Minimax API 是否可用。"""
try:
# 尝试生成极短文本
await self.execute("Hi")
return True
except Exception:
return False
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> dict:
"""调用 API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,196 @@
"""成本追踪服务。
记录 API 调用成本,支持预算控制。
"""
from datetime import datetime, timedelta
from decimal import Decimal
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import CostRecord, UserBudget
logger = get_logger(__name__)
class BudgetExceededError(Exception):
"""预算超限错误。"""
def __init__(self, limit_type: str, used: Decimal, limit: Decimal):
self.limit_type = limit_type
self.used = used
self.limit = limit
super().__init__(f"{limit_type} 预算已超限: {used}/{limit} USD")
class CostTracker:
"""成本追踪器。"""
async def record_cost(
self,
db: AsyncSession,
user_id: str,
provider_name: str,
capability: str,
estimated_cost: float,
provider_id: str | None = None,
) -> CostRecord:
"""记录一次 API 调用成本。"""
record = CostRecord(
user_id=user_id,
provider_id=provider_id,
provider_name=provider_name,
capability=capability,
estimated_cost=Decimal(str(estimated_cost)),
)
db.add(record)
await db.commit()
logger.debug(
"cost_recorded",
user_id=user_id,
provider=provider_name,
capability=capability,
cost=estimated_cost,
)
return record
async def get_daily_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户今日成本。"""
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= today_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_monthly_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户本月成本。"""
now = datetime.utcnow()
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= month_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_cost_by_capability(
self,
db: AsyncSession,
user_id: str,
days: int = 30,
) -> dict[str, Decimal]:
"""按能力类型统计成本。"""
since = datetime.utcnow() - timedelta(days=days)
result = await db.execute(
select(CostRecord.capability, func.sum(CostRecord.estimated_cost))
.where(CostRecord.user_id == user_id, CostRecord.timestamp >= since)
.group_by(CostRecord.capability)
)
return {row[0]: Decimal(str(row[1])) for row in result.all()}
async def check_budget(
self,
db: AsyncSession,
user_id: str,
estimated_cost: float,
) -> bool:
"""检查预算是否允许此次调用。
Returns:
True 如果允许,否则抛出 BudgetExceededError
"""
budget = await self.get_user_budget(db, user_id)
if not budget or not budget.enabled:
return True
# 检查日预算
daily_cost = await self.get_daily_cost(db, user_id)
if daily_cost + Decimal(str(estimated_cost)) > budget.daily_limit_usd:
raise BudgetExceededError("", daily_cost, budget.daily_limit_usd)
# 检查月预算
monthly_cost = await self.get_monthly_cost(db, user_id)
if monthly_cost + Decimal(str(estimated_cost)) > budget.monthly_limit_usd:
raise BudgetExceededError("", monthly_cost, budget.monthly_limit_usd)
return True
async def get_user_budget(self, db: AsyncSession, user_id: str) -> UserBudget | None:
"""获取用户预算配置。"""
result = await db.execute(
select(UserBudget).where(UserBudget.user_id == user_id)
)
return result.scalar_one_or_none()
async def set_user_budget(
self,
db: AsyncSession,
user_id: str,
daily_limit: float | None = None,
monthly_limit: float | None = None,
alert_threshold: float | None = None,
enabled: bool | None = None,
) -> UserBudget:
"""设置用户预算。"""
budget = await self.get_user_budget(db, user_id)
if budget is None:
budget = UserBudget(user_id=user_id)
db.add(budget)
if daily_limit is not None:
budget.daily_limit_usd = Decimal(str(daily_limit))
if monthly_limit is not None:
budget.monthly_limit_usd = Decimal(str(monthly_limit))
if alert_threshold is not None:
budget.alert_threshold = Decimal(str(alert_threshold))
if enabled is not None:
budget.enabled = enabled
await db.commit()
await db.refresh(budget)
return budget
async def get_cost_summary(
self,
db: AsyncSession,
user_id: str,
) -> dict:
"""获取用户成本摘要。"""
daily = await self.get_daily_cost(db, user_id)
monthly = await self.get_monthly_cost(db, user_id)
by_capability = await self.get_cost_by_capability(db, user_id)
budget = await self.get_user_budget(db, user_id)
return {
"daily_cost_usd": float(daily),
"monthly_cost_usd": float(monthly),
"by_capability": {k: float(v) for k, v in by_capability.items()},
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd) if budget else None,
"monthly_limit_usd": float(budget.monthly_limit_usd) if budget else None,
"daily_usage_percent": float(daily / budget.daily_limit_usd * 100)
if budget and budget.daily_limit_usd
else None,
"monthly_usage_percent": float(monthly / budget.monthly_limit_usd * 100)
if budget and budget.monthly_limit_usd
else None,
"enabled": budget.enabled if budget else False,
},
}
# 全局单例
cost_tracker = CostTracker()

View File

@@ -0,0 +1,471 @@
"""Memory service handles memory retrieval, scoring, and prompt injection."""
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models import ChildProfile, MemoryItem, StoryUniverse
logger = get_logger(__name__)
class MemoryType:
"""记忆类型常量及配置。"""
# 基础类型
RECENT_STORY = "recent_story"
FAVORITE_CHARACTER = "favorite_character"
SCARY_ELEMENT = "scary_element"
VOCABULARY_GROWTH = "vocabulary_growth"
EMOTIONAL_HIGHLIGHT = "emotional_highlight"
# Phase 1 新增类型
READING_PREFERENCE = "reading_preference" # 阅读偏好
MILESTONE = "milestone" # 里程碑事件
SKILL_MASTERED = "skill_mastered" # 掌握的技能
# 类型配置: (默认权重, 默认TTL天数, 描述)
CONFIG = {
RECENT_STORY: (1.0, 30, "最近阅读的故事"),
FAVORITE_CHARACTER: (1.5, None, "喜欢的角色"), # None = 永久
SCARY_ELEMENT: (2.0, None, "回避的元素"), # 高权重,永久有效
VOCABULARY_GROWTH: (0.8, 90, "词汇积累"),
EMOTIONAL_HIGHLIGHT: (1.2, 60, "情感高光"),
READING_PREFERENCE: (1.0, None, "阅读偏好"),
MILESTONE: (1.5, None, "里程碑事件"),
SKILL_MASTERED: (1.0, 180, "掌握的技能"),
}
@classmethod
def get_default_weight(cls, memory_type: str) -> float:
"""获取类型的默认权重。"""
config = cls.CONFIG.get(memory_type)
return config[0] if config else 1.0
@classmethod
def get_default_ttl(cls, memory_type: str) -> int | None:
"""获取类型的默认 TTL 天数。"""
config = cls.CONFIG.get(memory_type)
return config[1] if config else None
def _decay_factor(days: float) -> float:
"""计算时间衰减因子。"""
if days <= 7:
return 1.0
if days <= 30:
return 0.7
if days <= 90:
return 0.4
return 0.2
async def build_enhanced_memory_context(
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
) -> str | None:
"""构建增强版记忆上下文(自然语言 Prompt"""
if not profile_id and not universe_id:
return None
context_parts: list[str] = []
# 1. 基础档案 (Identity Layer)
if profile_id:
profile = await db.scalar(select(ChildProfile).where(ChildProfile.id == profile_id))
if profile:
context_parts.append(f"【目标读者】\n姓名:{profile.name}")
if profile.age:
context_parts.append(f"年龄:{profile.age}")
if profile.interests:
context_parts.append(f"兴趣爱好:{''.join(profile.interests)}")
if profile.growth_themes:
context_parts.append(f"当前成长关注点:{''.join(profile.growth_themes)}")
context_parts.append("") # 空行
# 2. 故事宇宙 (Universe Layer)
if universe_id:
universe = await db.scalar(select(StoryUniverse).where(StoryUniverse.id == universe_id))
if universe:
context_parts.append("【故事宇宙设定】")
context_parts.append(f"世界观:{universe.name}")
# 主角
protagonist = universe.protagonist or {}
p_desc = f"{protagonist.get('name', '主角')} ({protagonist.get('personality', '')})"
context_parts.append(f"主角设定:{p_desc}")
# 常驻角色
if universe.recurring_characters:
chars = [f"{c.get('name')} ({c.get('type')})" for c in universe.recurring_characters if isinstance(c, dict)]
context_parts.append(f"已知伙伴:{''.join(chars)}")
# 成就
if universe.achievements:
badges = [str(a.get('type')) for a in universe.achievements if isinstance(a, dict)]
if badges:
context_parts.append(f"已获荣誉:{''.join(badges[:5])}")
context_parts.append("")
# 3. 动态记忆 (Working Memory)
if profile_id:
memories = await _fetch_scored_memories(profile_id, universe_id, db)
if memories:
memory_text = _format_memories_to_prompt(memories)
if memory_text:
context_parts.append("【关键记忆回忆】(请在故事中自然地融入或致敬以下元素)")
context_parts.append(memory_text)
return "\n".join(context_parts)
async def _fetch_scored_memories(
profile_id: str,
universe_id: str | None,
db: AsyncSession,
limit: int = 8
) -> list[MemoryItem]:
"""获取并评分记忆项,返回 Top N。"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
# 取最近 50 条进行评分
query = query.order_by(MemoryItem.last_used_at.desc(), MemoryItem.created_at.desc()).limit(50)
result = await db.execute(query)
items = result.scalars().all()
scored: list[tuple[float, MemoryItem]] = []
now = datetime.now(timezone.utc)
for item in items:
reference = item.last_used_at or item.created_at or now
delta_days = max((now - reference).total_seconds() / 86400, 0)
if item.ttl_days and delta_days > item.ttl_days:
continue
score = (item.base_weight or 1.0) * _decay_factor(delta_days)
if score <= 0.1: # 忽略低权重
continue
scored.append((score, item))
scored.sort(key=lambda x: x[0], reverse=True)
return [item for _, item in scored[:limit]]
def _format_memories_to_prompt(memories: list[MemoryItem]) -> str:
"""将记忆项转换为自然语言指令。"""
lines = []
# 分类处理
recent_stories = []
favorites = []
scary = []
vocab = []
for m in memories:
if m.type == MemoryType.RECENT_STORY:
recent_stories.append(m)
elif m.type == MemoryType.FAVORITE_CHARACTER:
favorites.append(m)
elif m.type == MemoryType.SCARY_ELEMENT:
scary.append(m)
elif m.type == MemoryType.VOCABULARY_GROWTH:
vocab.append(m)
# 1. 喜欢的角色
if favorites:
names = []
for m in favorites:
val = m.value
if isinstance(val, dict):
names.append(f"{val.get('name')} ({val.get('description', '')})")
if names:
lines.append(f"- 孩子特别喜欢这些角色,可以让他们客串出场:{', '.join(names)}")
# 2. 避雷区
if scary:
items = []
for m in scary:
val = m.value
if isinstance(val, dict):
items.append(val.get('keyword', ''))
elif isinstance(val, str):
items.append(val)
if items:
lines.append(f"- 【注意禁止】不要出现以下让孩子害怕的元素:{', '.join(items)}")
# 3. 近期故事 (取最近 2 个)
if recent_stories:
lines.append("- 近期经历(可作为彩蛋提及):")
for m in recent_stories[:2]:
val = m.value
if isinstance(val, dict):
title = val.get('title', '未知故事')
lines.append(f" * 之前读过《{title}")
# 4. 词汇积累
if vocab:
words = []
for m in vocab:
val = m.value
if isinstance(val, dict):
words.append(val.get('word'))
if words:
lines.append(f"- 已掌握词汇(可适当复现以巩固):{', '.join([w for w in words if w])}")
return "\n".join(lines)
async def prune_expired_memories(db: AsyncSession) -> int:
"""清理过期的记忆项。
Returns:
删除的记录数量
"""
from sqlalchemy import delete
now = datetime.now(timezone.utc)
# 查找所有设置了 TTL 的项目
stmt = select(MemoryItem).where(MemoryItem.ttl_days.is_not(None))
result = await db.execute(stmt)
candidates = result.scalars().all()
to_delete_ids = []
for item in candidates:
if not item.ttl_days:
continue
reference = item.last_used_at or item.created_at or now
delta_days = (now - reference).total_seconds() / 86400
if delta_days > item.ttl_days:
to_delete_ids.append(item.id)
if not to_delete_ids:
return 0
delete_stmt = delete(MemoryItem).where(MemoryItem.id.in_(to_delete_ids))
await db.execute(delete_stmt)
await db.commit()
logger.info("memory_pruned", count=len(to_delete_ids))
return len(to_delete_ids)
async def create_memory(
db: AsyncSession,
profile_id: str,
memory_type: str,
value: dict,
universe_id: str | None = None,
weight: float | None = None,
ttl_days: int | None = None,
) -> MemoryItem:
"""创建新的记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 记忆类型 (使用 MemoryType 常量)
value: 记忆内容 (JSON 格式)
universe_id: 可选,关联的故事宇宙 ID
weight: 可选,权重 (默认使用类型配置)
ttl_days: 可选,过期天数 (默认使用类型配置)
Returns:
创建的 MemoryItem
"""
memory = MemoryItem(
child_profile_id=profile_id,
universe_id=universe_id,
type=memory_type,
value=value,
base_weight=weight or MemoryType.get_default_weight(memory_type),
ttl_days=ttl_days if ttl_days is not None else MemoryType.get_default_ttl(memory_type),
)
db.add(memory)
await db.commit()
await db.refresh(memory)
logger.info(
"memory_created",
memory_id=memory.id,
profile_id=profile_id,
type=memory_type,
)
return memory
async def update_memory_usage(db: AsyncSession, memory_id: str) -> None:
"""更新记忆的最后使用时间。
Args:
db: 数据库会话
memory_id: 记忆项 ID
"""
result = await db.execute(select(MemoryItem).where(MemoryItem.id == memory_id))
memory = result.scalar_one_or_none()
if memory:
memory.last_used_at = datetime.now(timezone.utc)
await db.commit()
logger.debug("memory_usage_updated", memory_id=memory_id)
async def get_profile_memories(
db: AsyncSession,
profile_id: str,
memory_type: str | None = None,
universe_id: str | None = None,
limit: int = 50,
) -> list[MemoryItem]:
"""获取档案的记忆列表。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 可选,按类型筛选
universe_id: 可选,按宇宙筛选
limit: 返回数量限制
Returns:
MemoryItem 列表
"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if memory_type:
query = query.where(MemoryItem.type == memory_type)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
query = query.order_by(MemoryItem.created_at.desc()).limit(limit)
result = await db.execute(query)
return list(result.scalars().all())
async def create_story_memory(
db: AsyncSession,
profile_id: str,
story_id: int,
title: str,
summary: str | None = None,
keywords: list[str] | None = None,
universe_id: str | None = None,
) -> MemoryItem:
"""为故事创建记忆项。
这是一个便捷函数,专门用于在故事阅读后创建 recent_story 类型的记忆。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
story_id: 故事 ID
title: 故事标题
summary: 故事梗概
keywords: 关键词列表
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"story_id": story_id,
"title": title,
"summary": summary or "",
"keywords": keywords or [],
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.RECENT_STORY,
value=value,
universe_id=universe_id,
)
async def create_character_memory(
db: AsyncSession,
profile_id: str,
name: str,
description: str | None = None,
source_story_id: int | None = None,
affinity_score: float = 1.0,
universe_id: str | None = None,
) -> MemoryItem:
"""为喜欢的角色创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
name: 角色名称
description: 角色描述
source_story_id: 来源故事 ID
affinity_score: 喜爱程度 (0.0-1.0)
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"name": name,
"description": description or "",
"source_story_id": source_story_id,
"affinity_score": min(1.0, max(0.0, affinity_score)),
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.FAVORITE_CHARACTER,
value=value,
universe_id=universe_id,
)
async def create_scary_element_memory(
db: AsyncSession,
profile_id: str,
keyword: str,
category: str = "other",
source_story_id: int | None = None,
) -> MemoryItem:
"""为回避元素创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
keyword: 回避的关键词
category: 分类 (creature/scene/action/other)
source_story_id: 来源故事 ID
Returns:
创建的 MemoryItem
"""
value = {
"keyword": keyword,
"category": category,
"source_story_id": source_story_id,
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.SCARY_ELEMENT,
value=value,
)

View File

@@ -0,0 +1,31 @@
"""In-memory cache for providers loaded from DB."""
from collections import defaultdict
from typing import Literal
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.admin_models import Provider
ProviderType = Literal["text", "image", "tts", "storybook"]
_cache: dict[ProviderType, list[Provider]] = defaultdict(list)
async def reload_providers(db: AsyncSession):
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
providers = result.scalars().all()
grouped: dict[ProviderType, list[Provider]] = defaultdict(list)
for p in providers:
grouped[p.type].append(p)
# sort by priority desc, then weight desc
for k in grouped:
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
_cache.clear()
_cache.update(grouped)
return _cache
def get_providers(provider_type: ProviderType) -> list[Provider]:
return _cache.get(provider_type, [])

View File

@@ -0,0 +1,248 @@
"""供应商指标收集和健康检查服务。"""
from datetime import datetime, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import ProviderHealth, ProviderMetrics
if TYPE_CHECKING:
from app.services.adapters.base import BaseAdapter
logger = get_logger(__name__)
# 熔断阈值:连续失败次数
CIRCUIT_BREAKER_THRESHOLD = 3
# 熔断恢复时间(秒)
CIRCUIT_BREAKER_RECOVERY_SECONDS = 60
class MetricsCollector:
"""供应商调用指标收集器。"""
async def record_call(
self,
db: AsyncSession,
provider_id: str,
success: bool,
latency_ms: int | None = None,
cost_usd: float | None = None,
error_message: str | None = None,
request_id: str | None = None,
) -> None:
"""记录一次 API 调用。"""
metric = ProviderMetrics(
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
cost_usd=Decimal(str(cost_usd)) if cost_usd else None,
error_message=error_message,
request_id=request_id,
)
db.add(metric)
await db.commit()
logger.debug(
"metrics_recorded",
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
)
async def get_success_rate(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的成功率。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(
func.count().filter(ProviderMetrics.success.is_(True)).label("success_count"),
func.count().label("total_count"),
).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
row = result.one()
success_count, total_count = row.success_count, row.total_count
if total_count == 0:
return 1.0 # 无数据时假设健康
return success_count / total_count
async def get_avg_latency(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的平均延迟(毫秒)。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.avg(ProviderMetrics.latency_ms)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
ProviderMetrics.latency_ms.isnot(None),
)
)
avg = result.scalar()
return float(avg) if avg else 0.0
async def get_total_cost(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的总成本USD"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.sum(ProviderMetrics.cost_usd)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
total = result.scalar()
return float(total) if total else 0.0
class HealthChecker:
"""供应商健康检查器。"""
async def check_provider(
self,
db: AsyncSession,
provider_id: str,
adapter: "BaseAdapter",
) -> bool:
"""执行健康检查并更新状态。"""
try:
is_healthy = await adapter.health_check()
except Exception as e:
logger.warning("health_check_failed", provider_id=provider_id, error=str(e))
is_healthy = False
await self.update_health_status(
db,
provider_id,
is_healthy,
error=None if is_healthy else "Health check failed",
)
return is_healthy
async def update_health_status(
self,
db: AsyncSession,
provider_id: str,
is_healthy: bool,
error: str | None = None,
) -> None:
"""更新供应商健康状态(含熔断逻辑)。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
now = datetime.utcnow()
if health is None:
health = ProviderHealth(
provider_id=provider_id,
is_healthy=is_healthy,
last_check=now,
consecutive_failures=0 if is_healthy else 1,
last_error=error,
)
db.add(health)
else:
health.last_check = now
if is_healthy:
health.is_healthy = True
health.consecutive_failures = 0
health.last_error = None
else:
health.consecutive_failures += 1
health.last_error = error
# 熔断逻辑
if health.consecutive_failures >= CIRCUIT_BREAKER_THRESHOLD:
health.is_healthy = False
logger.warning(
"circuit_breaker_triggered",
provider_id=provider_id,
consecutive_failures=health.consecutive_failures,
)
await db.commit()
async def record_call_result(
self,
db: AsyncSession,
provider_id: str,
success: bool,
error: str | None = None,
) -> None:
"""根据调用结果更新健康状态。"""
await self.update_health_status(db, provider_id, success, error)
async def get_healthy_providers(
self,
db: AsyncSession,
provider_ids: list[str],
) -> list[str]:
"""获取健康的供应商列表。"""
if not provider_ids:
return []
# 查询所有已记录的健康状态
result = await db.execute(
select(ProviderHealth.provider_id, ProviderHealth.is_healthy).where(
ProviderHealth.provider_id.in_(provider_ids),
)
)
health_map = {row[0]: row[1] for row in result.all()}
# 未记录的供应商默认健康,已记录但不健康的排除
return [
pid for pid in provider_ids
if pid not in health_map or health_map[pid]
]
async def is_healthy(
self,
db: AsyncSession,
provider_id: str,
) -> bool:
"""检查供应商是否健康。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
if health is None:
return True # 未记录默认健康
# 检查是否可以恢复
if not health.is_healthy and health.last_check:
recovery_time = health.last_check + timedelta(seconds=CIRCUIT_BREAKER_RECOVERY_SECONDS)
if datetime.utcnow() >= recovery_time:
return True # 允许重试
return health.is_healthy
# 全局单例
metrics_collector = MetricsCollector()
health_checker = HealthChecker()

View File

@@ -0,0 +1,432 @@
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
import time
from enum import Enum
from typing import TYPE_CHECKING, Literal, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.services.adapters import AdapterConfig, AdapterRegistry
from app.services.adapters.text.models import StoryOutput
from app.services.cost_tracker import cost_tracker
from app.services.provider_cache import get_providers
from app.services.provider_metrics import health_checker, metrics_collector
if TYPE_CHECKING:
from app.db.admin_models import Provider
logger = get_logger(__name__)
T = TypeVar("T")
ProviderType = Literal["text", "image", "tts", "storybook"]
class RoutingStrategy(str, Enum):
"""路由策略枚举。"""
PRIORITY = "priority" # 按优先级排序(默认)
COST = "cost" # 按成本排序
LATENCY = "latency" # 按延迟排序
ROUND_ROBIN = "round_robin" # 轮询
# 默认配置映射(当 DB 无配置时使用)
# 默认配置映射(当 DB 无配置时使用)
# 这是“代码级”的默认策略,对应 .env 为空的情况
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
"text": ["gemini", "openai"],
"image": ["cqtai"],
"tts": ["minimax", "elevenlabs", "edge_tts"],
"storybook": ["gemini"],
}
# API Key 映射adapter_name -> settings 属性名
API_KEY_MAP: dict[str, str] = {
# Text
"gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段
"text_primary": "text_api_key", # 兼容旧别名
"openai": "openai_api_key",
# Image
"cqtai": "cqtai_api_key",
"image_primary": "image_api_key", # 兼容旧别名
# TTS
"minimax": "minimax_api_key",
"elevenlabs": "elevenlabs_api_key",
"edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空)
"tts_primary": "tts_api_key", # 兼容旧别名
}
# 轮询计数器
_round_robin_counters: dict[ProviderType, int] = {
"text": 0,
"image": 0,
"tts": 0,
}
# 延迟缓存(内存中,简化实现)
_latency_cache: dict[str, float] = {}
def _get_api_key(config_ref: str | None, adapter_name: str) -> str:
"""根据 config_ref 或适配器名称获取 API Key。"""
# 优先使用 config_ref
key_attr = API_KEY_MAP.get(config_ref or adapter_name, None)
if key_attr:
return getattr(settings, key_attr, "")
# 回退到适配器名称
key_attr = API_KEY_MAP.get(adapter_name, None)
if key_attr:
return getattr(settings, key_attr, "")
return ""
def _get_default_config(adapter_name: str) -> AdapterConfig | None:
"""获取适配器的默认配置(无 DB 记录时使用)。返回 None 表示未知适配器。"""
# --- Text Defaults ---
if adapter_name in ("gemini", "text_primary"):
return AdapterConfig(
api_key=settings.text_api_key,
model=settings.text_model or "gemini-2.0-flash",
timeout_ms=60000,
)
if adapter_name == "openai":
return AdapterConfig(
api_key=getattr(settings, "openai_api_key", ""),
model="gpt-4o-mini", # 这里可以从 settings 读取,看需求
timeout_ms=60000,
)
# --- Image Defaults ---
if adapter_name in ("cqtai"):
return AdapterConfig(
api_key=getattr(settings, "cqtai_api_key", ""),
model="nano-banana-pro", # 默认使用 Pro
timeout_ms=120000,
)
if adapter_name == "image_primary":
# 如果还有地方在用 image_primary暂时映射到快或者其他
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
return AdapterConfig(
api_key=settings.image_api_key,
timeout_ms=120000
)
# --- TTS Defaults ---
if adapter_name == "minimax":
# 传递 group_id 到 Adapter
# 目前 AdapterConfig 没有 group_id 字段,我们暂时不改 Base
# 而是假设 Adapter 会从 config (通过 kwargs 或其他方式) 拿。
# 实际上我们的 MiniMaxTTSAdapter 还没有处理 group_id。
# 最简单的方法:把 group_id 藏在 api_base 里或者让 Adapter 自己去 settings 拿。
# 鉴于 _build_config_from_provider 里我们无法传递额外参数给 Adapter.__init__
# 我们这里暂时返回基础配置。
return AdapterConfig(
api_key=getattr(settings, "minimax_api_key", ""),
model="speech-2.6-turbo",
timeout_ms=60000,
)
if adapter_name == "elevenlabs":
return AdapterConfig(
api_key=getattr(settings, "elevenlabs_api_key", ""),
timeout_ms=120000,
)
if adapter_name in ("edge_tts", "tts_primary"):
return AdapterConfig(
api_key=settings.tts_api_key,
api_base=settings.tts_api_base,
model=settings.tts_model or "zh-CN-XiaoxiaoNeural",
timeout_ms=120000,
)
# --- Others ---
if adapter_name in ("storybook_primary", "storybook_gemini"):
return AdapterConfig(
api_key=settings.text_api_key, # 复用 Gemini key
model=settings.text_model,
timeout_ms=120000,
)
# 未知适配器返回 None
return None
def _build_config_from_provider(provider: "Provider") -> AdapterConfig:
"""从 DB Provider 记录构建 AdapterConfig。"""
api_key = getattr(provider, "api_key", None) or ""
if not api_key:
api_key = _get_api_key(provider.config_ref, provider.adapter)
default = _get_default_config(provider.adapter)
if default is None:
default = AdapterConfig(api_key="", timeout_ms=60000)
return AdapterConfig(
api_key=api_key or default.api_key,
api_base=provider.api_base or default.api_base,
model=provider.model or default.model,
timeout_ms=provider.timeout_ms or default.timeout_ms,
max_retries=provider.max_retries or default.max_retries,
extra_config=provider.config_json or {},
)
def _get_providers_with_config(
provider_type: ProviderType,
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
"""获取供应商列表及其配置。
Returns:
[(adapter_name, config, provider_or_none), ...] 按优先级排序
"""
db_providers = get_providers(provider_type)
if db_providers:
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
settings_map = {
"text": settings.text_providers,
"image": settings.image_providers,
"tts": settings.tts_providers,
}
names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type]
result = []
for name in names:
config = _get_default_config(name)
if config is None:
logger.warning("unknown_adapter_skipped", adapter=name, provider_type=provider_type)
continue
result.append((name, config, None))
return result
def _sort_by_strategy(
providers: list[tuple[str, AdapterConfig, "Provider | None"]],
strategy: RoutingStrategy,
provider_type: ProviderType,
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
"""按策略排序供应商列表。"""
if strategy == RoutingStrategy.PRIORITY:
# 按 priority 降序, weight 降序
return sorted(
providers,
key=lambda x: (-(x[2].priority if x[2] else 0), -(x[2].weight if x[2] else 1)),
)
if strategy == RoutingStrategy.COST:
# 按预估成本升序
def get_cost(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
adapter_class = AdapterRegistry.get(provider_type, item[0])
if adapter_class:
try:
adapter = adapter_class(item[1])
return adapter.estimated_cost
except Exception:
pass
return float("inf")
return sorted(providers, key=get_cost)
if strategy == RoutingStrategy.LATENCY:
# 按历史延迟升序
def get_latency(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
return _latency_cache.get(item[0], float("inf"))
return sorted(providers, key=get_latency)
if strategy == RoutingStrategy.ROUND_ROBIN:
# 轮询:旋转列表
counter = _round_robin_counters[provider_type]
_round_robin_counters[provider_type] = (counter + 1) % max(len(providers), 1)
return providers[counter:] + providers[:counter]
return providers
async def _route_with_failover(
provider_type: ProviderType,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
**kwargs,
) -> T:
"""通用 provider failover 路由。
Args:
provider_type: 供应商类型 (text/image/tts/storybook)
strategy: 路由策略
db: 数据库会话(可选,用于指标收集和熔断检查)
user_id: 用户 ID可选用于成本追踪和预算检查
**kwargs: 传递给适配器的参数
"""
providers = _get_providers_with_config(provider_type)
if not providers:
raise ValueError(f"No {provider_type} providers configured.")
# 按策略排序
sorted_providers = _sort_by_strategy(providers, strategy, provider_type)
# 如果有 db 会话,过滤掉熔断的供应商
if db:
healthy_providers = []
for item in sorted_providers:
name, config, db_provider = item
provider_id = db_provider.id if db_provider else name
if await health_checker.is_healthy(db, provider_id):
healthy_providers.append(item)
else:
logger.debug("provider_circuit_open", adapter=name, provider_id=provider_id)
# 如果所有供应商都熔断,仍然尝试第一个(允许恢复)
if not healthy_providers:
healthy_providers = sorted_providers[:1]
sorted_providers = healthy_providers
errors: list[str] = []
for name, config, db_provider in sorted_providers:
adapter_class = AdapterRegistry.get(provider_type, name)
if not adapter_class:
errors.append(f"{name}: 适配器未注册")
continue
provider_id = db_provider.id if db_provider else name
try:
logger.debug(
"provider_attempt",
provider_type=provider_type,
adapter=name,
strategy=strategy.value,
)
adapter = adapter_class(config)
# 执行并计时
start_time = time.time()
result = await adapter.execute(**kwargs)
latency_ms = int((time.time() - start_time) * 1000)
# 更新延迟缓存
_latency_cache[name] = latency_ms
# 记录成功指标
if db:
await metrics_collector.record_call(
db,
provider_id=provider_id,
success=True,
latency_ms=latency_ms,
cost_usd=adapter.estimated_cost,
)
await health_checker.record_call_result(db, provider_id, success=True)
# 记录用户成本
if user_id:
await cost_tracker.record_cost(
db,
user_id=user_id,
provider_name=name,
capability=provider_type,
estimated_cost=adapter.estimated_cost,
provider_id=provider_id if db_provider else None,
)
logger.info(
"provider_success",
provider_type=provider_type,
adapter=name,
latency_ms=latency_ms,
)
return result
except Exception as exc:
error_msg = str(exc)
logger.warning(
"provider_failed",
provider_type=provider_type,
adapter=name,
error=error_msg,
)
errors.append(f"{name}: {exc}")
# 记录失败指标
if db:
await metrics_collector.record_call(
db,
provider_id=provider_id,
success=False,
error_message=error_msg,
)
await health_checker.record_call_result(
db, provider_id, success=False, error=error_msg
)
raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}")
async def generate_story_content(
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
) -> StoryOutput:
"""生成或润色故事,支持 failover。"""
return await _route_with_failover(
"text",
strategy=strategy,
db=db,
input_type=input_type,
data=data,
education_theme=education_theme,
memory_context=memory_context,
)
async def generate_image(
prompt: str,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
**kwargs,
) -> str:
"""生成图片,返回 URL支持 failover。"""
return await _route_with_failover("image", strategy=strategy, db=db, prompt=prompt, **kwargs)
async def text_to_speech(
text: str,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
) -> bytes:
"""文本转语音,返回 MP3 bytes支持 failover。"""
return await _route_with_failover("tts", strategy=strategy, db=db, text=text)
async def generate_storybook(
keywords: str,
page_count: int = 6,
education_theme: str | None = None,
memory_context: str | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
):
"""生成分页故事书,支持 failover。"""
from app.services.adapters.storybook.primary import Storybook
result: Storybook = await _route_with_failover(
"storybook",
strategy=strategy,
db=db,
keywords=keywords,
page_count=page_count,
education_theme=education_theme,
memory_context=memory_context,
)
return result

View File

@@ -0,0 +1,207 @@
"""供应商密钥加密存储服务。
使用 Fernet 对称加密,密钥从 SECRET_KEY 派生。
"""
import base64
import hashlib
from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.admin_models import ProviderSecret
if TYPE_CHECKING:
pass
logger = get_logger(__name__)
class SecretEncryptionError(Exception):
"""密钥加密/解密错误。"""
pass
class SecretService:
"""供应商密钥加密存储服务。"""
_fernet: Fernet | None = None
@classmethod
def _get_fernet(cls) -> Fernet:
"""获取 Fernet 实例,从 SECRET_KEY 派生加密密钥。"""
if cls._fernet is None:
# 从 SECRET_KEY 派生 32 字节密钥
key_bytes = hashlib.sha256(settings.secret_key.encode()).digest()
fernet_key = base64.urlsafe_b64encode(key_bytes)
cls._fernet = Fernet(fernet_key)
return cls._fernet
@classmethod
def encrypt(cls, plaintext: str) -> str:
"""加密明文,返回 base64 编码的密文。
Args:
plaintext: 要加密的明文
Returns:
base64 编码的密文
"""
if not plaintext:
return ""
fernet = cls._get_fernet()
encrypted = fernet.encrypt(plaintext.encode())
return encrypted.decode()
@classmethod
def decrypt(cls, ciphertext: str) -> str:
"""解密密文,返回明文。
Args:
ciphertext: base64 编码的密文
Returns:
解密后的明文
Raises:
SecretEncryptionError: 解密失败
"""
if not ciphertext:
return ""
try:
fernet = cls._get_fernet()
decrypted = fernet.decrypt(ciphertext.encode())
return decrypted.decode()
except InvalidToken as e:
logger.error("secret_decrypt_failed", error=str(e))
raise SecretEncryptionError("密钥解密失败,可能是 SECRET_KEY 已更改") from e
@classmethod
async def get_secret(cls, db: AsyncSession, name: str) -> str | None:
"""从数据库获取并解密密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
解密后的密钥值,不存在返回 None
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return None
return cls.decrypt(secret.encrypted_value)
@classmethod
async def set_secret(cls, db: AsyncSession, name: str, value: str) -> ProviderSecret:
"""存储或更新加密密钥。
Args:
db: 数据库会话
name: 密钥名称
value: 密钥明文值
Returns:
ProviderSecret 实例
"""
encrypted = cls.encrypt(value)
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
secret = ProviderSecret(name=name, encrypted_value=encrypted)
db.add(secret)
else:
secret.encrypted_value = encrypted
await db.commit()
await db.refresh(secret)
logger.info("secret_stored", name=name)
return secret
@classmethod
async def delete_secret(cls, db: AsyncSession, name: str) -> bool:
"""删除密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
是否删除成功
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return False
await db.delete(secret)
await db.commit()
logger.info("secret_deleted", name=name)
return True
@classmethod
async def list_secrets(cls, db: AsyncSession) -> list[str]:
"""列出所有密钥名称(不返回值)。
Args:
db: 数据库会话
Returns:
密钥名称列表
"""
result = await db.execute(select(ProviderSecret.name))
return [row[0] for row in result.fetchall()]
@classmethod
async def get_api_key(
cls,
db: AsyncSession,
provider_api_key: str | None,
config_ref: str | None,
) -> str | None:
"""获取 Provider 的 API Key按优先级查找。
优先级:
1. provider.api_key (数据库明文/加密)
2. provider.config_ref 指向的 ProviderSecret
3. 环境变量 (config_ref 作为变量名)
Args:
db: 数据库会话
provider_api_key: Provider 表中的 api_key 字段
config_ref: Provider 表中的 config_ref 字段
Returns:
API Key 或 None
"""
# 1. 直接使用 provider.api_key
if provider_api_key:
# 尝试解密,如果失败则当作明文
try:
decrypted = cls.decrypt(provider_api_key)
if decrypted:
return decrypted
except SecretEncryptionError:
pass
return provider_api_key
# 2. 从 ProviderSecret 表查找
if config_ref:
secret_value = await cls.get_secret(db, config_ref)
if secret_value:
return secret_value
# 3. 从环境变量查找
env_value = getattr(settings, config_ref.lower(), None)
if env_value:
return env_value
return None

View File

@@ -0,0 +1,3 @@
"""Celery tasks package."""
from . import achievements, memory, push_notifications # noqa: F401

View File

@@ -0,0 +1,82 @@
"""Celery tasks for achievements."""
import asyncio
from datetime import datetime
from sqlalchemy import select
from app.core.celery_app import celery_app
from app.core.logging import get_logger
from app.db.database import _get_session_factory
from app.db.models import Story, StoryUniverse
from app.services.achievement_extractor import extract_achievements
logger = get_logger(__name__)
@celery_app.task
def extract_story_achievements(story_id: int, universe_id: str) -> None:
"""Extract achievements and update universe."""
asyncio.run(_extract_story_achievements(story_id, universe_id))
async def _extract_story_achievements(story_id: int, universe_id: str) -> None:
session_factory = _get_session_factory()
async with session_factory() as session:
result = await session.execute(select(Story).where(Story.id == story_id))
story = result.scalar_one_or_none()
if not story:
logger.warning("achievement_task_story_missing", story_id=story_id)
return
result = await session.execute(
select(StoryUniverse).where(StoryUniverse.id == universe_id)
)
universe = result.scalar_one_or_none()
if not universe:
logger.warning("achievement_task_universe_missing", universe_id=universe_id)
return
text_content = story.story_text
if not text_content and story.pages:
# 如果是绘本,拼接每页文本
text_content = "\n".join([str(p.get("text", "")) for p in story.pages])
if not text_content:
logger.warning("achievement_task_empty_content", story_id=story_id)
return
achievements = await extract_achievements(text_content)
if not achievements:
logger.info("achievement_task_no_new", story_id=story_id)
return
existing = {
(str(item.get("type", "")).strip(), str(item.get("description", "")).strip())
for item in (universe.achievements or [])
if isinstance(item, dict)
}
merged = list(universe.achievements or [])
added_count = 0
for item in achievements:
key = (item.get("type", "").strip(), item.get("description", "").strip())
if key in existing:
continue
merged.append({
"type": key[0],
"description": key[1],
"obtained_at": datetime.now().isoformat(),
"source_story_id": story_id,
})
existing.add(key)
added_count += 1
universe.achievements = merged
await session.commit()
logger.info(
"achievement_task_success",
story_id=story_id,
universe_id=universe_id,
added=added_count,
)

View File

@@ -0,0 +1,29 @@
import asyncio
from app.core.celery_app import celery_app
from app.core.logging import get_logger
from app.db.database import _get_session_factory
from app.services.memory_service import prune_expired_memories
logger = get_logger(__name__)
@celery_app.task
def prune_memories_task():
"""Daily task to prune expired memories."""
logger.info("prune_memories_task_started")
async def _run():
# Ensure engine is initialized in this process
session_factory = _get_session_factory()
async with session_factory() as session:
return await prune_expired_memories(session)
try:
# Create a new event loop for this task execution
count = asyncio.run(_run())
logger.info("prune_memories_task_completed", deleted_count=count)
return f"Deleted {count} expired memories"
except Exception as exc:
logger.error("prune_memories_task_failed", error=str(exc))
raise

View File

@@ -0,0 +1,108 @@
"""Celery tasks for push notifications."""
import asyncio
from datetime import datetime, time
from zoneinfo import ZoneInfo
from sqlalchemy import select
from app.core.celery_app import celery_app
from app.core.logging import get_logger
from app.db.database import _get_session_factory
from app.db.models import PushConfig, PushEvent
logger = get_logger(__name__)
LOCAL_TZ = ZoneInfo("Asia/Shanghai")
QUIET_HOURS_START = time(21, 0)
QUIET_HOURS_END = time(9, 0)
TRIGGER_WINDOW_MINUTES = 30
@celery_app.task
def check_push_notifications() -> None:
"""Check push configs and create push events."""
asyncio.run(_check_push_notifications())
def _is_quiet_hours(current: time) -> bool:
if QUIET_HOURS_START < QUIET_HOURS_END:
return QUIET_HOURS_START <= current < QUIET_HOURS_END
return current >= QUIET_HOURS_START or current < QUIET_HOURS_END
def _within_window(current: time, target: time) -> bool:
current_minutes = current.hour * 60 + current.minute
target_minutes = target.hour * 60 + target.minute
return 0 <= current_minutes - target_minutes < TRIGGER_WINDOW_MINUTES
async def _already_sent_today(
session,
child_profile_id: str,
now: datetime,
) -> bool:
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
end = now.replace(hour=23, minute=59, second=59, microsecond=999999)
result = await session.execute(
select(PushEvent.id).where(
PushEvent.child_profile_id == child_profile_id,
PushEvent.status == "sent",
PushEvent.sent_at >= start,
PushEvent.sent_at <= end,
)
)
return result.scalar_one_or_none() is not None
async def _check_push_notifications() -> None:
session_factory = _get_session_factory()
now = datetime.now(LOCAL_TZ)
current_day = now.weekday()
current_time = now.time()
async with session_factory() as session:
result = await session.execute(
select(PushConfig).where(PushConfig.enabled.is_(True))
)
configs = result.scalars().all()
for config in configs:
if not config.push_time:
continue
if config.push_days and current_day not in config.push_days:
continue
if not _within_window(current_time, config.push_time):
continue
if _is_quiet_hours(current_time):
session.add(
PushEvent(
user_id=config.user_id,
child_profile_id=config.child_profile_id,
trigger_type="time",
status="suppressed",
reason="quiet_hours",
sent_at=now,
)
)
continue
if await _already_sent_today(session, config.child_profile_id, now):
continue
session.add(
PushEvent(
user_id=config.user_id,
child_profile_id=config.child_profile_id,
trigger_type="time",
status="sent",
reason=None,
sent_at=now,
)
)
logger.info(
"push_event_sent",
child_profile_id=config.child_profile_id,
user_id=config.user_id,
)
await session.commit()

View File

@@ -0,0 +1,14 @@
# Code Review Report (2nd follow-up)
## What¡¯s fixed
- Provider cache now loads on startup via lifespan (`app/main.py`), so DB providers are honored without manual reload.
- Providers support DB-stored `api_key` precedence (`provider_router.py:77-104`) and Provider model added `api_key` column (`db/admin_models.py:25`).
- Frontend uses `/api/generate/full` and propagates image-failure warning to detail via query flag; StoryDetail displays banner when image generation failed.
- Tests added for full generation, provider failover, config-from-DB, and startup cache load.
## Remaining issue
1) **Missing DB migration for new Provider.api_key column**
- Files updated model (`backend/app/db/admin_models.py:25`) but `backend/alembic/versions/0001_init_providers_and_story_mode.py` lacks this column. Existing databases will not have `api_key`, causing runtime errors when accessing or inserting. Add an Alembic migration to add/drop `api_key` to `providers` table and update sample data if needed.
## Suggested action
- Create and apply an Alembic migration adding `api_key` (String, nullable) to `providers`. After migration, verify admin CRUD works with the new field.

View File

@@ -0,0 +1,147 @@
# 记忆系统开发指南 (Development Guide)
本文档详细说明了 PRD 中定义的记忆系统的技术实现细节。
## 1. 数据库架构变更 (Schema Changes)
目前的 `MemoryItem` 表结构尚可,但需要增强字段以支持丰富的情感和元数据。
### 1.1 `MemoryItem` 表优化
建议使用 Alembic 进行迁移,增加以下字段或在 `value` JSON 中规范化以下结构:
```python
# 建议在 models.py 中明确这些字段,或者严格定义 value 字段的 Schema
class MemoryItem(Base):
# ... 现有字段 ...
# 新增/规范化字段建议
# value 字段的 JSON 结构规范:
# {
# "content": "小兔子战胜了大灰狼", # 记忆的核心文本
# "keywords": ["勇敢", "森林"], # 用于检索的关键词
# "emotion": "positive", # 情感倾向: positive/negative/neutral
# "source_story_id": 123, # 来源故事 ID
# "confidence": 0.85 # 记忆置信度 (如果是 AI 自动提取)
# }
```
### 1.2 `StoryUniverse` 表优化 (成就系统)
目前的成就存储在 `achievements` JSON 字段中。为了支持更复杂的查询(如"获得勇气勋章的所有用户"),建议将其重构为独立关联表,或保持 JSON 但规范化结构。
**当前 JSON 结构规范**:
```json
[
{
"id": "badge_courage_01",
"type": "勇气",
"name": "小小勇士",
"description": "第一次在故事中独自面对困难",
"icon_url": "badges/courage.png",
"obtained_at": "2023-10-27T10:00:00Z",
"source_story_id": 45
}
]
```
---
## 2. 核心逻辑实现
### 2.1 记忆注入逻辑 (Prompt Engineering)
修改 `backend/app/api/stories.py` 中的 `_build_memory_context` 函数。
**目标**: 生成一段自然的、不仅是罗列数据的 Prompt。
**伪代码逻辑**:
```python
def format_memory_for_prompt(memories: list[MemoryItem]) -> str:
"""
将记忆项转换为自然语言 Prompt 片段。
"""
context_parts = []
# 1. 角色记忆
chars = [m for m in memories if m.type == 'favorite_character']
if chars:
names = ", ".join([c.value['name'] for c in chars])
context_parts.append(f"孩子特别喜欢的角色有:{names}。请尝试让他们客串出场。")
# 2. 近期情节
recent_stories = [m for m in memories if m.type == 'recent_story'][:2]
if recent_stories:
for story in recent_stories:
context_parts.append(f"最近发生过:{story.value['summary']}。可以在对话中不经意地提及此事。")
# 3. 避雷区 (负面记忆)
scary = [m for m in memories if m.type == 'scary_element']
if scary:
items = ", ".join([s.value['keyword'] for s in scary])
context_parts.append(f"【绝对禁止】不要出现以下让孩子害怕的元素:{items}")
return "\n".join(context_parts)
```
### 2.2 成就提取与通知流程
当前流程在 `app/tasks/achievements.py`。需要完善闭环。
**改进后的流程**:
1. **Story Generation**: 故事生成成功,存入数据库。
2. **Async Task**: 触发 Celery 任务 `extract_story_achievements`
3. **LLM Analysis**: 调用 Gemini 分析故事,提取成就。
4. **Update DB**: 更新 `StoryUniverse.achievements`
5. **Notification (新增)**:
* 创建一个 `Notification``UserMessage` 记录(需要新建表或使用 Redis Pub/Sub
* 前端轮询或通过 SSE (Server-Sent Events) 接收通知:"🎉 恭喜!在这个故事里,小明获得了[诚实勋章]"
### 2.3 记忆清理与衰减 (Maintenance)
需要一个后台定时任务Cron Job清理无效记忆避免 Prompt 过长。
* **频率**: 每天一次。
* **逻辑**:
* 删除 `ttl_days` 已过期的记录。
*`recent_story` 类型的 `base_weight` 进行每日衰减 update或者只在读取时计算数据库存静态值推荐读取时动态计算以减少写操作
*`MemoryItem` 总数超过 100 条时,触发"记忆总结"任务,将多条旧记忆合并为一条"长期印象" (Long-term Impression)。
---
## 3. API 接口规划
### 3.1 获取成长时间轴
`GET /api/profiles/{id}/timeline`
**Response**:
```json
{
"events": [
{
"date": "2023-10-01",
"type": "milestone",
"title": "初次相遇",
"description": "创建了角色 [小明]"
},
{
"date": "2023-10-05",
"type": "story",
"title": "小明与魔法树",
"image_url": "..."
},
{
"date": "2023-10-05",
"type": "achievement",
"badge": {
"name": "好奇宝宝",
"icon": "..."
}
}
]
}
```
### 3.2 记忆反馈 (人工干预)
`POST /api/memories/{id}/feedback`
允许家长手动删除或修正错误的记忆。
* **Action**: `delete` | `reinforce` (强化,增加权重)

View File

@@ -0,0 +1,93 @@
# 梦语织机 (DreamWeaver) 记忆系统升级 PRD
> 版本: v1.0 | 状态: 规划中 | 优先级: High
## 1. 核心愿景 (Vision)
将当前的"数据存储"升级为有温度的**"情感连接系统"**。
我们不只是在记住数据,而是在**维护孩子与故事世界的关系**。让每一个故事不再是孤立的碎片,而是构建孩子专属"故事宇宙"的砖瓦。
---
## 2. 产品痛点与解决方案
| 用户角色 | 核心痛点 | 解决方案 | 预期价值 |
|---------|---------|---------|---------|
| **孩子** | "上次的小兔子怎么不认识我了?" <br> 故事之间缺乏连续性,只有单次体验。 | **角色一致性与记忆注入** <br> 故事开头主动提及往事,角色性格延续。 | 建立情感依恋,提升沉浸感。 |
| **家长** | "这App除了生成故事还能干嘛" <br> 无法感知产品的长期教育价值。 | **显性化成长轨迹** <br> 词汇量统计、主题变化、成就徽章可视化。 | 提高付费意愿,提供社交货币。 |
| **平台** | 用户用完即走,缺乏留存壁垒。 | **沉没成本与情感资产** <br> 积累的记忆越多,越舍不得离开。 | 提升长期留存率 (LTV)。 |
---
## 3. 功能架构:记忆分层模型
### 3.1 层级 1: 核心档案 (Identity Layer)
*性质:永久、静态、显性*
* **数据**: 姓名、年龄、性别。
* **输入**: 家长在 Onboarding 阶段手动输入。
* **作用**: 决定故事的基础适龄性和称呼。
### 3.2 层级 2: 故事宇宙 (Universe Layer)
*性质:长期、动态积累、半显性*
* **主角设定**: 姓名、性格特征(勇敢/害羞)、外貌特征(戴眼镜/卷发)。
* **常驻配角**: 从随机故事中涌现出的固定伙伴(如"爱吃胡萝卜的松鼠奇奇")。
* **世界观**: 故事发生的背景(魔法森林、未来城市、海底世界)。
* **成就系统**: 孩子获得的虚拟奖励(勇气勋章、小小探险家)。
### 3.3 层级 3: 工作记忆 (Working Memory)
*性质:短期、自动衰减、隐性*
* **关键情节**: 最近 3 个故事的结局和核心冲突。
* **情感标记**: 孩子对特定内容的反应(根据“重播”、“跳过”推断)。
* **新学词汇**: 故事中出现的高级词汇。
---
## 4. 关键功能特性 (Feature Specs)
### 4.1 智能开场白 (Memory Injection)
在生成新故事时Prompt 必须包含一段"记忆唤醒"指令。
* **示例**: "小明,还记得上周我们帮小松鼠找回了松果吗?今天,小松鼠带来了一位新朋友..."
* **策略**: 提取权重最高的 Top 3 记忆注入 Prompt。
### 4.2 成长时间轴 (Growth Timeline)
一个可视化的 H5 页面或 App 模块,以时间轴形式展示里程碑。
* **节点类型**:
* 🌟 **初次相遇**: 创建角色的第一天。
* 📖 **阅读打卡**: 累计阅读 10/50/100 本。
* 🏅 **获得成就**: 获得"诚实勋章"。
* 🧠 **能力解锁**: 第一次阅读"科幻"题材。
### 4.3 成就仪式感 (Achievement Ceremony)
* **触发**: 故事生成并分析后,如果获得新成就。
* **表现**: 弹窗动画 + 音效 + "恭喜获得 [勇气] 徽章"。
* **分享**:允许生成带二维码的成就海报。
---
## 5. 记忆类型扩展 (Memory Types)
| 类型 Key | 描述 | 来源 | 过期策略 |
|---------|------|------|---------|
| `recent_story` | 最近读过的故事梗概 | 阅读事件 | 30天衰减 |
| `favorite_character` | 孩子喜欢的角色 | 重播/高评分 | 长期有效 |
| `scary_element` | 孩子害怕/不喜欢的元素 | 跳过/负反馈 | 长期有效 (避雷) |
| `vocabulary_growth` | 新掌握的词汇 | 故事分析 | 90天衰减 |
| `emotional_highlight` | 高光时刻 (如: 特别开心的情节) | 互动数据 | 60天衰减 |
---
## 6. 实施路线图 (Roadmap)
### Phase 1: 基础建设 (v0.3.0)
* [x] 数据库 `MemoryItem` 表 (已存在)。
* [ ] 扩展 `MemoryItem` 类型字段,支持更多维度。
* [ ] 优化 `_build_memory_context`,支持更自然的 Prompt 注入。
* [ ] 前端:简单的"近期回忆"展示列表。
### Phase 2: 可视化与成就 (v0.4.0)
* [ ] 实现"成就提取器" (Achievement Extractor) 的闭环通知。
* [ ] 前端:开发"我的成就"和"成长时间轴"页面。
* [ ] 增加故事开场白的动态生成逻辑。
### Phase 3: 深度智能 (v0.5.0+)
* [ ] 引入向量数据库,实现基于语义的记忆检索 (不仅是时间最近)。
* [ ] 情感分析模型:分析用户行为推断情感倾向。

View File

@@ -0,0 +1,246 @@
# Provider 系统开发文档
## 当前版本功能 (v0.2.0)
### 已完成功能
1. **CQTAI nano 图像适配器** (`app/services/adapters/image/cqtai.py`)
- 异步生成 + 轮询获取结果
- 支持 nano-banana / nano-banana-pro 模型
- 支持多种分辨率和画面比例
- 支持图生图 (filesUrl)
2. **密钥加密存储** (`app/services/secret_service.py`)
- Fernet 对称加密,密钥从 SECRET_KEY 派生
- Provider API Key 自动加密存储
- 密钥管理 API (CRUD)
3. **指标收集系统** (`app/services/provider_metrics.py`)
- 调用成功率、延迟、成本统计
- 时间窗口聚合查询
- 已集成到 provider_router
4. **熔断器功能** (`app/services/provider_metrics.py`)
- 连续失败 3 次触发熔断
- 60 秒后自动恢复尝试
- 健康状态持久化到数据库
5. **管理后台前端** (`app/admin_app.py`)
- 独立端口部署 (8001)
- Vue 3 + Tailwind CSS 单页应用
- Provider CRUD 管理
- 密钥管理界面
- Basic Auth 认证
### 配置说明
```bash
# 启动主应用
uvicorn app.main:app --port 8000
# 启动管理后台 (独立端口)
uvicorn app.admin_app:app --port 8001
```
环境变量:
```
CQTAI_API_KEY=your-cqtai-api-key
ENABLE_ADMIN_CONSOLE=true
ADMIN_USERNAME=admin
ADMIN_PASSWORD=your-secure-password
```
---
## 下一版本优化计划 (v0.3.0)
### 高优先级
#### 1. 智能负载分流 (方案 B)
**目标**: 主渠道压力大时自动分流到后备渠道
**实现方案**:
- 监控指标: 并发数、响应延迟、错误率
- 分流阈值配置:
```python
class LoadBalanceConfig:
max_concurrent: int = 10 # 并发超过此值时分流
max_latency_ms: int = 5000 # 延迟超过此值时分流
max_error_rate: float = 0.1 # 错误率超过 10% 时分流
```
- 分流策略: 加权轮询,根据健康度动态调整权重
**涉及文件**:
- `app/services/provider_router.py` - 添加负载均衡逻辑
- `app/services/provider_metrics.py` - 添加并发计数器
- `app/db/admin_models.py` - 添加 LoadBalanceConfig 模型
#### 2. Storybook 适配器
**目标**: 生成可翻页的分页故事书
**实现方案**:
- 参考 Gemini AI Story Generator 格式
- 输出结构:
```python
class StorybookPage:
page_number: int
text: str
image_prompt: str
image_url: str | None
class Storybook:
title: str
pages: list[StorybookPage]
cover_url: str | None
```
- 集成文本 + 图像生成流水线
**涉及文件**:
- `app/services/adapters/storybook/` - 新建目录
- `app/api/stories.py` - 添加 storybook 生成端点
### 中优先级
#### 3. 成本追踪系统
**目标**: 记录实际消费,支持预算控制
**实现方案**:
- 成本记录表:
```python
class CostRecord:
user_id: str
provider_id: str
capability: str # text/image/tts
estimated_cost: Decimal
actual_cost: Decimal | None
timestamp: datetime
```
- 预算配置:
```python
class BudgetConfig:
user_id: str
daily_limit: Decimal
monthly_limit: Decimal
alert_threshold: float = 0.8 # 80% 时告警
```
- 超预算处理: 拒绝请求 / 降级到低成本 provider
**涉及文件**:
- `app/db/admin_models.py` - 添加 CostRecord, BudgetConfig
- `app/services/cost_tracker.py` - 新建
- `app/api/admin_providers.py` - 添加成本查询 API
#### 4. 指标可视化
**目标**: 管理后台展示供应商指标图表
**实现方案**:
- 添加指标查询 API:
- GET /admin/metrics/summary - 汇总统计
- GET /admin/metrics/timeline - 时间线数据
- GET /admin/metrics/providers/{id} - 单个供应商详情
- 前端使用 Chart.js 或 ECharts 展示
### 低优先级
#### 5. 多租户 Provider 配置
**目标**: 每个租户可配置独立 provider 列表和 API Key
**实现方案**:
- 租户配置表:
```python
class TenantProviderConfig:
tenant_id: str
provider_type: str
provider_ids: list[str] # 按优先级排序
api_key_override: str | None # 加密存储
```
- 路由时优先使用租户配置,回退到全局配置
#### 6. Provider 健康检查调度器
**目标**: 定期主动检查 provider 健康状态
**实现方案**:
- Celery Beat 定时任务
- 每 5 分钟检查一次所有启用的 provider
- 更新 ProviderHealth 表
#### 7. 适配器热加载
**目标**: 支持运行时动态加载新适配器
**实现方案**:
- 适配器插件目录: `app/services/adapters/plugins/`
- 启动时扫描并注册
- 提供 API 触发重新扫描
---
## API 变更记录
### v0.2.0 新增
| Method | Route | Description |
|--------|-------|-------------|
| GET | `/admin/secrets` | 列出所有密钥名称 |
| POST | `/admin/secrets` | 创建/更新密钥 |
| DELETE | `/admin/secrets/{name}` | 删除密钥 |
| GET | `/admin/secrets/{name}/verify` | 验证密钥有效性 |
### 计划中 (v0.3.0)
| Method | Route | Description |
|--------|-------|-------------|
| GET | `/admin/metrics/summary` | 指标汇总 |
| GET | `/admin/metrics/timeline` | 时间线数据 |
| POST | `/api/storybook/generate` | 生成分页故事书 |
| GET | `/admin/costs` | 成本统计 |
| POST | `/admin/budgets` | 设置预算 |
---
## 适配器开发指南
### 添加新适配器
1. 创建适配器文件:
```python
# app/services/adapters/image/new_provider.py
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
@AdapterRegistry.register("image", "new_provider")
class NewProviderAdapter(BaseAdapter[str]):
adapter_type = "image"
adapter_name = "new_provider"
async def execute(self, prompt: str, **kwargs) -> str:
# 实现生成逻辑
pass
async def health_check(self) -> bool:
# 实现健康检查
pass
@property
def estimated_cost(self) -> float:
return 0.01 # USD
```
2. 在 `__init__.py` 中导入:
```python
# app/services/adapters/__init__.py
from app.services.adapters.image import new_provider as _new_provider # noqa: F401
```
3. 添加配置:
```python
# app/core/config.py
new_provider_api_key: str = ""
# app/services/provider_router.py
API_KEY_MAP["new_provider"] = "new_provider_api_key"
```
4. 更新 `.env.example`:
```
NEW_PROVIDER_API_KEY=
```

46
backend/pyproject.toml Normal file
View File

@@ -0,0 +1,46 @@
[project]
name = "dreamweaver"
version = "0.1.0"
description = "AI 驱动的儿童故事生成应用"
requires-python = ">=3.11"
dependencies = [
"fastapi>=0.115.0",
"uvicorn[standard]>=0.32.0",
"sqlalchemy[asyncio]>=2.0.0",
"asyncpg>=0.30.0",
"pydantic>=2.0.0",
"pydantic-settings>=2.0.0",
"python-jose[cryptography]>=3.3.0",
"cryptography>=43.0.0",
"httpx>=0.28.0",
"alembic>=1.13.0",
"cachetools>=5.0.0",
"tenacity>=8.0.0",
"structlog>=24.0.0",
"sse-starlette>=2.0.0",
"celery>=5.4.0",
"redis>=5.0.0",
"openai>=1.0.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.24.0",
"pytest-cov>=4.0.0",
"aiosqlite>=0.20.0",
"ruff>=0.8.0",
]
[tool.setuptools.packages.find]
include = ["app*"]
[tool.ruff]
line-length = 100
target-version = "py311"
[tool.ruff.lint]
select = ["E", "F", "I", "N", "W"]
[tool.pytest.ini_options]
asyncio_mode = "auto"

View File

@@ -0,0 +1,27 @@
import asyncio
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.db.database import async_engine
from sqlalchemy import text
async def upgrade_db():
print("🚀 Checking database schema...")
async with async_engine.begin() as conn:
# Check if column exists
result = await conn.execute(text(
"SELECT column_name FROM information_schema.columns WHERE table_name='providers' AND column_name='config_json';"
))
if result.scalar():
print("✅ Column 'config_json' already exists.")
else:
print("⚠️ Column 'config_json' missing. Adding it now...")
await conn.execute(text("ALTER TABLE providers ADD COLUMN config_json JSON;"))
print("✅ Column 'config_json' added successfully.")
if __name__ == "__main__":
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(upgrade_db())

View File

@@ -0,0 +1,29 @@
import asyncio
import os
import sys
# Add backend to path
sys.path.append(os.path.join(os.getcwd()))
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from app.core.config import settings
async def add_column():
engine = create_async_engine(settings.database_url)
async_session = async_sessionmaker(engine, expire_on_commit=False)
async with async_session() as session:
try:
print("Adding config_json column to providers table...")
await session.execute(text("ALTER TABLE providers ADD COLUMN IF NOT EXISTS config_json JSONB DEFAULT '{}'::jsonb"))
await session.commit()
print("Successfully added config_json column.")
except Exception as e:
print(f"Error adding column: {e}")
await session.rollback()
if __name__ == "__main__":
import asyncio
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(add_column())

View File

@@ -0,0 +1,21 @@
import asyncio
import sys
import os
# Add backend to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.db.database import init_db
from app.core.logging import setup_logging
async def main():
setup_logging()
print("Initializing database...")
try:
await init_db()
print("Database initialized successfully.")
except Exception as e:
print(f"Error initializing database: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1 @@
# Tests package

146
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,146 @@
"""测试配置和 fixtures。"""
import os
from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing")
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
from app.core.security import create_access_token
from app.api.stories import _request_log
from app.db.database import get_db
from app.db.models import Base, Story, User
from app.main import app
@pytest.fixture
async def async_engine():
"""创建内存数据库引擎。"""
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""创建数据库会话。"""
session_factory = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with session_factory() as session:
yield session
@pytest.fixture
async def test_user(db_session: AsyncSession) -> User:
"""创建测试用户。"""
user = User(
id="github:12345",
name="Test User",
avatar_url="https://example.com/avatar.png",
provider="github",
)
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
return user
@pytest.fixture
async def test_story(db_session: AsyncSession, test_user: User) -> Story:
"""创建测试故事。"""
story = Story(
user_id=test_user.id,
title="测试故事",
story_text="从前有一只小兔子...",
cover_prompt="A cute rabbit in a forest",
mode="generated",
)
db_session.add(story)
await db_session.commit()
await db_session.refresh(story)
return story
@pytest.fixture
def auth_token(test_user: User) -> str:
"""生成测试用户的 JWT token。"""
return create_access_token({"sub": test_user.id})
@pytest.fixture
def client(db_session: AsyncSession) -> TestClient:
"""创建测试客户端。"""
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as c:
yield c
app.dependency_overrides.clear()
@pytest.fixture
def auth_client(client: TestClient, auth_token: str) -> TestClient:
"""带认证的测试客户端。"""
client.cookies.set("access_token", auth_token)
return client
@pytest.fixture(autouse=True)
def clear_rate_limit_cache():
"""确保每个测试用例的限流缓存互不影响。"""
_request_log.clear()
yield
_request_log.clear()
@pytest.fixture
def mock_text_provider():
"""Mock 文本生成适配器 API 调用。"""
from app.services.adapters.text.models import StoryOutput
mock_result = StoryOutput(
mode="generated",
title="小兔子的冒险",
story_text="从前有一只小兔子...",
cover_prompt_suggestion="A cute rabbit",
)
with patch("app.api.stories.generate_story_content", new_callable=AsyncMock) as mock:
mock.return_value = mock_result
yield mock
@pytest.fixture
def mock_image_provider():
"""Mock 图像生成。"""
with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock:
mock.return_value = "https://example.com/image.png"
yield mock
@pytest.fixture
def mock_tts_provider():
"""Mock TTS。"""
with patch("app.api.stories.text_to_speech", new_callable=AsyncMock) as mock:
mock.return_value = b"fake-audio-bytes"
yield mock
@pytest.fixture
def mock_all_providers(mock_text_provider, mock_image_provider, mock_tts_provider):
"""Mock 所有 AI 供应商。"""
return {
"text_primary": mock_text_provider,
"image_primary": mock_image_provider,
"tts_primary": mock_tts_provider,
}

View File

@@ -0,0 +1,65 @@
"""认证相关测试。"""
import pytest
from fastapi.testclient import TestClient
from app.core.security import create_access_token, decode_access_token
class TestJWT:
"""JWT token 测试。"""
def test_create_and_decode_token(self):
"""测试 token 创建和解码。"""
payload = {"sub": "github:12345"}
token = create_access_token(payload)
decoded = decode_access_token(token)
assert decoded is not None
assert decoded["sub"] == "github:12345"
def test_decode_invalid_token(self):
"""测试无效 token 解码。"""
result = decode_access_token("invalid-token")
assert result is None
def test_decode_empty_token(self):
"""测试空 token 解码。"""
result = decode_access_token("")
assert result is None
class TestSession:
"""Session 端点测试。"""
def test_session_without_auth(self, client: TestClient):
"""未登录时获取 session。"""
response = client.get("/auth/session")
assert response.status_code == 200
data = response.json()
assert data["user"] is None
def test_session_with_auth(self, auth_client: TestClient, test_user):
"""已登录时获取 session。"""
response = auth_client.get("/auth/session")
assert response.status_code == 200
data = response.json()
assert data["user"] is not None
assert data["user"]["id"] == test_user.id
assert data["user"]["name"] == test_user.name
def test_session_with_invalid_token(self, client: TestClient):
"""无效 token 获取 session。"""
client.cookies.set("access_token", "invalid-token")
response = client.get("/auth/session")
assert response.status_code == 200
data = response.json()
assert data["user"] is None
class TestSignout:
"""登出测试。"""
def test_signout(self, auth_client: TestClient):
"""测试登出。"""
response = auth_client.post("/auth/signout", follow_redirects=False)
assert response.status_code == 302

View File

@@ -0,0 +1,78 @@
"""Child profile API tests."""
from datetime import date
def _calc_age(birth_date: date) -> int:
today = date.today()
return today.year - birth_date.year - (
(today.month, today.day) < (birth_date.month, birth_date.day)
)
def test_list_profiles_empty(auth_client):
response = auth_client.get("/api/profiles")
assert response.status_code == 200
data = response.json()
assert data["profiles"] == []
assert data["total"] == 0
def test_create_update_delete_profile(auth_client):
payload = {
"name": "小明",
"birth_date": "2020-05-12",
"gender": "male",
"interests": ["太空", "机器人"],
"growth_themes": ["勇气"],
}
response = auth_client.post("/api/profiles", json=payload)
assert response.status_code == 201
data = response.json()
assert data["name"] == payload["name"]
assert data["gender"] == payload["gender"]
assert data["interests"] == payload["interests"]
assert data["growth_themes"] == payload["growth_themes"]
assert data["age"] == _calc_age(date.fromisoformat(payload["birth_date"]))
profile_id = data["id"]
update_payload = {"growth_themes": ["分享", "独立"]}
response = auth_client.put(f"/api/profiles/{profile_id}", json=update_payload)
assert response.status_code == 200
data = response.json()
assert data["growth_themes"] == update_payload["growth_themes"]
response = auth_client.delete(f"/api/profiles/{profile_id}")
assert response.status_code == 200
assert response.json()["message"] == "Deleted"
def test_profile_limit_and_duplicate(auth_client):
# 先测试重复名称(在达到限制前)
response = auth_client.post(
"/api/profiles",
json={"name": "孩子1", "gender": "female"},
)
assert response.status_code == 201
response = auth_client.post(
"/api/profiles",
json={"name": "孩子1", "gender": "female"},
)
assert response.status_code == 409 # 重复名称
# 继续创建到上限
for i in range(2, 6):
response = auth_client.post(
"/api/profiles",
json={"name": f"孩子{i}", "gender": "female"},
)
assert response.status_code == 201
# 测试数量限制
response = auth_client.post(
"/api/profiles",
json={"name": "孩子6", "gender": "female"},
)
assert response.status_code == 400 # 超过5个限制

View File

@@ -0,0 +1,195 @@
"""Provider router 测试 - failover 和配置加载。"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.adapters import AdapterConfig
from app.services.adapters.text.models import StoryOutput
class TestProviderFailover:
"""Provider failover 测试。"""
@pytest.mark.asyncio
async def test_failover_to_second_provider(self):
"""第一个 provider 失败时切换到第二个。"""
from app.services import provider_router
# Mock 两个 provider - 使用 spec=False 并显式设置所有属性
mock_provider_1 = MagicMock()
mock_provider_1.configure_mock(
id="provider-1",
type="text",
adapter="text_primary",
api_key="key1",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=10,
weight=1.0,
enabled=True,
)
mock_provider_2 = MagicMock()
mock_provider_2.configure_mock(
id="provider-2",
type="text",
adapter="text_primary",
api_key="key2",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=5,
weight=1.0,
enabled=True,
)
mock_providers = [mock_provider_1, mock_provider_2]
mock_result = StoryOutput(
mode="generated",
title="测试故事",
story_text="内容",
cover_prompt_suggestion="prompt",
)
call_count = 0
async def mock_execute(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise Exception("First provider failed")
return mock_result
with patch.object(provider_router, "get_providers", return_value=mock_providers):
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
mock_adapter_class = MagicMock()
mock_adapter_instance = MagicMock()
mock_adapter_instance.execute = mock_execute
mock_adapter_class.return_value = mock_adapter_instance
mock_get.return_value = mock_adapter_class
result = await provider_router.generate_story_content(
input_type="keywords",
data="测试",
)
assert result == mock_result
assert call_count == 2 # 第一个失败,第二个成功
@pytest.mark.asyncio
async def test_all_providers_fail(self):
"""所有 provider 都失败时抛出异常。"""
from app.services import provider_router
mock_provider = MagicMock()
mock_provider.configure_mock(
id="provider-1",
type="text",
adapter="text_primary",
api_key="key1",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=10,
weight=1.0,
enabled=True,
)
mock_providers = [mock_provider]
async def mock_execute(**kwargs):
raise Exception("Provider failed")
with patch.object(provider_router, "get_providers", return_value=mock_providers):
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
mock_adapter_class = MagicMock()
mock_adapter_instance = MagicMock()
mock_adapter_instance.execute = mock_execute
mock_adapter_class.return_value = mock_adapter_instance
mock_get.return_value = mock_adapter_class
with pytest.raises(ValueError, match="No text provider succeeded"):
await provider_router.generate_story_content(
input_type="keywords",
data="测试",
)
class TestProviderConfigFromDB:
"""从 DB 加载 provider 配置测试。"""
def test_build_config_from_provider_with_api_key(self):
"""Provider 有 api_key 时优先使用。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "text_primary"
mock_provider.api_key = "db-api-key"
mock_provider.api_base = "https://custom.api.com"
mock_provider.model = "custom-model"
mock_provider.timeout_ms = 30000
mock_provider.max_retries = 5
mock_provider.config_ref = None
mock_provider.config_json = {}
config = _build_config_from_provider(mock_provider)
assert config.api_key == "db-api-key"
assert config.api_base == "https://custom.api.com"
assert config.model == "custom-model"
assert config.timeout_ms == 30000
assert config.max_retries == 5
def test_build_config_fallback_to_settings(self):
"""Provider 无 api_key 时回退到 settings。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "text_primary"
mock_provider.api_key = None
mock_provider.api_base = None
mock_provider.model = None
mock_provider.timeout_ms = None
mock_provider.max_retries = None
mock_provider.config_ref = "text_api_key"
mock_provider.config_json = {}
with patch("app.services.provider_router.settings") as mock_settings:
mock_settings.text_api_key = "settings-api-key"
mock_settings.text_model = "gemini-2.0-flash"
config = _build_config_from_provider(mock_provider)
assert config.api_key == "settings-api-key"
class TestProviderCacheStartup:
"""Provider cache 启动加载测试。"""
@pytest.mark.asyncio
async def test_cache_loaded_on_startup(self):
"""启动时加载 provider cache。"""
from app.main import _load_provider_cache
with patch("app.db.database._get_session_factory") as mock_factory:
mock_session = AsyncMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock()
with patch("app.services.provider_cache.reload_providers", new_callable=AsyncMock) as mock_reload:
mock_reload.return_value = {"text": [], "image": [], "tts": []}
await _load_provider_cache()
mock_reload.assert_called_once()

View File

@@ -0,0 +1,77 @@
"""Push config API tests."""
def _create_profile(auth_client) -> str:
response = auth_client.post(
"/api/profiles",
json={"name": "小明", "gender": "male"},
)
assert response.status_code == 201
return response.json()["id"]
def test_create_list_update_push_config(auth_client):
profile_id = _create_profile(auth_client)
response = auth_client.put(
"/api/push-configs",
json={
"child_profile_id": profile_id,
"push_time": "20:30",
"push_days": [1, 3, 5],
"enabled": True,
},
)
assert response.status_code == 201
data = response.json()
assert data["child_profile_id"] == profile_id
assert data["push_time"].startswith("20:30")
assert data["push_days"] == [1, 3, 5]
assert data["enabled"] is True
response = auth_client.get("/api/push-configs")
assert response.status_code == 200
payload = response.json()
assert payload["total"] == 1
response = auth_client.put(
"/api/push-configs",
json={
"child_profile_id": profile_id,
"enabled": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["enabled"] is False
assert data["push_time"].startswith("20:30")
assert data["push_days"] == [1, 3, 5]
def test_push_config_validation(auth_client):
profile_id = _create_profile(auth_client)
response = auth_client.put(
"/api/push-configs",
json={"child_profile_id": profile_id},
)
assert response.status_code == 400
response = auth_client.put(
"/api/push-configs",
json={
"child_profile_id": profile_id,
"push_time": "19:00",
"push_days": [7],
},
)
assert response.status_code == 400
response = auth_client.put(
"/api/push-configs",
json={
"child_profile_id": profile_id,
"push_time": None,
},
)
assert response.status_code == 400

View File

@@ -0,0 +1,143 @@
"""Reading event API tests."""
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select
from app.db.database import get_db
from app.db.models import MemoryItem
from app.main import app
pytestmark = pytest.mark.asyncio
async def _create_profile(client: AsyncClient) -> str:
response = await client.post(
"/api/profiles",
json={"name": "小明", "gender": "male"},
)
assert response.status_code == 201
return response.json()["id"]
async def test_create_reading_event_updates_stats_and_memory(
db_session,
test_user,
auth_token,
test_story,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
profile_id = await _create_profile(client)
response = await client.post(
"/api/reading-events",
json={
"child_profile_id": profile_id,
"story_id": test_story.id,
"event_type": "completed",
"reading_time": 120,
},
)
assert response.status_code == 201
data = response.json()
assert data["child_profile_id"] == profile_id
assert data["story_id"] == test_story.id
assert data["event_type"] == "completed"
response = await client.get(f"/api/profiles/{profile_id}")
assert response.status_code == 200
profile = response.json()
assert profile["stories_count"] == 1
assert profile["total_reading_time"] == 120
result = await db_session.execute(
select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
)
items = result.scalars().all()
assert len(items) == 1
assert items[0].type == "recent_story"
assert items[0].value["story_id"] == test_story.id
response = await client.post(
"/api/reading-events",
json={
"child_profile_id": profile_id,
"story_id": test_story.id,
"event_type": "skipped",
"reading_time": 0,
},
)
assert response.status_code == 201
result = await db_session.execute(
select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
)
assert len(result.scalars().all()) == 1
response = await client.post(
"/api/reading-events",
json={
"child_profile_id": profile_id,
"story_id": test_story.id,
"event_type": "completed",
"reading_time": 0,
},
)
assert response.status_code == 201
response = await client.get(f"/api/profiles/{profile_id}")
assert response.status_code == 200
profile = response.json()
assert profile["stories_count"] == 1
assert profile["total_reading_time"] == 120
finally:
app.dependency_overrides.clear()
async def test_reading_event_validation_errors(
db_session,
test_user,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
"/api/reading-events",
json={
"child_profile_id": "not-exist",
"event_type": "started",
"reading_time": 0,
},
)
assert response.status_code == 404
profile_id = await _create_profile(client)
response = await client.post(
"/api/reading-events",
json={
"child_profile_id": profile_id,
"story_id": 999999,
"event_type": "completed",
"reading_time": 0,
},
)
assert response.status_code == 404
finally:
app.dependency_overrides.clear()

View File

@@ -0,0 +1,257 @@
"""故事 API 测试。"""
import time
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from app.api.stories import _request_log, RATE_LIMIT_REQUESTS
class TestStoryGenerate:
"""故事生成测试。"""
def test_generate_without_auth(self, client: TestClient):
"""未登录时生成故事。"""
response = client.post(
"/api/generate",
json={"type": "keywords", "data": "小兔子, 森林"},
)
assert response.status_code == 401
def test_generate_with_empty_data(self, auth_client: TestClient):
"""空数据生成故事。"""
response = auth_client.post(
"/api/generate",
json={"type": "keywords", "data": ""},
)
assert response.status_code == 422
def test_generate_with_invalid_type(self, auth_client: TestClient):
"""无效类型生成故事。"""
response = auth_client.post(
"/api/generate",
json={"type": "invalid", "data": "test"},
)
assert response.status_code == 422
def test_generate_story_success(self, auth_client: TestClient, mock_text_provider):
"""成功生成故事。"""
response = auth_client.post(
"/api/generate",
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert "title" in data
assert "story_text" in data
assert data["mode"] == "generated"
class TestStoryList:
"""故事列表测试。"""
def test_list_without_auth(self, client: TestClient):
"""未登录时获取列表。"""
response = client.get("/api/stories")
assert response.status_code == 401
def test_list_empty(self, auth_client: TestClient):
"""空列表。"""
response = auth_client.get("/api/stories")
assert response.status_code == 200
assert response.json() == []
def test_list_with_stories(self, auth_client: TestClient, test_story):
"""有故事时获取列表。"""
response = auth_client.get("/api/stories")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["id"] == test_story.id
assert data[0]["title"] == test_story.title
def test_list_pagination(self, auth_client: TestClient, test_story):
"""分页测试。"""
response = auth_client.get("/api/stories?limit=1&offset=0")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
response = auth_client.get("/api/stories?limit=1&offset=1")
assert response.status_code == 200
data = response.json()
assert len(data) == 0
class TestStoryDetail:
"""故事详情测试。"""
def test_get_story_without_auth(self, client: TestClient, test_story):
"""未登录时获取详情。"""
response = client.get(f"/api/stories/{test_story.id}")
assert response.status_code == 401
def test_get_story_not_found(self, auth_client: TestClient):
"""故事不存在。"""
response = auth_client.get("/api/stories/99999")
assert response.status_code == 404
def test_get_story_success(self, auth_client: TestClient, test_story):
"""成功获取详情。"""
response = auth_client.get(f"/api/stories/{test_story.id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == test_story.id
assert data["title"] == test_story.title
assert data["story_text"] == test_story.story_text
class TestStoryDelete:
"""故事删除测试。"""
def test_delete_without_auth(self, client: TestClient, test_story):
"""未登录时删除。"""
response = client.delete(f"/api/stories/{test_story.id}")
assert response.status_code == 401
def test_delete_not_found(self, auth_client: TestClient):
"""删除不存在的故事。"""
response = auth_client.delete("/api/stories/99999")
assert response.status_code == 404
def test_delete_success(self, auth_client: TestClient, test_story):
"""成功删除故事。"""
response = auth_client.delete(f"/api/stories/{test_story.id}")
assert response.status_code == 200
assert response.json()["message"] == "Deleted"
response = auth_client.get(f"/api/stories/{test_story.id}")
assert response.status_code == 404
class TestRateLimit:
"""Rate limit 测试。"""
def setup_method(self):
"""每个测试前清理 rate limit 缓存。"""
_request_log.clear()
def test_rate_limit_allows_normal_requests(self, auth_client: TestClient, test_story):
"""正常请求不触发限流。"""
for _ in range(RATE_LIMIT_REQUESTS - 1):
response = auth_client.get(f"/api/stories/{test_story.id}")
assert response.status_code == 200
def test_rate_limit_blocks_excess_requests(self, auth_client: TestClient, test_story):
"""超限请求被阻止。"""
for _ in range(RATE_LIMIT_REQUESTS):
auth_client.get(f"/api/stories/{test_story.id}")
response = auth_client.get(f"/api/stories/{test_story.id}")
assert response.status_code == 429
assert "Too many requests" in response.json()["detail"]
class TestImageGenerate:
"""封面图片生成测试。"""
def test_generate_image_without_auth(self, client: TestClient, test_story):
"""未登录时生成图片。"""
response = client.post(f"/api/image/generate/{test_story.id}")
assert response.status_code == 401
def test_generate_image_not_found(self, auth_client: TestClient):
"""故事不存在。"""
response = auth_client.post("/api/image/generate/99999")
assert response.status_code == 404
class TestAudio:
"""语音朗读测试。"""
def test_get_audio_without_auth(self, client: TestClient, test_story):
"""未登录时获取音频。"""
response = client.get(f"/api/audio/{test_story.id}")
assert response.status_code == 401
def test_get_audio_not_found(self, auth_client: TestClient):
"""故事不存在。"""
response = auth_client.get("/api/audio/99999")
assert response.status_code == 404
def test_get_audio_success(self, auth_client: TestClient, test_story, mock_tts_provider):
"""成功获取音频。"""
response = auth_client.get(f"/api/audio/{test_story.id}")
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
class TestGenerateFull:
"""完整故事生成测试(/api/generate/full"""
def test_generate_full_without_auth(self, client: TestClient):
"""未登录时生成完整故事。"""
response = client.post(
"/api/generate/full",
json={"type": "keywords", "data": "小兔子, 森林"},
)
assert response.status_code == 401
def test_generate_full_success(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
"""成功生成完整故事(含图片)。"""
response = auth_client.post(
"/api/generate/full",
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert "title" in data
assert "story_text" in data
assert data["mode"] == "generated"
assert data["image_url"] == "https://example.com/image.png"
assert data["audio_ready"] is False # 音频按需生成
assert data["errors"] == {}
def test_generate_full_image_failure(self, auth_client: TestClient, mock_text_provider):
"""图片生成失败时返回部分成功。"""
with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock_img:
mock_img.side_effect = Exception("Image API error")
response = auth_client.post(
"/api/generate/full",
json={"type": "keywords", "data": "小兔子, 森林"},
)
assert response.status_code == 200
data = response.json()
assert data["image_url"] is None
assert "image" in data["errors"]
assert "Image API error" in data["errors"]["image"]
def test_generate_full_with_education_theme(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
"""带教育主题生成故事。"""
response = auth_client.post(
"/api/generate/full",
json={
"type": "keywords",
"data": "小兔子, 森林",
"education_theme": "勇气与友谊",
},
)
assert response.status_code == 200
mock_text_provider.assert_called_once()
call_kwargs = mock_text_provider.call_args.kwargs
assert call_kwargs["education_theme"] == "勇气与友谊"
class TestImageGenerateSuccess:
"""封面图片生成成功测试。"""
def test_generate_image_success(self, auth_client: TestClient, test_story, mock_image_provider):
"""成功生成图片。"""
response = auth_client.post(f"/api/image/generate/{test_story.id}")
assert response.status_code == 200
data = response.json()
assert data["image_url"] == "https://example.com/image.png"

View File

@@ -0,0 +1,68 @@
"""Story universe API tests."""
def _create_profile(auth_client):
response = auth_client.post(
"/api/profiles",
json={
"name": "小明",
"gender": "male",
"interests": ["太空"],
"growth_themes": ["勇气"],
},
)
assert response.status_code == 201
return response.json()["id"]
def test_create_list_update_universe(auth_client):
profile_id = _create_profile(auth_client)
payload = {
"name": "星际冒险",
"protagonist": {"name": "小明", "role": "船长"},
"recurring_characters": [{"name": "小七", "role": "机器人"}],
"world_settings": {"world_name": "星际学院"},
}
response = auth_client.post(f"/api/profiles/{profile_id}/universes", json=payload)
assert response.status_code == 201
data = response.json()
assert data["name"] == payload["name"]
universe_id = data["id"]
response = auth_client.get(f"/api/profiles/{profile_id}/universes")
assert response.status_code == 200
list_data = response.json()
assert list_data["total"] == 1
response = auth_client.put(
f"/api/universes/{universe_id}",
json={"name": "星际冒险·第二季"},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "星际冒险·第二季"
def test_add_achievement(auth_client):
profile_id = _create_profile(auth_client)
response = auth_client.post(
f"/api/profiles/{profile_id}/universes",
json={
"name": "梦幻森林",
"protagonist": {"name": "小红"},
},
)
assert response.status_code == 201
universe_id = response.json()["id"]
response = auth_client.post(
f"/api/universes/{universe_id}/achievements",
json={"type": "勇气", "description": "克服黑暗"},
)
assert response.status_code == 200
data = response.json()
assert {"type": "勇气", "description": "克服黑暗"} in data["achievements"]