Initial commit: clean project structure
- Backend: FastAPI + SQLAlchemy + Celery (Python 3.11+) - Frontend: Vue 3 + TypeScript + Pinia + Tailwind - Admin Frontend: separate Vue 3 app for management - Docker Compose: 9 services orchestration - Specs: design prototypes, memory system PRD, product roadmap Cleanup performed: - Removed temporary debug scripts from backend root - Removed deprecated admin_app.py (embedded UI) - Removed duplicate docs from admin-frontend - Updated .gitignore for Vite cache and egg-info
This commit is contained in:
115
backend/.env.example
Normal file
115
backend/.env.example
Normal file
@@ -0,0 +1,115 @@
|
||||
# ==============================================
|
||||
# DREAMWEAVER 环境变量配置模板
|
||||
# ==============================================
|
||||
# 使用说明:
|
||||
# 1. 复制此文件为 .env
|
||||
# 2. 填入您的 API Keys
|
||||
# 3. 配合 docker-compose.yml 启动
|
||||
# ==============================================
|
||||
|
||||
# ----------------------------------------------
|
||||
# 1. 基础设施 (Infrastructure) [必填]
|
||||
# ----------------------------------------------
|
||||
# ⚠️ 在 Docker 启动时无需修改这部分,直接使用默认值即可
|
||||
# ⚠️ 仅当您想连接外部数据库时才修改这里
|
||||
POSTGRES_USER=dreamweaver
|
||||
POSTGRES_PASSWORD=dreamweaver_password
|
||||
POSTGRES_DB=dreamweaver_db
|
||||
POSTGRES_PORT=5432
|
||||
REDIS_PORT=6379
|
||||
|
||||
DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB}
|
||||
CELERY_BROKER_URL=redis://redis:6379/0
|
||||
CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||
|
||||
# Web Security
|
||||
SECRET_KEY=change-me-to-a-secure-random-string-in-production
|
||||
DEBUG=true
|
||||
|
||||
|
||||
# ----------------------------------------------
|
||||
# 2. AI 引擎配置 (AI Engines) [核心]
|
||||
# ----------------------------------------------
|
||||
# [策略配置]
|
||||
# 系统默认使用的供应商列表 (按优先级排序)
|
||||
# 文本生成: 优先 Gemini,其次 OpenAI
|
||||
TEXT_PROVIDERS=["gemini", "openai"]
|
||||
# 图片生成: 优先 CQTAI (Flux/NanoBanana)
|
||||
IMAGE_PROVIDERS=["cqtai"]
|
||||
# 语音生成: 优先 MiniMax,其次 ElevenLabs,最后 EdgeTTS(免费)
|
||||
TTS_PROVIDERS=["minimax", "elevenlabs", "edge_tts"]
|
||||
|
||||
# [模型参数]
|
||||
TEXT_MODEL=gemini-2.0-flash
|
||||
IMAGE_MODEL=nano-banana
|
||||
IMAGE_RESOLUTION=1K
|
||||
# TTS_MODEL=speech-2.6-turbo (MiniMax) / zh-CN-XiaoxiaoNeural (Edge)
|
||||
|
||||
# [API 密钥池]
|
||||
# 请填入您拥有的 Key,没有的留空即可
|
||||
# ⚠️ 注意: 除非您使用国内中转(OneAPI)或企业私有版,否则无需填写 API_BASE (系统会自动使用官方地址)
|
||||
|
||||
# Google Gemini
|
||||
TEXT_API_KEY=
|
||||
TEXT_API_BASE=
|
||||
|
||||
# CQTAI / GoQuantum (Image)
|
||||
CQTAI_API_KEY=
|
||||
# CQTAI_API_BASE=https://api.cqtai.com/v1
|
||||
|
||||
# Antigravity (Image - OpenAI Compatible)
|
||||
ANTIGRAVITY_API_KEY=
|
||||
ANTIGRAVITY_API_BASE=http://127.0.0.1:8045/v1
|
||||
# 模型: gemini-3-pro-image, gemini-3-pro-image-16-9, etc.
|
||||
|
||||
# MiniMax (TTS)
|
||||
MINIMAX_API_KEY=
|
||||
# MINIMAX_GROUP_ID 是 MiniMax v1/v2 接口必须的参数 (通常在 MiniMax 控制台可见)
|
||||
MINIMAX_GROUP_ID=
|
||||
MINIMAX_API_BASE=
|
||||
|
||||
# ElevenLabs (TTS)
|
||||
ELEVENLABS_API_KEY=
|
||||
# ELEVENLABS_API_BASE=https://api.elevenlabs.io/v1
|
||||
|
||||
# OpenAI (如需使用)
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_API_BASE=
|
||||
|
||||
# ----------------------------------------------
|
||||
# 3. 第三方登录 (OAuth Config) [可选]
|
||||
# ----------------------------------------------
|
||||
# 若留空,则无法使用该方式登录
|
||||
GITHUB_CLIENT_ID=
|
||||
GITHUB_CLIENT_SECRET=
|
||||
GOOGLE_CLIENT_ID=
|
||||
GOOGLE_CLIENT_SECRET=
|
||||
|
||||
|
||||
# ----------------------------------------------
|
||||
# 4. 管理后台 (Admin Console)
|
||||
# ----------------------------------------------
|
||||
# 是否开启 /admin 路由与 API (生产环境建议 false)
|
||||
ENABLE_ADMIN_CONSOLE=true
|
||||
|
||||
# 管理员 Basic Auth 账号
|
||||
ADMIN_USERNAME=admin
|
||||
ADMIN_PASSWORD=admin
|
||||
|
||||
|
||||
# ----------------------------------------------
|
||||
# 5. 部署与网络 (Deployment & Network)
|
||||
# ----------------------------------------------
|
||||
# [外部访问地址]
|
||||
# 用于 OAuth 回调验证 (对应 docker-compose 的 52000 端口)
|
||||
BASE_URL=http://localhost:52000
|
||||
|
||||
# [跨域白名单 CORS]
|
||||
# 包含 User Frontend (52080), Admin Frontend (52888) 及本地开发端口
|
||||
CORS_ORIGINS=["http://localhost:52080", "http://localhost:52888", "http://localhost:5173", "http://localhost:5174"]
|
||||
|
||||
# [本地开发覆盖 Local Dev Override]
|
||||
# 如果您不使用 Docker,而是在本机直接运行 `python -m uvicorn ...`
|
||||
# 请取消注释以下行以连接 localhost 数据库:
|
||||
# DATABASE_URL=postgresql+asyncpg://dreamweaver:dreamweaver_password@localhost:52432/dreamweaver_db
|
||||
# CELERY_BROKER_URL=redis://localhost:52379/0
|
||||
27
backend/.gitignore
vendored
Normal file
27
backend/.gitignore
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# 环境变量
|
||||
.env
|
||||
|
||||
# 测试
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# 其他
|
||||
*.log
|
||||
.DS_Store
|
||||
27
backend/Dockerfile
Normal file
27
backend/Dockerfile
Normal file
@@ -0,0 +1,27 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装系统依赖 (如果需要)
|
||||
# RUN apt-get update && apt-get install -y gcc libpq-dev && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制项目文件
|
||||
COPY pyproject.toml .
|
||||
# 复制源码
|
||||
COPY app ./app
|
||||
COPY alembic ./alembic
|
||||
COPY alembic.ini .
|
||||
|
||||
# 安装依赖
|
||||
# 使用 pip 安装当前目录 (.),会自动解析 pyproject.toml
|
||||
RUN pip install --no-cache-dir .
|
||||
|
||||
# 创建静态文件目录 (用于存放生成的图片)
|
||||
RUN mkdir -p static/images
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8000
|
||||
|
||||
# 启动命令
|
||||
# 生产环境建议使用 gunicorn 或 uvicorn --workers
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
38
backend/alembic.ini
Normal file
38
backend/alembic.ini
Normal file
@@ -0,0 +1,38 @@
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
sqlalchemy.url = postgresql+asyncpg://user:password@localhost/db
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
20
backend/alembic/README.md
Normal file
20
backend/alembic/README.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Alembic 使用说明
|
||||
|
||||
1. 安装依赖(在后端虚拟环境内)
|
||||
```
|
||||
pip install alembic
|
||||
```
|
||||
|
||||
2. 设置环境变量,确保 `DATABASE_URL` 指向目标数据库。
|
||||
|
||||
3. 运行迁移:
|
||||
```
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
4. 生成新迁移(如有模型变更):
|
||||
```
|
||||
alembic revision -m "message" --autogenerate
|
||||
```
|
||||
|
||||
说明:`alembic/env.py` 会从 `app.core.config` 读取数据库 URL,并包含 admin/provider 模型。
|
||||
68
backend/alembic/env.py
Normal file
68
backend/alembic/env.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from alembic import context
|
||||
from app.core.config import settings
|
||||
from app.db import models, admin_models # ensure models are imported
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# override sqlalchemy.url from settings
|
||||
config.set_main_option("sqlalchemy.url", settings.database_url)
|
||||
|
||||
target_metadata = models.Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline():
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
compare_type=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection):
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_migrations_online():
|
||||
"""Run migrations in 'online' mode."""
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
connect_args={"statement_cache_size": 0},
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
asyncio.run(run_migrations_online())
|
||||
@@ -0,0 +1,45 @@
|
||||
"""init providers and story mode"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0001_init_providers"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"providers",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("type", sa.String(length=50), nullable=False),
|
||||
sa.Column("adapter", sa.String(length=100), nullable=False),
|
||||
sa.Column("model", sa.String(length=200), nullable=True),
|
||||
sa.Column("api_base", sa.String(length=300), nullable=True),
|
||||
sa.Column("timeout_ms", sa.Integer(), server_default="60000", nullable=False),
|
||||
sa.Column("max_retries", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("weight", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("priority", sa.Integer(), server_default="0", nullable=False),
|
||||
sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False),
|
||||
sa.Column("config_ref", sa.String(length=100), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_by", sa.String(length=100), nullable=True),
|
||||
)
|
||||
|
||||
with op.batch_alter_table("stories", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("mode", sa.String(length=20), server_default="generated", nullable=False)
|
||||
)
|
||||
batch_op.alter_column("mode", server_default=None)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table("stories", schema=None) as batch_op:
|
||||
batch_op.drop_column("mode")
|
||||
|
||||
op.drop_table("providers")
|
||||
29
backend/alembic/versions/0002_add_api_key_to_providers.py
Normal file
29
backend/alembic/versions/0002_add_api_key_to_providers.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""add api_key to providers
|
||||
|
||||
Revision ID: 0002_add_api_key_to_providers
|
||||
Revises: 0001_init_providers_and_story_mode
|
||||
Create Date: 2025-01-01
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0002_add_api_key"
|
||||
down_revision = "0001_init_providers"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 添加 api_key 列,可为空,优先于 config_ref 使用
|
||||
with op.batch_alter_table("providers", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("api_key", sa.String(length=500), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table("providers", schema=None) as batch_op:
|
||||
batch_op.drop_column("api_key")
|
||||
100
backend/alembic/versions/0003_add_provider_monitoring_tables.py
Normal file
100
backend/alembic/versions/0003_add_provider_monitoring_tables.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""add provider monitoring tables
|
||||
|
||||
Revision ID: 0003_add_monitoring
|
||||
Revises: 0002_add_api_key
|
||||
Create Date: 2025-01-01
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0003_add_monitoring"
|
||||
down_revision = "0002_add_api_key"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 创建 provider_metrics 表
|
||||
op.create_table(
|
||||
"provider_metrics",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("provider_id", sa.String(length=36), nullable=False),
|
||||
sa.Column(
|
||||
"timestamp",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("success", sa.Boolean(), nullable=False),
|
||||
sa.Column("latency_ms", sa.Integer(), nullable=True),
|
||||
sa.Column("cost_usd", sa.Numeric(precision=10, scale=6), nullable=True),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column("request_id", sa.String(length=100), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["provider_id"],
|
||||
["providers.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_provider_metrics_provider_id",
|
||||
"provider_metrics",
|
||||
["provider_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_provider_metrics_timestamp",
|
||||
"provider_metrics",
|
||||
["timestamp"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# 创建 provider_health 表
|
||||
op.create_table(
|
||||
"provider_health",
|
||||
sa.Column("provider_id", sa.String(length=36), nullable=False),
|
||||
sa.Column("is_healthy", sa.Boolean(), server_default=sa.text("true"), nullable=True),
|
||||
sa.Column("last_check", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("consecutive_failures", sa.Integer(), server_default=sa.text("0"), nullable=True),
|
||||
sa.Column("last_error", sa.Text(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["provider_id"],
|
||||
["providers.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("provider_id"),
|
||||
)
|
||||
|
||||
# 创建 provider_secrets 表
|
||||
op.create_table(
|
||||
"provider_secrets",
|
||||
sa.Column("id", sa.String(length=36), nullable=False),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("encrypted_value", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("provider_secrets")
|
||||
op.drop_table("provider_health")
|
||||
op.drop_index("ix_provider_metrics_timestamp", table_name="provider_metrics")
|
||||
op.drop_index("ix_provider_metrics_provider_id", table_name="provider_metrics")
|
||||
op.drop_table("provider_metrics")
|
||||
42
backend/alembic/versions/0004_add_child_profiles.py
Normal file
42
backend/alembic/versions/0004_add_child_profiles.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""add child profiles
|
||||
|
||||
Revision ID: 0004_add_child_profiles
|
||||
Revises: 0003_add_monitoring
|
||||
Create Date: 2025-12-22
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0004_add_child_profiles"
|
||||
down_revision = "0003_add_monitoring"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"child_profiles",
|
||||
sa.Column("id", sa.String(36), primary_key=True),
|
||||
sa.Column("user_id", sa.String(255), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("name", sa.String(50), nullable=False),
|
||||
sa.Column("avatar_url", sa.String(500)),
|
||||
sa.Column("birth_date", sa.Date()),
|
||||
sa.Column("gender", sa.String(10)),
|
||||
sa.Column("interests", sa.JSON(), server_default="[]", nullable=False),
|
||||
sa.Column("growth_themes", sa.JSON(), server_default="[]", nullable=False),
|
||||
sa.Column("reading_preferences", sa.JSON(), server_default="{}", nullable=False),
|
||||
sa.Column("stories_count", sa.Integer(), server_default="0", nullable=False),
|
||||
sa.Column("total_reading_time", sa.Integer(), server_default="0", nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.UniqueConstraint("user_id", "name", name="uq_child_profile_user_name"),
|
||||
)
|
||||
op.create_index("idx_child_profiles_user_id", "child_profiles", ["user_id"])
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index("idx_child_profiles_user_id", table_name="child_profiles")
|
||||
op.drop_table("child_profiles")
|
||||
@@ -0,0 +1,67 @@
|
||||
"""add story universes and story links
|
||||
|
||||
Revision ID: 0005_add_story_universes_and_story_links
|
||||
Revises: 0004_add_child_profiles
|
||||
Create Date: 2025-12-22
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = "0005_add_story_universes_and_story_links"
|
||||
down_revision = "0004_add_child_profiles"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"story_universes",
|
||||
sa.Column("id", sa.String(36), primary_key=True),
|
||||
sa.Column(
|
||||
"child_profile_id",
|
||||
sa.String(36),
|
||||
sa.ForeignKey("child_profiles.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("name", sa.String(100), nullable=False),
|
||||
sa.Column("protagonist", sa.JSON(), nullable=False),
|
||||
sa.Column("recurring_characters", sa.JSON(), server_default="[]", nullable=False),
|
||||
sa.Column("world_settings", sa.JSON(), server_default="{}", nullable=False),
|
||||
sa.Column("achievements", sa.JSON(), server_default="[]", nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
op.create_index("idx_story_universes_child_id", "story_universes", ["child_profile_id"])
|
||||
op.create_index("idx_story_universes_updated_at", "story_universes", ["updated_at"])
|
||||
|
||||
op.add_column("stories", sa.Column("child_profile_id", sa.String(36), nullable=True))
|
||||
op.add_column("stories", sa.Column("universe_id", sa.String(36), nullable=True))
|
||||
op.create_foreign_key(
|
||||
"fk_stories_child_profile",
|
||||
"stories",
|
||||
"child_profiles",
|
||||
["child_profile_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_stories_universe",
|
||||
"stories",
|
||||
"story_universes",
|
||||
["universe_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_constraint("fk_stories_universe", "stories", type_="foreignkey")
|
||||
op.drop_constraint("fk_stories_child_profile", "stories", type_="foreignkey")
|
||||
op.drop_column("stories", "universe_id")
|
||||
op.drop_column("stories", "child_profile_id")
|
||||
|
||||
op.drop_index("idx_story_universes_updated_at", table_name="story_universes")
|
||||
op.drop_index("idx_story_universes_child_id", table_name="story_universes")
|
||||
op.drop_table("story_universes")
|
||||
@@ -0,0 +1,78 @@
|
||||
"""add reading events and memory items
|
||||
|
||||
Revision ID: 0006_add_reading_events_and_memory_items
|
||||
Revises: 0005_add_story_universes_and_story_links
|
||||
Create Date: 2025-12-22
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = "0006_add_reading_events_and_memory_items"
|
||||
down_revision = "0005_add_story_universes_and_story_links"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"reading_events",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column(
|
||||
"child_profile_id",
|
||||
sa.String(36),
|
||||
sa.ForeignKey("child_profiles.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"story_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("stories.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("event_type", sa.String(20), nullable=False),
|
||||
sa.Column("reading_time", sa.Integer(), server_default="0", nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
op.create_index("idx_reading_events_profile", "reading_events", ["child_profile_id"])
|
||||
op.create_index("idx_reading_events_story", "reading_events", ["story_id"])
|
||||
op.create_index("idx_reading_events_created", "reading_events", ["created_at"])
|
||||
|
||||
op.create_table(
|
||||
"memory_items",
|
||||
sa.Column("id", sa.String(36), primary_key=True),
|
||||
sa.Column(
|
||||
"child_profile_id",
|
||||
sa.String(36),
|
||||
sa.ForeignKey("child_profiles.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"universe_id",
|
||||
sa.String(36),
|
||||
sa.ForeignKey("story_universes.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("type", sa.String(50), nullable=False),
|
||||
sa.Column("value", sa.JSON(), nullable=False),
|
||||
sa.Column("base_weight", sa.Float(), server_default="1.0", nullable=False),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True)),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("ttl_days", sa.Integer()),
|
||||
)
|
||||
op.create_index("idx_memory_items_profile", "memory_items", ["child_profile_id"])
|
||||
op.create_index("idx_memory_items_universe", "memory_items", ["universe_id"])
|
||||
op.create_index("idx_memory_items_last_used", "memory_items", ["last_used_at"])
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index("idx_memory_items_last_used", table_name="memory_items")
|
||||
op.drop_index("idx_memory_items_universe", table_name="memory_items")
|
||||
op.drop_index("idx_memory_items_profile", table_name="memory_items")
|
||||
op.drop_table("memory_items")
|
||||
|
||||
op.drop_index("idx_reading_events_created", table_name="reading_events")
|
||||
op.drop_index("idx_reading_events_story", table_name="reading_events")
|
||||
op.drop_index("idx_reading_events_profile", table_name="reading_events")
|
||||
op.drop_table("reading_events")
|
||||
68
backend/alembic/versions/0007_add_push_configs_and_events.py
Normal file
68
backend/alembic/versions/0007_add_push_configs_and_events.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Add push configs and events.
|
||||
|
||||
Revision ID: 0007_add_push_configs_and_events
|
||||
Revises: 0006_add_reading_events_and_memory_items
|
||||
Create Date: 2025-12-24 16:40:00.000000
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0007_add_push_configs_and_events"
|
||||
down_revision = "0006_add_reading_events_and_memory_items"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"push_configs",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("user_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("child_profile_id", sa.String(length=36), nullable=False),
|
||||
sa.Column("push_time", sa.Time(), nullable=True),
|
||||
sa.Column("push_days", sa.JSON(), nullable=False, server_default="[]"),
|
||||
sa.Column("enabled", sa.Boolean(), nullable=False, server_default=sa.true()),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["child_profile_id"], ["child_profiles.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("child_profile_id", name="uq_push_config_child"),
|
||||
)
|
||||
op.create_index("ix_push_configs_user_id", "push_configs", ["user_id"])
|
||||
op.create_index("ix_push_configs_child_profile_id", "push_configs", ["child_profile_id"])
|
||||
|
||||
op.create_table(
|
||||
"push_events",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("user_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("child_profile_id", sa.String(length=36), nullable=False),
|
||||
sa.Column("trigger_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("status", sa.String(length=20), nullable=False),
|
||||
sa.Column("reason", sa.Text(), nullable=True),
|
||||
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["child_profile_id"], ["child_profiles.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_push_events_user_id", "push_events", ["user_id"])
|
||||
op.create_index("ix_push_events_child_profile_id", "push_events", ["child_profile_id"])
|
||||
op.create_index("ix_push_events_sent_at", "push_events", ["sent_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_push_events_sent_at", table_name="push_events")
|
||||
op.drop_index("ix_push_events_child_profile_id", table_name="push_events")
|
||||
op.drop_index("ix_push_events_user_id", table_name="push_events")
|
||||
op.drop_table("push_events")
|
||||
|
||||
op.drop_index("ix_push_configs_child_profile_id", table_name="push_configs")
|
||||
op.drop_index("ix_push_configs_user_id", table_name="push_configs")
|
||||
op.drop_table("push_configs")
|
||||
25
backend/alembic/versions/0008_add_pages_to_stories.py
Normal file
25
backend/alembic/versions/0008_add_pages_to_stories.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""add pages column to stories
|
||||
|
||||
Revision ID: 0008_add_pages_to_stories
|
||||
Revises: 0007_add_push_configs_and_events
|
||||
Create Date: 2026-01-20
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '0008_add_pages_to_stories'
|
||||
down_revision = '0007_add_push_configs_and_events'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('stories', sa.Column('pages', postgresql.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('stories', 'pages')
|
||||
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
61
backend/app/admin_main.py
Normal file
61
backend/app/admin_main.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api import admin_providers, admin_reload
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger, setup_logging
|
||||
from app.db.database import init_db
|
||||
|
||||
setup_logging()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Admin App lifespan manager."""
|
||||
logger.info("admin_app_starting")
|
||||
await init_db()
|
||||
|
||||
# 可以在这里加载特定的 Admin 缓存或预热
|
||||
|
||||
yield
|
||||
logger.info("admin_app_shutdown")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title=f"{settings.app_name} Admin Console",
|
||||
description="Administrative Control Plane for DreamWeaver.",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Admin 后台通常允许更宽松的 CORS,或者特定的管理域名
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins, # 或者专门的 ADMIN_CORS_ORIGINS
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 根据配置开关挂载路由
|
||||
if settings.enable_admin_console:
|
||||
app.include_router(admin_providers.router, prefix="/admin", tags=["admin-providers"])
|
||||
app.include_router(admin_reload.router, prefix="/admin", tags=["admin-reload"])
|
||||
else:
|
||||
@app.get("/admin/{path:path}")
|
||||
@app.post("/admin/{path:path}")
|
||||
@app.put("/admin/{path:path}")
|
||||
@app.delete("/admin/{path:path}")
|
||||
async def admin_disabled(path: str):
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Admin console is disabled in environment configuration."
|
||||
)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "ok", "service": "admin-backend"}
|
||||
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
307
backend/app/api/admin_providers.py
Normal file
307
backend/app/api/admin_providers.py
Normal file
@@ -0,0 +1,307 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.admin_auth import admin_guard
|
||||
from app.db.admin_models import Provider
|
||||
from app.db.database import get_db
|
||||
from app.services.cost_tracker import cost_tracker
|
||||
from app.services.secret_service import SecretService
|
||||
|
||||
router = APIRouter(dependencies=[Depends(admin_guard)])
|
||||
|
||||
|
||||
class ProviderCreate(BaseModel):
|
||||
name: str
|
||||
type: str = Field(..., pattern="^(text|image|tts|storybook)$")
|
||||
adapter: str
|
||||
model: str | None = None
|
||||
api_base: str | None = None
|
||||
api_key: str | None = None # 可选,优先于 config_ref
|
||||
timeout_ms: int = 60000
|
||||
max_retries: int = 1
|
||||
weight: int = 1
|
||||
priority: int = 0
|
||||
enabled: bool = True
|
||||
config_json: dict | None = None
|
||||
config_ref: str | None = None # 环境变量 key 名称(回退)
|
||||
updated_by: str | None = None
|
||||
|
||||
|
||||
class ProviderUpdate(ProviderCreate):
|
||||
enabled: bool | None = None
|
||||
api_key: str | None = None
|
||||
config_json: dict | None = None
|
||||
|
||||
|
||||
class ProviderResponse(BaseModel):
|
||||
"""Provider 响应模型,隐藏敏感字段。"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
adapter: str
|
||||
model: str | None = None
|
||||
api_base: str | None = None
|
||||
has_api_key: bool = False # 仅标识是否配置了 api_key,不返回明文
|
||||
timeout_ms: int = 60000
|
||||
max_retries: int = 1
|
||||
weight: int = 1
|
||||
priority: int = 0
|
||||
enabled: bool = True
|
||||
config_ref: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
from app.services.provider_router import DEFAULT_PROVIDERS
|
||||
|
||||
|
||||
@router.get("/providers/adapters")
|
||||
async def list_available_adapters():
|
||||
"""获取所有可用的适配器类型 (定义的类)。"""
|
||||
return AdapterRegistry.list_adapters()
|
||||
|
||||
|
||||
@router.get("/providers/defaults")
|
||||
async def get_env_defaults():
|
||||
"""获取当前环境变量定义的默认策略 (Read-Only)。"""
|
||||
return DEFAULT_PROVIDERS
|
||||
|
||||
|
||||
@router.get("/providers", response_model=list[ProviderResponse])
|
||||
async def list_providers(db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Provider))
|
||||
providers = result.scalars().all()
|
||||
# 转换为响应模型,隐藏 api_key 明文
|
||||
return [
|
||||
ProviderResponse(
|
||||
id=p.id,
|
||||
name=p.name,
|
||||
type=p.type,
|
||||
adapter=p.adapter,
|
||||
model=p.model,
|
||||
api_base=p.api_base,
|
||||
has_api_key=bool(p.api_key), # 仅标识是否有 key
|
||||
timeout_ms=p.timeout_ms,
|
||||
max_retries=p.max_retries,
|
||||
weight=p.weight,
|
||||
priority=p.priority,
|
||||
enabled=p.enabled,
|
||||
config_ref=p.config_ref,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
|
||||
|
||||
def _to_response(provider: Provider) -> ProviderResponse:
|
||||
"""将 Provider 转换为响应模型,隐藏敏感字段。"""
|
||||
return ProviderResponse(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
type=provider.type,
|
||||
adapter=provider.adapter,
|
||||
model=provider.model,
|
||||
api_base=provider.api_base,
|
||||
has_api_key=bool(provider.api_key),
|
||||
timeout_ms=provider.timeout_ms,
|
||||
max_retries=provider.max_retries,
|
||||
weight=provider.weight,
|
||||
priority=provider.priority,
|
||||
enabled=provider.enabled,
|
||||
config_ref=provider.config_ref,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/providers", response_model=ProviderResponse)
|
||||
async def create_provider(payload: ProviderCreate, db: AsyncSession = Depends(get_db)):
|
||||
data = payload.model_dump()
|
||||
# 加密 API Key
|
||||
if data.get("api_key"):
|
||||
data["api_key"] = SecretService.encrypt(data["api_key"])
|
||||
provider = Provider(**data)
|
||||
db.add(provider)
|
||||
await db.commit()
|
||||
await db.refresh(provider)
|
||||
return _to_response(provider)
|
||||
|
||||
|
||||
@router.put("/providers/{provider_id}", response_model=ProviderResponse)
|
||||
async def update_provider(
|
||||
provider_id: str, payload: ProviderUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
data = payload.model_dump(exclude_unset=True)
|
||||
# 加密 API Key
|
||||
if "api_key" in data and data["api_key"]:
|
||||
data["api_key"] = SecretService.encrypt(data["api_key"])
|
||||
for k, v in data.items():
|
||||
setattr(provider, k, v)
|
||||
await db.commit()
|
||||
await db.refresh(provider)
|
||||
return _to_response(provider)
|
||||
|
||||
|
||||
@router.delete("/providers/{provider_id}")
|
||||
async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
await db.delete(provider)
|
||||
await db.commit()
|
||||
return {"message": "deleted"}
|
||||
|
||||
|
||||
# ==================== 密钥管理 API ====================
|
||||
|
||||
|
||||
class SecretCreate(BaseModel):
|
||||
"""密钥创建请求。"""
|
||||
|
||||
name: str = Field(..., description="密钥名称,如 CQTAI_API_KEY")
|
||||
value: str = Field(..., description="密钥明文值")
|
||||
|
||||
|
||||
class SecretResponse(BaseModel):
|
||||
"""密钥响应,不返回明文。"""
|
||||
|
||||
name: str
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
|
||||
|
||||
@router.get("/secrets", response_model=list[str])
|
||||
async def list_secrets(db: AsyncSession = Depends(get_db)):
|
||||
"""列出所有密钥名称(不返回值)。"""
|
||||
return await SecretService.list_secrets(db)
|
||||
|
||||
|
||||
@router.post("/secrets", response_model=SecretResponse)
|
||||
async def create_or_update_secret(payload: SecretCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""创建或更新密钥。"""
|
||||
secret = await SecretService.set_secret(db, payload.name, payload.value)
|
||||
return SecretResponse(
|
||||
name=secret.name,
|
||||
created_at=secret.created_at.isoformat() if secret.created_at else None,
|
||||
updated_at=secret.updated_at.isoformat() if secret.updated_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/secrets/{name}")
|
||||
async def delete_secret(name: str, db: AsyncSession = Depends(get_db)):
|
||||
"""删除密钥。"""
|
||||
deleted = await SecretService.delete_secret(db, name)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Secret not found")
|
||||
return {"message": "deleted"}
|
||||
|
||||
|
||||
@router.get("/secrets/{name}/verify")
|
||||
async def verify_secret(name: str, db: AsyncSession = Depends(get_db)):
|
||||
"""验证密钥是否存在且可解密(不返回明文)。"""
|
||||
value = await SecretService.get_secret(db, name)
|
||||
if value is None:
|
||||
raise HTTPException(status_code=404, detail="Secret not found")
|
||||
return {"name": name, "valid": True, "length": len(value)}
|
||||
|
||||
|
||||
# ==================== 成本追踪 API ====================
|
||||
|
||||
|
||||
class BudgetUpdate(BaseModel):
|
||||
"""预算更新请求。"""
|
||||
|
||||
daily_limit_usd: float | None = None
|
||||
monthly_limit_usd: float | None = None
|
||||
alert_threshold: float | None = Field(default=None, ge=0, le=1)
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
@router.get("/costs/summary/{user_id}")
|
||||
async def get_user_cost_summary(user_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""获取用户成本摘要。"""
|
||||
return await cost_tracker.get_cost_summary(db, user_id)
|
||||
|
||||
|
||||
@router.get("/costs/all")
|
||||
async def get_all_costs_summary(db: AsyncSession = Depends(get_db)):
|
||||
"""获取所有用户成本汇总(管理员)。"""
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.db.admin_models import CostRecord
|
||||
|
||||
# 按用户汇总
|
||||
result = await db.execute(
|
||||
select(
|
||||
CostRecord.user_id,
|
||||
func.sum(CostRecord.estimated_cost).label("total_cost"),
|
||||
func.count().label("call_count"),
|
||||
).group_by(CostRecord.user_id)
|
||||
)
|
||||
users = [
|
||||
{"user_id": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]}
|
||||
for row in result.all()
|
||||
]
|
||||
|
||||
# 按能力汇总
|
||||
result = await db.execute(
|
||||
select(
|
||||
CostRecord.capability,
|
||||
func.sum(CostRecord.estimated_cost).label("total_cost"),
|
||||
func.count().label("call_count"),
|
||||
).group_by(CostRecord.capability)
|
||||
)
|
||||
capabilities = [
|
||||
{"capability": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]}
|
||||
for row in result.all()
|
||||
]
|
||||
|
||||
return {"by_user": users, "by_capability": capabilities}
|
||||
|
||||
|
||||
@router.get("/budgets/{user_id}")
|
||||
async def get_user_budget(user_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""获取用户预算配置。"""
|
||||
budget = await cost_tracker.get_user_budget(db, user_id)
|
||||
if not budget:
|
||||
return {"user_id": user_id, "budget": None}
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"budget": {
|
||||
"daily_limit_usd": float(budget.daily_limit_usd),
|
||||
"monthly_limit_usd": float(budget.monthly_limit_usd),
|
||||
"alert_threshold": float(budget.alert_threshold),
|
||||
"enabled": budget.enabled,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/budgets/{user_id}")
|
||||
async def set_user_budget(
|
||||
user_id: str, payload: BudgetUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""设置用户预算。"""
|
||||
budget = await cost_tracker.set_user_budget(
|
||||
db,
|
||||
user_id,
|
||||
daily_limit=payload.daily_limit_usd,
|
||||
monthly_limit=payload.monthly_limit_usd,
|
||||
alert_threshold=payload.alert_threshold,
|
||||
enabled=payload.enabled,
|
||||
)
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"budget": {
|
||||
"daily_limit_usd": float(budget.daily_limit_usd),
|
||||
"monthly_limit_usd": float(budget.monthly_limit_usd),
|
||||
"alert_threshold": float(budget.alert_threshold),
|
||||
"enabled": budget.enabled,
|
||||
},
|
||||
}
|
||||
14
backend/app/api/admin_reload.py
Normal file
14
backend/app/api/admin_reload.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.admin_auth import admin_guard
|
||||
from app.db.database import get_db
|
||||
from app.services.provider_cache import reload_providers
|
||||
|
||||
router = APIRouter(dependencies=[Depends(admin_guard)])
|
||||
|
||||
|
||||
@router.post("/providers/reload")
|
||||
async def reload(db: AsyncSession = Depends(get_db)):
|
||||
cache = await reload_providers(db)
|
||||
return {k: len(v) for k, v in cache.items()}
|
||||
272
backend/app/api/auth.py
Normal file
272
backend/app/api/auth.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import secrets
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.deps import get_current_user
|
||||
from app.core.security import create_access_token
|
||||
from app.db.database import get_db
|
||||
from app.db.models import User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# OAuth endpoints
|
||||
GITHUB_AUTHORIZE_URL = "https://github.com/login/oauth/authorize"
|
||||
GITHUB_TOKEN_URL = "https://github.com/login/oauth/access_token"
|
||||
GITHUB_USER_URL = "https://api.github.com/user"
|
||||
|
||||
GOOGLE_AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
GOOGLE_USER_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
STATE_COOKIE = "oauth_state"
|
||||
STATE_MAX_AGE = 600 # 10 minutes
|
||||
|
||||
|
||||
def _set_state_cookie(response: RedirectResponse, provider: str, state: str) -> None:
|
||||
response.set_cookie(
|
||||
key=STATE_COOKIE,
|
||||
value=f"{provider}:{state}",
|
||||
httponly=True,
|
||||
secure=not settings.debug,
|
||||
samesite="lax",
|
||||
max_age=STATE_MAX_AGE,
|
||||
)
|
||||
|
||||
|
||||
def _validate_state(state_from_query: str | None, state_cookie: str | None, provider: str):
|
||||
if not state_from_query or not state_cookie:
|
||||
raise HTTPException(status_code=400, detail="Missing OAuth state")
|
||||
expected_prefix = f"{provider}:"
|
||||
if not state_cookie.startswith(expected_prefix):
|
||||
raise HTTPException(status_code=400, detail="OAuth state mismatch")
|
||||
expected_state = state_cookie.removeprefix(expected_prefix)
|
||||
if not secrets.compare_digest(state_from_query, expected_state):
|
||||
raise HTTPException(status_code=400, detail="OAuth state mismatch")
|
||||
|
||||
|
||||
@router.get("/github/signin")
|
||||
async def github_signin():
|
||||
"""Start GitHub OAuth with state protection."""
|
||||
state = secrets.token_urlsafe(16)
|
||||
params = {
|
||||
"client_id": settings.github_client_id,
|
||||
"redirect_uri": f"{settings.base_url}/auth/github/callback",
|
||||
"scope": "read:user user:email",
|
||||
"state": state,
|
||||
}
|
||||
url = f"{GITHUB_AUTHORIZE_URL}?{urlencode(params)}"
|
||||
response = RedirectResponse(url=url)
|
||||
_set_state_cookie(response, "github", state)
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/github/callback")
|
||||
async def github_callback(
|
||||
code: str,
|
||||
state: str | None = Query(default=None),
|
||||
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Handle GitHub OAuth callback."""
|
||||
_validate_state(state, state_cookie, "github")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_resp = await client.post(
|
||||
GITHUB_TOKEN_URL,
|
||||
data={
|
||||
"client_id": settings.github_client_id,
|
||||
"client_secret": settings.github_client_secret,
|
||||
"code": code,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
token_resp.raise_for_status()
|
||||
token_data = token_resp.json()
|
||||
access_token = token_data.get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(status_code=502, detail="GitHub login failed")
|
||||
|
||||
user_resp = await client.get(
|
||||
GITHUB_USER_URL,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
user_resp.raise_for_status()
|
||||
user_data = user_resp.json()
|
||||
except httpx.HTTPStatusError:
|
||||
raise HTTPException(status_code=502, detail="GitHub login failed")
|
||||
|
||||
github_id = user_data.get("id")
|
||||
if github_id is None:
|
||||
raise HTTPException(status_code=502, detail="GitHub login failed")
|
||||
|
||||
return await _handle_oauth_user(
|
||||
db=db,
|
||||
provider="github",
|
||||
user_id=str(github_id),
|
||||
name=user_data.get("name") or user_data.get("login") or "GitHub User",
|
||||
avatar_url=user_data.get("avatar_url"),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/google/signin")
|
||||
async def google_signin():
|
||||
"""Start Google OAuth with state protection."""
|
||||
state = secrets.token_urlsafe(16)
|
||||
params = {
|
||||
"client_id": settings.google_client_id,
|
||||
"redirect_uri": f"{settings.base_url}/auth/google/callback",
|
||||
"response_type": "code",
|
||||
"scope": "openid email profile",
|
||||
"state": state,
|
||||
}
|
||||
url = f"{GOOGLE_AUTHORIZE_URL}?{urlencode(params)}"
|
||||
response = RedirectResponse(url=url)
|
||||
_set_state_cookie(response, "google", state)
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/google/callback")
|
||||
async def google_callback(
|
||||
code: str,
|
||||
state: str | None = Query(default=None),
|
||||
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Handle Google OAuth callback."""
|
||||
_validate_state(state, state_cookie, "google")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_resp = await client.post(
|
||||
GOOGLE_TOKEN_URL,
|
||||
data={
|
||||
"client_id": settings.google_client_id,
|
||||
"client_secret": settings.google_client_secret,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": f"{settings.base_url}/auth/google/callback",
|
||||
},
|
||||
)
|
||||
token_resp.raise_for_status()
|
||||
token_data = token_resp.json()
|
||||
access_token = token_data.get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(status_code=502, detail="Google login failed")
|
||||
|
||||
user_resp = await client.get(
|
||||
GOOGLE_USER_URL,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
user_resp.raise_for_status()
|
||||
user_data = user_resp.json()
|
||||
except httpx.HTTPStatusError:
|
||||
raise HTTPException(status_code=502, detail="Google login failed")
|
||||
|
||||
google_id = user_data.get("id")
|
||||
if google_id is None:
|
||||
raise HTTPException(status_code=502, detail="Google login failed")
|
||||
|
||||
return await _handle_oauth_user(
|
||||
db=db,
|
||||
provider="google",
|
||||
user_id=str(google_id),
|
||||
name=user_data.get("name") or user_data.get("email") or "Google User",
|
||||
avatar_url=user_data.get("picture"),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_oauth_user(
|
||||
db: AsyncSession,
|
||||
provider: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
avatar_url: str | None,
|
||||
) -> RedirectResponse:
|
||||
"""Create/update user and issue session cookie."""
|
||||
full_id = f"{provider}:{user_id}"
|
||||
|
||||
result = await db.execute(select(User).where(User.id == full_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
user = User(
|
||||
id=full_id,
|
||||
name=name,
|
||||
avatar_url=avatar_url,
|
||||
provider=provider,
|
||||
)
|
||||
db.add(user)
|
||||
else:
|
||||
user.name = name
|
||||
user.avatar_url = avatar_url
|
||||
|
||||
await db.commit()
|
||||
|
||||
token = create_access_token({"sub": user.id})
|
||||
|
||||
frontend_url = "http://localhost:5173"
|
||||
if settings.cors_origins and len(settings.cors_origins) > 0:
|
||||
frontend_url = settings.cors_origins[0]
|
||||
|
||||
response = RedirectResponse(url=f"{frontend_url}/my-stories", status_code=302)
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=not settings.debug,
|
||||
samesite="lax",
|
||||
max_age=60 * 60 * 24 * 7, # align with ACCESS_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
response.delete_cookie(STATE_COOKIE)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/signout")
|
||||
async def signout():
|
||||
"""Sign out and clear cookies."""
|
||||
response = RedirectResponse(url=settings.cors_origins[0], status_code=302)
|
||||
response.delete_cookie("access_token", samesite="lax", secure=not settings.debug)
|
||||
response.delete_cookie(STATE_COOKIE, samesite="lax", secure=not settings.debug)
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/session")
|
||||
async def get_session(user: User | None = Depends(get_current_user)):
|
||||
"""Fetch current session info."""
|
||||
if not user:
|
||||
return {"user": None}
|
||||
return {
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"avatar_url": user.avatar_url,
|
||||
"provider": user.provider,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/dev/signin")
|
||||
async def dev_signin(db: AsyncSession = Depends(get_db)):
|
||||
"""Developer backdoor login. Only works in DEBUG mode."""
|
||||
# if not settings.debug:
|
||||
# raise HTTPException(status_code=403, detail="Developer login disabled")
|
||||
|
||||
try:
|
||||
return await _handle_oauth_user(
|
||||
db=db,
|
||||
provider="github",
|
||||
user_id="dev_user_001",
|
||||
name="Developer",
|
||||
avatar_url="https://api.dicebear.com/7.x/avataaars/svg?seed=Developer"
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"Dev login failed: {str(e)}")
|
||||
268
backend/app/api/memories.py
Normal file
268
backend/app/api/memories.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Memory management APIs."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.db.database import get_db
|
||||
from app.db.models import ChildProfile, User
|
||||
from app.services import memory_service
|
||||
from app.services.memory_service import MemoryType
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class MemoryItemResponse(BaseModel):
|
||||
"""Memory item response."""
|
||||
|
||||
id: str
|
||||
type: str
|
||||
value: dict
|
||||
base_weight: float
|
||||
ttl_days: int | None
|
||||
created_at: str
|
||||
last_used_at: str | None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class MemoryListResponse(BaseModel):
|
||||
"""Memory list response."""
|
||||
|
||||
memories: list[MemoryItemResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class CreateMemoryRequest(BaseModel):
|
||||
"""Create memory request."""
|
||||
|
||||
type: str = Field(..., description="记忆类型")
|
||||
value: dict = Field(..., description="记忆内容")
|
||||
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
|
||||
weight: float | None = Field(default=None, description="权重")
|
||||
ttl_days: int | None = Field(default=None, description="过期天数")
|
||||
|
||||
|
||||
class CreateCharacterMemoryRequest(BaseModel):
|
||||
"""Create character memory request."""
|
||||
|
||||
name: str = Field(..., description="角色名称")
|
||||
description: str | None = Field(default=None, description="角色描述")
|
||||
source_story_id: int | None = Field(default=None, description="来源故事 ID")
|
||||
affinity_score: float = Field(default=1.0, ge=0.0, le=1.0, description="喜爱程度")
|
||||
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
|
||||
|
||||
|
||||
class CreateScaryElementRequest(BaseModel):
|
||||
"""Create scary element memory request."""
|
||||
|
||||
keyword: str = Field(..., description="回避的关键词")
|
||||
category: str = Field(default="other", description="分类")
|
||||
source_story_id: int | None = Field(default=None, description="来源故事 ID")
|
||||
|
||||
|
||||
async def _verify_profile_ownership(
|
||||
profile_id: str, user: User, db: AsyncSession
|
||||
) -> ChildProfile:
|
||||
"""验证档案所有权。"""
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="档案不存在")
|
||||
return profile
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/memories", response_model=MemoryListResponse)
|
||||
async def list_memories(
|
||||
profile_id: str,
|
||||
memory_type: str | None = None,
|
||||
universe_id: str | None = None,
|
||||
limit: int = 50,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取档案的记忆列表。"""
|
||||
await _verify_profile_ownership(profile_id, user, db)
|
||||
|
||||
memories = await memory_service.get_profile_memories(
|
||||
db=db,
|
||||
profile_id=profile_id,
|
||||
memory_type=memory_type,
|
||||
universe_id=universe_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return MemoryListResponse(
|
||||
memories=[
|
||||
MemoryItemResponse(
|
||||
id=m.id,
|
||||
type=m.type,
|
||||
value=m.value,
|
||||
base_weight=m.base_weight,
|
||||
ttl_days=m.ttl_days,
|
||||
created_at=m.created_at.isoformat() if m.created_at else "",
|
||||
last_used_at=m.last_used_at.isoformat() if m.last_used_at else None,
|
||||
)
|
||||
for m in memories
|
||||
],
|
||||
total=len(memories),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/profiles/{profile_id}/memories", response_model=MemoryItemResponse)
|
||||
async def create_memory(
|
||||
profile_id: str,
|
||||
payload: CreateMemoryRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建新的记忆项。"""
|
||||
await _verify_profile_ownership(profile_id, user, db)
|
||||
|
||||
# 验证类型
|
||||
valid_types = [
|
||||
MemoryType.RECENT_STORY,
|
||||
MemoryType.FAVORITE_CHARACTER,
|
||||
MemoryType.SCARY_ELEMENT,
|
||||
MemoryType.VOCABULARY_GROWTH,
|
||||
MemoryType.EMOTIONAL_HIGHLIGHT,
|
||||
MemoryType.READING_PREFERENCE,
|
||||
MemoryType.MILESTONE,
|
||||
MemoryType.SKILL_MASTERED,
|
||||
]
|
||||
if payload.type not in valid_types:
|
||||
raise HTTPException(status_code=400, detail=f"无效的记忆类型: {payload.type}")
|
||||
|
||||
memory = await memory_service.create_memory(
|
||||
db=db,
|
||||
profile_id=profile_id,
|
||||
memory_type=payload.type,
|
||||
value=payload.value,
|
||||
universe_id=payload.universe_id,
|
||||
weight=payload.weight,
|
||||
ttl_days=payload.ttl_days,
|
||||
)
|
||||
|
||||
return MemoryItemResponse(
|
||||
id=memory.id,
|
||||
type=memory.type,
|
||||
value=memory.value,
|
||||
base_weight=memory.base_weight,
|
||||
ttl_days=memory.ttl_days,
|
||||
created_at=memory.created_at.isoformat() if memory.created_at else "",
|
||||
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/profiles/{profile_id}/memories/character", response_model=MemoryItemResponse)
|
||||
async def create_character_memory(
|
||||
profile_id: str,
|
||||
payload: CreateCharacterMemoryRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""添加喜欢的角色。"""
|
||||
await _verify_profile_ownership(profile_id, user, db)
|
||||
|
||||
memory = await memory_service.create_character_memory(
|
||||
db=db,
|
||||
profile_id=profile_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
source_story_id=payload.source_story_id,
|
||||
affinity_score=payload.affinity_score,
|
||||
universe_id=payload.universe_id,
|
||||
)
|
||||
|
||||
return MemoryItemResponse(
|
||||
id=memory.id,
|
||||
type=memory.type,
|
||||
value=memory.value,
|
||||
base_weight=memory.base_weight,
|
||||
ttl_days=memory.ttl_days,
|
||||
created_at=memory.created_at.isoformat() if memory.created_at else "",
|
||||
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/profiles/{profile_id}/memories/scary", response_model=MemoryItemResponse)
|
||||
async def create_scary_element_memory(
|
||||
profile_id: str,
|
||||
payload: CreateScaryElementRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""添加回避元素。"""
|
||||
await _verify_profile_ownership(profile_id, user, db)
|
||||
|
||||
memory = await memory_service.create_scary_element_memory(
|
||||
db=db,
|
||||
profile_id=profile_id,
|
||||
keyword=payload.keyword,
|
||||
category=payload.category,
|
||||
source_story_id=payload.source_story_id,
|
||||
)
|
||||
|
||||
return MemoryItemResponse(
|
||||
id=memory.id,
|
||||
type=memory.type,
|
||||
value=memory.value,
|
||||
base_weight=memory.base_weight,
|
||||
ttl_days=memory.ttl_days,
|
||||
created_at=memory.created_at.isoformat() if memory.created_at else "",
|
||||
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/profiles/{profile_id}/memories/{memory_id}")
|
||||
async def delete_memory(
|
||||
profile_id: str,
|
||||
memory_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除记忆项。"""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db.models import MemoryItem
|
||||
|
||||
await _verify_profile_ownership(profile_id, user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(MemoryItem).where(
|
||||
MemoryItem.id == memory_id,
|
||||
MemoryItem.child_profile_id == profile_id,
|
||||
)
|
||||
)
|
||||
memory = result.scalar_one_or_none()
|
||||
|
||||
if not memory:
|
||||
raise HTTPException(status_code=404, detail="记忆不存在")
|
||||
|
||||
await db.delete(memory)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Deleted"}
|
||||
|
||||
|
||||
@router.get("/memory-types")
|
||||
async def list_memory_types():
|
||||
"""获取所有可用的记忆类型及其配置。"""
|
||||
types = []
|
||||
for type_name, config in MemoryType.CONFIG.items():
|
||||
types.append({
|
||||
"type": type_name,
|
||||
"default_weight": config[0],
|
||||
"default_ttl_days": config[1],
|
||||
"description": config[2],
|
||||
})
|
||||
return {"types": types}
|
||||
280
backend/app/api/profiles.py
Normal file
280
backend/app/api/profiles.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""Child profile APIs."""
|
||||
|
||||
from datetime import date
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.db.database import get_db
|
||||
from app.db.models import ChildProfile, Story, StoryUniverse, User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
MAX_PROFILES_PER_USER = 5
|
||||
|
||||
|
||||
class ChildProfileCreate(BaseModel):
|
||||
"""Create profile payload."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=50)
|
||||
birth_date: date | None = None
|
||||
gender: str | None = Field(default=None, pattern="^(male|female|other)$")
|
||||
interests: list[str] = Field(default_factory=list)
|
||||
growth_themes: list[str] = Field(default_factory=list)
|
||||
avatar_url: str | None = None
|
||||
|
||||
|
||||
class ChildProfileUpdate(BaseModel):
|
||||
"""Update profile payload."""
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=50)
|
||||
birth_date: date | None = None
|
||||
gender: str | None = Field(default=None, pattern="^(male|female|other)$")
|
||||
interests: list[str] | None = None
|
||||
growth_themes: list[str] | None = None
|
||||
avatar_url: str | None = None
|
||||
|
||||
|
||||
class ChildProfileResponse(BaseModel):
|
||||
"""Profile response."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
avatar_url: str | None
|
||||
birth_date: date | None
|
||||
gender: str | None
|
||||
age: int | None
|
||||
interests: list[str]
|
||||
growth_themes: list[str]
|
||||
stories_count: int
|
||||
total_reading_time: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ChildProfileListResponse(BaseModel):
|
||||
"""Profile list response."""
|
||||
|
||||
profiles: list[ChildProfileResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class TimelineEvent(BaseModel):
|
||||
"""Timeline event item."""
|
||||
|
||||
date: str
|
||||
type: Literal["story", "achievement", "milestone"]
|
||||
title: str
|
||||
description: str | None = None
|
||||
image_url: str | None = None
|
||||
metadata: dict | None = None
|
||||
|
||||
|
||||
class TimelineResponse(BaseModel):
|
||||
"""Timeline response."""
|
||||
|
||||
events: list[TimelineEvent]
|
||||
|
||||
|
||||
@router.get("/profiles", response_model=ChildProfileListResponse)
|
||||
async def list_profiles(
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List child profiles for current user."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile)
|
||||
.where(ChildProfile.user_id == user.id)
|
||||
.order_by(ChildProfile.created_at.desc())
|
||||
)
|
||||
profiles = result.scalars().all()
|
||||
|
||||
return ChildProfileListResponse(profiles=profiles, total=len(profiles))
|
||||
|
||||
|
||||
@router.post("/profiles", response_model=ChildProfileResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_profile(
|
||||
payload: ChildProfileCreate,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new child profile."""
|
||||
count = await db.scalar(
|
||||
select(func.count(ChildProfile.id)).where(ChildProfile.user_id == user.id)
|
||||
)
|
||||
if count and count >= MAX_PROFILES_PER_USER:
|
||||
raise HTTPException(status_code=400, detail="最多只能创建 5 个孩子档案")
|
||||
|
||||
existing = await db.scalar(
|
||||
select(ChildProfile.id).where(
|
||||
ChildProfile.user_id == user.id,
|
||||
ChildProfile.name == payload.name,
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="该档案名称已存在")
|
||||
|
||||
profile = ChildProfile(user_id=user.id, **payload.model_dump())
|
||||
db.add(profile)
|
||||
await db.commit()
|
||||
await db.refresh(profile)
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}", response_model=ChildProfileResponse)
|
||||
async def get_profile(
|
||||
profile_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get one child profile."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="档案不存在")
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
@router.put("/profiles/{profile_id}", response_model=ChildProfileResponse)
|
||||
async def update_profile(
|
||||
profile_id: str,
|
||||
payload: ChildProfileUpdate,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update a child profile."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="档案不存在")
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
if "name" in updates:
|
||||
existing = await db.scalar(
|
||||
select(ChildProfile.id).where(
|
||||
ChildProfile.user_id == user.id,
|
||||
ChildProfile.name == updates["name"],
|
||||
ChildProfile.id != profile_id,
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="该档案名称已存在")
|
||||
|
||||
for key, value in updates.items():
|
||||
setattr(profile, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(profile)
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
@router.delete("/profiles/{profile_id}")
|
||||
async def delete_profile(
|
||||
profile_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete a child profile."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="档案不存在")
|
||||
|
||||
await db.delete(profile)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Deleted"}
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/timeline", response_model=TimelineResponse)
|
||||
async def get_profile_timeline(
|
||||
profile_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get profile growth timeline."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="档案不存在")
|
||||
|
||||
events: list[TimelineEvent] = []
|
||||
|
||||
# 1. Milestone: Profile Created
|
||||
events.append(TimelineEvent(
|
||||
date=profile.created_at.isoformat(),
|
||||
type="milestone",
|
||||
title="初次相遇",
|
||||
description=f"创建了档案 {profile.name}"
|
||||
))
|
||||
|
||||
# 2. Stories
|
||||
stories_result = await db.execute(
|
||||
select(Story).where(Story.child_profile_id == profile_id)
|
||||
)
|
||||
for s in stories_result.scalars():
|
||||
events.append(TimelineEvent(
|
||||
date=s.created_at.isoformat(),
|
||||
type="story",
|
||||
title=s.title,
|
||||
image_url=s.image_url,
|
||||
metadata={"story_id": s.id, "mode": s.mode}
|
||||
))
|
||||
|
||||
# 3. Achievements (from Universe)
|
||||
universes_result = await db.execute(
|
||||
select(StoryUniverse).where(StoryUniverse.child_profile_id == profile_id)
|
||||
)
|
||||
for u in universes_result.scalars():
|
||||
if u.achievements:
|
||||
for ach in u.achievements:
|
||||
if isinstance(ach, dict):
|
||||
obt_at = ach.get("obtained_at")
|
||||
# Fallback
|
||||
if not obt_at:
|
||||
obt_at = u.updated_at.isoformat()
|
||||
|
||||
events.append(TimelineEvent(
|
||||
date=obt_at,
|
||||
type="achievement",
|
||||
title=f"获得成就:{ach.get('type')}",
|
||||
description=ach.get('description'),
|
||||
metadata={"universe_id": u.id, "source_story_id": ach.get("source_story_id")}
|
||||
))
|
||||
|
||||
# Sort by date desc
|
||||
events.sort(key=lambda x: x.date, reverse=True)
|
||||
|
||||
return TimelineResponse(events=events)
|
||||
120
backend/app/api/push_configs.py
Normal file
120
backend/app/api/push_configs.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Push configuration APIs."""
|
||||
|
||||
from datetime import time
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.db.database import get_db
|
||||
from app.db.models import ChildProfile, PushConfig, User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class PushConfigUpsert(BaseModel):
|
||||
"""Upsert push config payload."""
|
||||
|
||||
child_profile_id: str
|
||||
push_time: time | None = None
|
||||
push_days: list[int] | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class PushConfigResponse(BaseModel):
|
||||
"""Push config response."""
|
||||
|
||||
id: str
|
||||
child_profile_id: str
|
||||
push_time: time | None
|
||||
push_days: list[int]
|
||||
enabled: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PushConfigListResponse(BaseModel):
|
||||
"""Push config list response."""
|
||||
|
||||
configs: list[PushConfigResponse]
|
||||
total: int
|
||||
|
||||
|
||||
def _validate_push_days(push_days: list[int]) -> list[int]:
|
||||
invalid = [day for day in push_days if day < 0 or day > 6]
|
||||
if invalid:
|
||||
raise HTTPException(status_code=400, detail="推送日期必须在 0-6 之间")
|
||||
return list(dict.fromkeys(push_days))
|
||||
|
||||
|
||||
@router.get("/push-configs", response_model=PushConfigListResponse)
|
||||
async def list_push_configs(
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List push configs for current user."""
|
||||
result = await db.execute(
|
||||
select(PushConfig).where(PushConfig.user_id == user.id)
|
||||
)
|
||||
configs = result.scalars().all()
|
||||
return PushConfigListResponse(configs=configs, total=len(configs))
|
||||
|
||||
|
||||
@router.put("/push-configs", response_model=PushConfigResponse)
|
||||
async def upsert_push_config(
|
||||
payload: PushConfigUpsert,
|
||||
response: Response,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create or update push config for a child profile."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == payload.child_profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
||||
|
||||
result = await db.execute(
|
||||
select(PushConfig).where(PushConfig.child_profile_id == payload.child_profile_id)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config is None:
|
||||
if payload.push_time is None or payload.push_days is None:
|
||||
raise HTTPException(status_code=400, detail="创建配置需要提供推送时间和日期")
|
||||
push_days = _validate_push_days(payload.push_days)
|
||||
config = PushConfig(
|
||||
user_id=user.id,
|
||||
child_profile_id=payload.child_profile_id,
|
||||
push_time=payload.push_time,
|
||||
push_days=push_days,
|
||||
enabled=True if payload.enabled is None else payload.enabled,
|
||||
)
|
||||
db.add(config)
|
||||
await db.commit()
|
||||
await db.refresh(config)
|
||||
response.status_code = status.HTTP_201_CREATED
|
||||
return config
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
if "push_days" in updates and updates["push_days"] is not None:
|
||||
updates["push_days"] = _validate_push_days(updates["push_days"])
|
||||
if "push_time" in updates and updates["push_time"] is None:
|
||||
raise HTTPException(status_code=400, detail="推送时间不能为空")
|
||||
|
||||
for key, value in updates.items():
|
||||
if key == "child_profile_id":
|
||||
continue
|
||||
if value is not None:
|
||||
setattr(config, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(config)
|
||||
return config
|
||||
120
backend/app/api/reading_events.py
Normal file
120
backend/app/api/reading_events.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Reading event APIs."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.db.database import get_db
|
||||
from app.db.models import ChildProfile, MemoryItem, ReadingEvent, Story, User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
EVENT_WEIGHTS: dict[str, float] = {
|
||||
"completed": 1.0,
|
||||
"replayed": 1.5,
|
||||
"started": 0.1,
|
||||
"skipped": -0.5,
|
||||
}
|
||||
|
||||
|
||||
class ReadingEventCreate(BaseModel):
|
||||
"""Reading event payload."""
|
||||
|
||||
child_profile_id: str
|
||||
story_id: int | None = None
|
||||
event_type: Literal["started", "completed", "skipped", "replayed"]
|
||||
reading_time: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class ReadingEventResponse(BaseModel):
|
||||
"""Reading event response."""
|
||||
|
||||
id: int
|
||||
child_profile_id: str
|
||||
story_id: int | None
|
||||
event_type: str
|
||||
reading_time: int
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
@router.post("/reading-events", response_model=ReadingEventResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_reading_event(
|
||||
payload: ReadingEventCreate,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a reading event and update profile stats/memory."""
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == payload.child_profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
||||
|
||||
story = None
|
||||
if payload.story_id is not None:
|
||||
result = await db.execute(
|
||||
select(Story).where(
|
||||
Story.id == payload.story_id,
|
||||
Story.user_id == user.id,
|
||||
)
|
||||
)
|
||||
story = result.scalar_one_or_none()
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="故事不存在")
|
||||
|
||||
if payload.reading_time:
|
||||
profile.total_reading_time = (profile.total_reading_time or 0) + payload.reading_time
|
||||
|
||||
if payload.event_type in {"completed", "replayed"} and payload.story_id is not None:
|
||||
existing = await db.scalar(
|
||||
select(ReadingEvent.id).where(
|
||||
ReadingEvent.child_profile_id == payload.child_profile_id,
|
||||
ReadingEvent.story_id == payload.story_id,
|
||||
ReadingEvent.event_type.in_(["completed", "replayed"]),
|
||||
)
|
||||
)
|
||||
if existing is None:
|
||||
profile.stories_count = (profile.stories_count or 0) + 1
|
||||
|
||||
event = ReadingEvent(
|
||||
child_profile_id=payload.child_profile_id,
|
||||
story_id=payload.story_id,
|
||||
event_type=payload.event_type,
|
||||
reading_time=payload.reading_time,
|
||||
)
|
||||
db.add(event)
|
||||
|
||||
weight = EVENT_WEIGHTS.get(payload.event_type, 0.0)
|
||||
if story and weight > 0:
|
||||
db.add(
|
||||
MemoryItem(
|
||||
child_profile_id=payload.child_profile_id,
|
||||
universe_id=story.universe_id,
|
||||
type="recent_story",
|
||||
value={
|
||||
"story_id": story.id,
|
||||
"title": story.title,
|
||||
"event_type": payload.event_type,
|
||||
},
|
||||
base_weight=weight,
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
ttl_days=90,
|
||||
)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(event)
|
||||
|
||||
return event
|
||||
605
backend/app/api/stories.py
Normal file
605
backend/app/api/stories.py
Normal file
@@ -0,0 +1,605 @@
|
||||
"""Story related APIs."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Literal
|
||||
|
||||
from cachetools import TTLCache
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.core.logging import get_logger
|
||||
from app.db.database import get_db
|
||||
from app.db.models import ChildProfile, Story, StoryUniverse, User
|
||||
from app.services.provider_router import (
|
||||
generate_image,
|
||||
generate_story_content,
|
||||
generate_storybook,
|
||||
text_to_speech,
|
||||
)
|
||||
from app.tasks.achievements import extract_story_achievements
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
MAX_DATA_LENGTH = 2000
|
||||
MAX_EDU_THEME_LENGTH = 200
|
||||
MAX_TTS_LENGTH = 4000
|
||||
|
||||
RATE_LIMIT_WINDOW = 60 # seconds
|
||||
RATE_LIMIT_REQUESTS = 10
|
||||
RATE_LIMIT_CACHE_SIZE = 10000 # 最大跟踪用户数
|
||||
|
||||
_request_log: TTLCache[str, list[float]] = TTLCache(
|
||||
maxsize=RATE_LIMIT_CACHE_SIZE, ttl=RATE_LIMIT_WINDOW * 2
|
||||
)
|
||||
|
||||
|
||||
def _check_rate_limit(user_id: str):
|
||||
now = time.time()
|
||||
timestamps = _request_log.get(user_id, [])
|
||||
timestamps = [t for t in timestamps if now - t <= RATE_LIMIT_WINDOW]
|
||||
if len(timestamps) >= RATE_LIMIT_REQUESTS:
|
||||
raise HTTPException(status_code=429, detail="Too many requests, please slow down.")
|
||||
timestamps.append(now)
|
||||
_request_log[user_id] = timestamps
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
"""Story generation request."""
|
||||
|
||||
type: Literal["keywords", "full_story"]
|
||||
data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
|
||||
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
|
||||
child_profile_id: str | None = None
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
class StoryResponse(BaseModel):
|
||||
"""Story response."""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
story_text: str
|
||||
cover_prompt: str | None
|
||||
image_url: str | None
|
||||
mode: str
|
||||
child_profile_id: str | None = None
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
class StoryListItem(BaseModel):
|
||||
"""Story list item."""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
image_url: str | None
|
||||
created_at: str
|
||||
mode: str
|
||||
|
||||
|
||||
class FullStoryResponse(BaseModel):
|
||||
"""完整故事响应(含图片和音频状态)。"""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
story_text: str
|
||||
cover_prompt: str | None
|
||||
image_url: str | None
|
||||
audio_ready: bool
|
||||
mode: str
|
||||
errors: dict[str, str | None] = Field(default_factory=dict)
|
||||
child_profile_id: str | None = None
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
from app.services.memory_service import build_enhanced_memory_context
|
||||
|
||||
|
||||
async def _validate_profile_and_universe(
|
||||
request: GenerateRequest,
|
||||
user: User,
|
||||
db: AsyncSession,
|
||||
) -> tuple[str | None, str | None]:
|
||||
if not request.child_profile_id and not request.universe_id:
|
||||
return None, None
|
||||
|
||||
profile_id = request.child_profile_id
|
||||
universe_id = request.universe_id
|
||||
|
||||
if profile_id:
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
||||
|
||||
if universe_id:
|
||||
result = await db.execute(
|
||||
select(StoryUniverse)
|
||||
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
|
||||
.where(
|
||||
StoryUniverse.id == universe_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
universe = result.scalar_one_or_none()
|
||||
if not universe:
|
||||
raise HTTPException(status_code=404, detail="故事宇宙不存在")
|
||||
if profile_id and universe.child_profile_id != profile_id:
|
||||
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
|
||||
if not profile_id:
|
||||
profile_id = universe.child_profile_id
|
||||
|
||||
return profile_id, universe_id
|
||||
|
||||
|
||||
@router.post("/stories/generate", response_model=StoryResponse)
|
||||
async def generate_story(
|
||||
request: GenerateRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Generate or enhance a story."""
|
||||
_check_rate_limit(user.id)
|
||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
try:
|
||||
result = await generate_story_content(
|
||||
input_type=request.type,
|
||||
data=request.data,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
|
||||
|
||||
story = Story(
|
||||
user_id=user.id,
|
||||
child_profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
title=result.title,
|
||||
story_text=result.story_text,
|
||||
cover_prompt=result.cover_prompt_suggestion,
|
||||
mode=result.mode,
|
||||
)
|
||||
db.add(story)
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
return StoryResponse(
|
||||
id=story.id,
|
||||
title=story.title,
|
||||
story_text=story.story_text,
|
||||
cover_prompt=story.cover_prompt,
|
||||
image_url=story.image_url,
|
||||
mode=story.mode,
|
||||
child_profile_id=story.child_profile_id,
|
||||
universe_id=story.universe_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stories/generate/full", response_model=FullStoryResponse)
|
||||
async def generate_story_full(
|
||||
request: GenerateRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""生成完整故事(故事 + 并行生成图片和音频)。
|
||||
|
||||
部分成功策略:故事必须成功,图片/音频失败不影响整体。
|
||||
"""
|
||||
_check_rate_limit(user.id)
|
||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
# Step 1: 故事生成(必须成功)
|
||||
try:
|
||||
result = await generate_story_content(
|
||||
input_type=request.type,
|
||||
data=request.data,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("story_generation_failed", error=str(exc))
|
||||
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
|
||||
|
||||
# 保存故事
|
||||
story = Story(
|
||||
user_id=user.id,
|
||||
child_profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
title=result.title,
|
||||
story_text=result.story_text,
|
||||
cover_prompt=result.cover_prompt_suggestion,
|
||||
mode=result.mode,
|
||||
)
|
||||
db.add(story)
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
# Step 2: 生成封面图片(音频按需生成,避免浪费)
|
||||
errors: dict[str, str | None] = {}
|
||||
image_url: str | None = None
|
||||
|
||||
if story.cover_prompt:
|
||||
try:
|
||||
image_url = await generate_image(story.cover_prompt)
|
||||
story.image_url = image_url
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
errors["image"] = str(exc)
|
||||
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
|
||||
|
||||
# 注意:音频不在此处预生成,用户通过 /api/audio/{id} 按需获取
|
||||
# 这样避免生成后丢弃造成的成本浪费
|
||||
|
||||
return FullStoryResponse(
|
||||
id=story.id,
|
||||
title=story.title,
|
||||
story_text=story.story_text,
|
||||
cover_prompt=story.cover_prompt,
|
||||
image_url=image_url,
|
||||
audio_ready=False, # 音频需要用户主动请求
|
||||
mode=story.mode,
|
||||
errors=errors,
|
||||
child_profile_id=story.child_profile_id,
|
||||
universe_id=story.universe_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stories/generate/stream")
|
||||
async def generate_story_stream(
|
||||
request: GenerateRequest,
|
||||
req: Request,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""流式生成故事(SSE)。
|
||||
|
||||
事件流程:
|
||||
- started: 返回 story_id
|
||||
- story_ready: 返回 title, content
|
||||
- story_failed: 返回 error
|
||||
- image_ready: 返回 image_url
|
||||
- image_failed: 返回 error
|
||||
- complete: 结束流
|
||||
"""
|
||||
_check_rate_limit(user.id)
|
||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[dict, None]:
|
||||
story_id = str(uuid.uuid4())
|
||||
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
|
||||
|
||||
# Step 1: 生成故事
|
||||
try:
|
||||
result = await generate_story_content(
|
||||
input_type=request.type,
|
||||
data=request.data,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("sse_story_generation_failed", error=str(e))
|
||||
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
|
||||
return
|
||||
|
||||
# 保存故事
|
||||
story = Story(
|
||||
user_id=user.id,
|
||||
child_profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
title=result.title,
|
||||
story_text=result.story_text,
|
||||
cover_prompt=result.cover_prompt_suggestion,
|
||||
mode=result.mode,
|
||||
)
|
||||
db.add(story)
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
yield {
|
||||
"event": "story_ready",
|
||||
"data": json.dumps({
|
||||
"id": story.id,
|
||||
"title": story.title,
|
||||
"content": story.story_text,
|
||||
"cover_prompt": story.cover_prompt,
|
||||
"mode": story.mode,
|
||||
"child_profile_id": story.child_profile_id,
|
||||
"universe_id": story.universe_id,
|
||||
}),
|
||||
}
|
||||
|
||||
# Step 2: 并行生成图片(音频按需)
|
||||
if story.cover_prompt:
|
||||
try:
|
||||
image_url = await generate_image(story.cover_prompt)
|
||||
story.image_url = image_url
|
||||
await db.commit()
|
||||
yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})}
|
||||
except Exception as e:
|
||||
logger.warning("sse_image_generation_failed", story_id=story.id, error=str(e))
|
||||
yield {"event": "image_failed", "data": json.dumps({"error": str(e)})}
|
||||
|
||||
yield {"event": "complete", "data": json.dumps({"story_id": story.id})}
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
|
||||
# ==================== Storybook API ====================
|
||||
|
||||
|
||||
class StorybookRequest(BaseModel):
|
||||
"""Storybook 生成请求。"""
|
||||
|
||||
keywords: str = Field(..., min_length=1, max_length=200)
|
||||
page_count: int = Field(default=6, ge=4, le=12)
|
||||
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
|
||||
generate_images: bool = Field(default=False, description="是否同时生成插图")
|
||||
child_profile_id: str | None = None
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
class StorybookPageResponse(BaseModel):
|
||||
"""故事书单页响应。"""
|
||||
|
||||
page_number: int
|
||||
text: str
|
||||
image_prompt: str
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
class StorybookResponse(BaseModel):
|
||||
"""故事书响应。"""
|
||||
|
||||
id: int | None = None
|
||||
title: str
|
||||
main_character: str
|
||||
art_style: str
|
||||
pages: list[StorybookPageResponse]
|
||||
cover_prompt: str
|
||||
cover_url: str | None = None
|
||||
|
||||
|
||||
@router.post("/storybook/generate", response_model=StorybookResponse)
|
||||
async def generate_storybook_api(
|
||||
request: StorybookRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""生成分页故事书并保存。
|
||||
|
||||
返回故事书结构,包含每页文字和图像提示词。
|
||||
"""
|
||||
_check_rate_limit(user.id)
|
||||
|
||||
# 验证档案和宇宙
|
||||
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数
|
||||
# 这里我们直接手动验证,或重构验证函数。为了简单,手动调用部分逻辑。
|
||||
|
||||
# 构建临时的 GenerateRequest 用于验证验证函数签名(或者直接手动查库更好)
|
||||
profile_id = request.child_profile_id
|
||||
universe_id = request.universe_id
|
||||
|
||||
if profile_id:
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
if not result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
||||
|
||||
if universe_id:
|
||||
result = await db.execute(
|
||||
select(StoryUniverse)
|
||||
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
|
||||
.where(
|
||||
StoryUniverse.id == universe_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
universe = result.scalar_one_or_none()
|
||||
if not universe:
|
||||
raise HTTPException(status_code=404, detail="故事宇宙不存在")
|
||||
if profile_id and universe.child_profile_id != profile_id:
|
||||
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
|
||||
if not profile_id:
|
||||
profile_id = universe.child_profile_id
|
||||
|
||||
logger.info(
|
||||
"storybook_request",
|
||||
user_id=user.id,
|
||||
keywords=request.keywords,
|
||||
page_count=request.page_count,
|
||||
profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
)
|
||||
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
try:
|
||||
# 注意:generate_storybook 目前可能不支持记忆上下文注入
|
||||
# 我们需要看看 generate_storybook 的签名
|
||||
# 如果不支持,记忆功能在绘本模式下暂不可用,但基本参数传递是支持的
|
||||
storybook = await generate_storybook(
|
||||
keywords=request.keywords,
|
||||
page_count=request.page_count,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
db=db,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("storybook_generation_failed", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
|
||||
|
||||
# ==============================================================================
|
||||
# 核心升级: 并行全量生成 (Parallel Full Rendering)
|
||||
# ==============================================================================
|
||||
final_cover_url = storybook.cover_url
|
||||
|
||||
if request.generate_images:
|
||||
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
|
||||
|
||||
# 1. 准备所有生图任务 (封面 + 所有内页)
|
||||
tasks = []
|
||||
|
||||
# 封面任务
|
||||
async def _gen_cover():
|
||||
if storybook.cover_prompt and not storybook.cover_url:
|
||||
try:
|
||||
return await generate_image(storybook.cover_prompt, db=db)
|
||||
except Exception as e:
|
||||
logger.warning("cover_gen_failed", error=str(e))
|
||||
return storybook.cover_url
|
||||
|
||||
tasks.append(_gen_cover())
|
||||
|
||||
# 内页任务
|
||||
async def _gen_page(page):
|
||||
if page.image_prompt and not page.image_url:
|
||||
try:
|
||||
url = await generate_image(page.image_prompt, db=db)
|
||||
page.image_url = url
|
||||
except Exception as e:
|
||||
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
|
||||
|
||||
for page in storybook.pages:
|
||||
tasks.append(_gen_page(page))
|
||||
|
||||
# 2. 并发执行所有任务
|
||||
# 使用 return_exceptions=True 防止单张失败影响整体
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 3. 更新封面结果 (results[0] 是封面任务的返回值)
|
||||
cover_res = results[0]
|
||||
if isinstance(cover_res, str):
|
||||
final_cover_url = cover_res
|
||||
|
||||
logger.info("storybook_parallel_generation_complete")
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
# 构建并保存 Story 对象
|
||||
# 将 pages 对象转换为字典列表以存入 JSON 字段
|
||||
pages_data = [
|
||||
{
|
||||
"page_number": p.page_number,
|
||||
"text": p.text,
|
||||
"image_prompt": p.image_prompt,
|
||||
"image_url": p.image_url,
|
||||
}
|
||||
for p in storybook.pages
|
||||
]
|
||||
|
||||
story = Story(
|
||||
user_id=user.id,
|
||||
child_profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
title=storybook.title,
|
||||
mode="storybook",
|
||||
pages=pages_data, # 存入 JSON 字段
|
||||
story_text=None, # 绘本模式下,主文本可为空,或者可以存个摘要
|
||||
cover_prompt=storybook.cover_prompt,
|
||||
image_url=final_cover_url,
|
||||
)
|
||||
db.add(story)
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
# 构建响应 (使用更新后的 pages_data)
|
||||
response_pages = [
|
||||
StorybookPageResponse(
|
||||
page_number=p["page_number"],
|
||||
text=p["text"],
|
||||
image_prompt=p["image_prompt"],
|
||||
image_url=p.get("image_url"),
|
||||
)
|
||||
for p in pages_data
|
||||
]
|
||||
|
||||
return StorybookResponse(
|
||||
id=story.id,
|
||||
title=storybook.title,
|
||||
main_character=storybook.main_character,
|
||||
art_style=storybook.art_style,
|
||||
pages=response_pages,
|
||||
cover_prompt=storybook.cover_prompt,
|
||||
cover_url=final_cover_url,
|
||||
)
|
||||
|
||||
|
||||
class AchievementItem(BaseModel):
|
||||
type: str
|
||||
description: str
|
||||
obtained_at: str | None = None
|
||||
|
||||
|
||||
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
|
||||
async def get_story_achievements(
|
||||
story_id: int,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get achievements unlocked by a specific story."""
|
||||
# 使用 joinedload 避免 N+1 查询
|
||||
result = await db.execute(
|
||||
select(Story)
|
||||
.options(joinedload(Story.story_universe))
|
||||
.where(Story.id == story_id, Story.user_id == user.id)
|
||||
)
|
||||
story = result.scalar_one_or_none()
|
||||
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
|
||||
if not story.universe_id or not story.story_universe:
|
||||
return []
|
||||
|
||||
universe = story.story_universe
|
||||
if not universe.achievements:
|
||||
return []
|
||||
|
||||
results = []
|
||||
for ach in universe.achievements:
|
||||
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
|
||||
results.append(AchievementItem(
|
||||
type=ach.get("type", "Unknown"),
|
||||
description=ach.get("description", ""),
|
||||
obtained_at=ach.get("obtained_at")
|
||||
))
|
||||
|
||||
return results
|
||||
201
backend/app/api/universes.py
Normal file
201
backend/app/api/universes.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Story universe APIs."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.db.database import get_db
|
||||
from app.db.models import ChildProfile, StoryUniverse, User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class StoryUniverseCreate(BaseModel):
|
||||
"""Create universe payload."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
protagonist: dict[str, Any]
|
||||
recurring_characters: list[dict[str, Any]] = Field(default_factory=list)
|
||||
world_settings: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StoryUniverseUpdate(BaseModel):
|
||||
"""Update universe payload."""
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=100)
|
||||
protagonist: dict[str, Any] | None = None
|
||||
recurring_characters: list[dict[str, Any]] | None = None
|
||||
world_settings: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AchievementCreate(BaseModel):
|
||||
"""Achievement payload."""
|
||||
|
||||
type: str = Field(..., min_length=1, max_length=50)
|
||||
description: str = Field(..., min_length=1, max_length=200)
|
||||
|
||||
|
||||
class StoryUniverseResponse(BaseModel):
|
||||
"""Universe response."""
|
||||
|
||||
id: str
|
||||
child_profile_id: str
|
||||
name: str
|
||||
protagonist: dict[str, Any]
|
||||
recurring_characters: list[dict[str, Any]]
|
||||
world_settings: dict[str, Any]
|
||||
achievements: list[dict[str, Any]]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class StoryUniverseListResponse(BaseModel):
|
||||
"""Universe list response."""
|
||||
|
||||
universes: list[StoryUniverseResponse]
|
||||
total: int
|
||||
|
||||
|
||||
async def _get_profile_or_404(
|
||||
profile_id: str,
|
||||
user: User,
|
||||
db: AsyncSession,
|
||||
) -> ChildProfile:
|
||||
result = await db.execute(
|
||||
select(ChildProfile).where(
|
||||
ChildProfile.id == profile_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
profile = result.scalar_one_or_none()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="档案不存在")
|
||||
return profile
|
||||
|
||||
|
||||
async def _get_universe_or_404(
|
||||
universe_id: str,
|
||||
user: User,
|
||||
db: AsyncSession,
|
||||
) -> StoryUniverse:
|
||||
result = await db.execute(
|
||||
select(StoryUniverse)
|
||||
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
|
||||
.where(
|
||||
StoryUniverse.id == universe_id,
|
||||
ChildProfile.user_id == user.id,
|
||||
)
|
||||
)
|
||||
universe = result.scalar_one_or_none()
|
||||
if not universe:
|
||||
raise HTTPException(status_code=404, detail="宇宙不存在")
|
||||
return universe
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/universes", response_model=StoryUniverseListResponse)
|
||||
async def list_universes(
|
||||
profile_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List universes for a child profile."""
|
||||
await _get_profile_or_404(profile_id, user, db)
|
||||
result = await db.execute(
|
||||
select(StoryUniverse)
|
||||
.where(StoryUniverse.child_profile_id == profile_id)
|
||||
.order_by(StoryUniverse.updated_at.desc())
|
||||
)
|
||||
universes = result.scalars().all()
|
||||
return StoryUniverseListResponse(universes=universes, total=len(universes))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/profiles/{profile_id}/universes",
|
||||
response_model=StoryUniverseResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_universe(
|
||||
profile_id: str,
|
||||
payload: StoryUniverseCreate,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a story universe."""
|
||||
await _get_profile_or_404(profile_id, user, db)
|
||||
universe = StoryUniverse(child_profile_id=profile_id, **payload.model_dump())
|
||||
db.add(universe)
|
||||
await db.commit()
|
||||
await db.refresh(universe)
|
||||
return universe
|
||||
|
||||
|
||||
@router.get("/universes/{universe_id}", response_model=StoryUniverseResponse)
|
||||
async def get_universe(
|
||||
universe_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get one universe."""
|
||||
universe = await _get_universe_or_404(universe_id, user, db)
|
||||
return universe
|
||||
|
||||
|
||||
@router.put("/universes/{universe_id}", response_model=StoryUniverseResponse)
|
||||
async def update_universe(
|
||||
universe_id: str,
|
||||
payload: StoryUniverseUpdate,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update a story universe."""
|
||||
universe = await _get_universe_or_404(universe_id, user, db)
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
for key, value in updates.items():
|
||||
setattr(universe, key, value)
|
||||
await db.commit()
|
||||
await db.refresh(universe)
|
||||
return universe
|
||||
|
||||
|
||||
@router.delete("/universes/{universe_id}")
|
||||
async def delete_universe(
|
||||
universe_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete a story universe."""
|
||||
universe = await _get_universe_or_404(universe_id, user, db)
|
||||
await db.delete(universe)
|
||||
await db.commit()
|
||||
return {"message": "Deleted"}
|
||||
|
||||
|
||||
@router.post("/universes/{universe_id}/achievements", response_model=StoryUniverseResponse)
|
||||
async def add_achievement(
|
||||
universe_id: str,
|
||||
payload: AchievementCreate,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Add an achievement to a universe."""
|
||||
universe = await _get_universe_or_404(universe_id, user, db)
|
||||
|
||||
achievements = list(universe.achievements or [])
|
||||
key = (payload.type.strip(), payload.description.strip())
|
||||
existing = {
|
||||
(str(item.get("type", "")).strip(), str(item.get("description", "")).strip())
|
||||
for item in achievements
|
||||
if isinstance(item, dict)
|
||||
}
|
||||
if key not in existing:
|
||||
achievements.append({"type": key[0], "description": key[1]})
|
||||
universe.achievements = achievements
|
||||
await db.commit()
|
||||
await db.refresh(universe)
|
||||
|
||||
return universe
|
||||
0
backend/app/core/__init__.py
Normal file
0
backend/app/core/__init__.py
Normal file
72
backend/app/core/admin_auth.py
Normal file
72
backend/app/core/admin_auth.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import secrets
|
||||
import time
|
||||
|
||||
from cachetools import TTLCache
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
security = HTTPBasic()
|
||||
|
||||
# 登录失败记录:IP -> (失败次数, 首次失败时间)
|
||||
_failed_attempts: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900) # 15分钟
|
||||
|
||||
MAX_ATTEMPTS = 5
|
||||
LOCKOUT_SECONDS = 900 # 15分钟
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
if request.client and request.client.host:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
|
||||
def admin_guard(
|
||||
request: Request,
|
||||
credentials: HTTPBasicCredentials = Depends(security),
|
||||
):
|
||||
client_ip = _get_client_ip(request)
|
||||
|
||||
# 检查是否被锁定
|
||||
if client_ip in _failed_attempts:
|
||||
attempts, first_fail = _failed_attempts[client_ip]
|
||||
if attempts >= MAX_ATTEMPTS:
|
||||
remaining = int(LOCKOUT_SECONDS - (time.time() - first_fail))
|
||||
if remaining > 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"登录尝试过多,请 {remaining} 秒后重试",
|
||||
)
|
||||
else:
|
||||
del _failed_attempts[client_ip]
|
||||
|
||||
# 使用 secrets.compare_digest 防止时序攻击
|
||||
username_ok = secrets.compare_digest(
|
||||
credentials.username.encode(), settings.admin_username.encode()
|
||||
)
|
||||
password_ok = secrets.compare_digest(
|
||||
credentials.password.encode(), settings.admin_password.encode()
|
||||
)
|
||||
|
||||
if not (username_ok and password_ok):
|
||||
# 记录失败
|
||||
if client_ip in _failed_attempts:
|
||||
attempts, first_fail = _failed_attempts[client_ip]
|
||||
_failed_attempts[client_ip] = (attempts + 1, first_fail)
|
||||
else:
|
||||
_failed_attempts[client_ip] = (1, time.time())
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户名或密码错误",
|
||||
)
|
||||
|
||||
# 登录成功,清除失败记录
|
||||
if client_ip in _failed_attempts:
|
||||
del _failed_attempts[client_ip]
|
||||
|
||||
return True
|
||||
33
backend/app/core/celery_app.py
Normal file
33
backend/app/core/celery_app.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Celery application setup."""
|
||||
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
celery_app = Celery(
|
||||
"dreamweaver",
|
||||
broker=settings.celery_broker_url,
|
||||
backend=settings.celery_result_backend,
|
||||
)
|
||||
|
||||
celery_app.conf.update(
|
||||
task_track_started=True,
|
||||
task_serializer="json",
|
||||
accept_content=["json"],
|
||||
result_serializer="json",
|
||||
timezone="Asia/Shanghai",
|
||||
enable_utc=True,
|
||||
beat_schedule={
|
||||
"check_push_notifications": {
|
||||
"task": "app.tasks.push_notifications.check_push_notifications",
|
||||
"schedule": crontab(minute="*/15"),
|
||||
},
|
||||
"prune_expired_memories": {
|
||||
"task": "app.tasks.memory.prune_memories_task",
|
||||
"schedule": crontab(minute="0", hour="3"), # Daily at 03:00
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
celery_app.autodiscover_tasks(["app.tasks"])
|
||||
76
backend/app/core/config.py
Normal file
76
backend/app/core/config.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用全局配置"""
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
# 应用基础配置
|
||||
app_name: str = "DreamWeaver"
|
||||
debug: bool = False
|
||||
secret_key: str = Field(..., description="JWT 签名密钥")
|
||||
base_url: str = Field("http://localhost:8000", description="后端对外回调地址")
|
||||
|
||||
# 数据库
|
||||
database_url: str = Field(..., description="SQLAlchemy async URL")
|
||||
|
||||
# OAuth - GitHub
|
||||
github_client_id: str = ""
|
||||
github_client_secret: str = ""
|
||||
|
||||
# OAuth - Google
|
||||
google_client_id: str = ""
|
||||
google_client_secret: str = ""
|
||||
|
||||
# AI Capability Keys
|
||||
text_api_key: str = ""
|
||||
tts_api_base: str = ""
|
||||
tts_api_key: str = ""
|
||||
image_api_key: str = ""
|
||||
|
||||
# Additional Provider API Keys
|
||||
openai_api_key: str = ""
|
||||
elevenlabs_api_key: str = ""
|
||||
cqtai_api_key: str = ""
|
||||
minimax_api_key: str = ""
|
||||
minimax_group_id: str = ""
|
||||
antigravity_api_key: str = ""
|
||||
antigravity_api_base: str = ""
|
||||
|
||||
# AI Model Configuration
|
||||
text_model: str = "gemini-2.0-flash"
|
||||
tts_model: str = ""
|
||||
image_model: str = ""
|
||||
|
||||
# Provider routing (ordered lists)
|
||||
text_providers: list[str] = Field(default_factory=lambda: ["gemini"])
|
||||
image_providers: list[str] = Field(default_factory=lambda: ["cqtai"])
|
||||
tts_providers: list[str] = Field(default_factory=lambda: ["minimax", "elevenlabs", "edge_tts"])
|
||||
|
||||
# Celery (Redis)
|
||||
celery_broker_url: str = Field("redis://localhost:6379/0")
|
||||
celery_result_backend: str = Field("redis://localhost:6379/0")
|
||||
|
||||
# Admin console
|
||||
enable_admin_console: bool = False
|
||||
admin_username: str = "admin"
|
||||
admin_password: str = "admin123" # 建议通过环境变量覆盖
|
||||
|
||||
# CORS
|
||||
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173"])
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _require_core_settings(self) -> "Settings": # type: ignore[override]
|
||||
missing = []
|
||||
if not self.secret_key or self.secret_key == "change-me-in-production":
|
||||
missing.append("SECRET_KEY")
|
||||
if not self.database_url:
|
||||
missing.append("DATABASE_URL")
|
||||
if missing:
|
||||
raise ValueError(f"Missing required settings: {', '.join(missing)}")
|
||||
return self
|
||||
|
||||
|
||||
settings = Settings()
|
||||
39
backend/app/core/deps.py
Normal file
39
backend/app/core/deps.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from fastapi import Cookie, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.security import decode_access_token
|
||||
from app.db.database import get_db
|
||||
from app.db.models import User
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
access_token: str | None = Cookie(default=None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User | None:
|
||||
"""获取当前用户(可选)。"""
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
payload = decode_access_token(access_token)
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def require_user(
|
||||
user: User | None = Depends(get_current_user),
|
||||
) -> User:
|
||||
"""要求用户登录,否则抛 401。"""
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="未登录",
|
||||
)
|
||||
return user
|
||||
48
backend/app/core/logging.py
Normal file
48
backend/app/core/logging.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""结构化日志配置。"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import structlog
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""配置 structlog 结构化日志。"""
|
||||
shared_processors = [
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
]
|
||||
|
||||
if settings.debug:
|
||||
processors = shared_processors + [
|
||||
structlog.dev.ConsoleRenderer(colors=True),
|
||||
]
|
||||
else:
|
||||
processors = shared_processors + [
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.JSONRenderer(),
|
||||
]
|
||||
|
||||
structlog.configure(
|
||||
processors=processors,
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
context_class=dict,
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(message)s",
|
||||
stream=sys.stdout,
|
||||
level=logging.DEBUG if settings.debug else logging.INFO,
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
|
||||
"""获取结构化日志器。"""
|
||||
return structlog.get_logger(name)
|
||||
190
backend/app/core/prompts.py
Normal file
190
backend/app/core/prompts.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# ruff: noqa: E501
|
||||
"""AI 提示词模板 (Modernized)"""
|
||||
|
||||
# 随机元素列表:为故事注入不可预测的魔法
|
||||
RANDOM_ELEMENTS = [
|
||||
"一个会打喷嚏的云朵",
|
||||
"一本地图上找不到的神秘图书馆",
|
||||
"一只能实现小愿望的彩色蜗牛",
|
||||
"一扇通往颠倒世界的门",
|
||||
"一顶能听懂动物说话的旧帽子",
|
||||
"一个装着星星的玻璃罐",
|
||||
"一棵结满笑声果实的树",
|
||||
"一只能在水上画画的画笔",
|
||||
"一个怕黑的影子",
|
||||
"一只收集回声的瓶子",
|
||||
"一双会自己跳舞的红鞋子",
|
||||
"一个只能在月光下看见的邮筒",
|
||||
"一张会改变模样的全家福",
|
||||
"一把可以打开梦境的钥匙",
|
||||
"一个喜欢讲冷笑话的冰箱",
|
||||
"一条通往星期八的秘密小径"
|
||||
]
|
||||
|
||||
# ==============================================================================
|
||||
# Model A: 故事生成 (Story Generation)
|
||||
# ==============================================================================
|
||||
|
||||
SYSTEM_INSTRUCTION_STORYTELLER = """
|
||||
# Role
|
||||
You are "**Dream Weaver**", a world-class children's storyteller with the imagination of Pixar and the warmth of Miyazaki.
|
||||
Your mission is to create engaging, safe, and educational stories for children (ages 3-8).
|
||||
|
||||
# Core Philosophy
|
||||
1. **Show, Don't Tell**: Don't preach the lesson. Let the character's actions and the plot demonstrate the theme.
|
||||
2. **Safety First**: No violence, horror, or scary elements. Conflict should be emotional or situational, not physical.
|
||||
3. **Vivid Imagery**: Use sensory details (colors, sounds, smells) that appeal to children.
|
||||
4. **Empowerment**: The child protagonist should solve the problem using wit, kindness, or courage, not just luck.
|
||||
|
||||
# Continuity & Memory (CRITICAL)
|
||||
- **Universal Context**: The story takes place in the child's established "Story Universe". Respect existing world rules.
|
||||
- **Character Consistency**: If "Child Profile" or "Sidekicks" are provided, you MUST use their specific names and traits. Do NOT invent new main characters unless asked.
|
||||
- **Callback**: If "Past Memories" are provided, try to make a natural, one-sentence reference to a past adventure to build a sense of continuity (e.g., "Just like when we found the lost star...").
|
||||
|
||||
# Output Format
|
||||
You MUST return a pure JSON object with NO markdown formatting (no ```json code blocks).
|
||||
The JSON object must have the following schema:
|
||||
{
|
||||
"mode": "generated",
|
||||
"title": "A catchy, imaginative title",
|
||||
"story_text": "The full story text. Use \\n\\n for paragraph breaks.",
|
||||
"cover_prompt_suggestion": "A detailed English image generation prompt for the story cover. Style: whimsical, children's book illustration, soft lighting, vibrant colors."
|
||||
}
|
||||
"""
|
||||
|
||||
USER_PROMPT_GENERATION = """
|
||||
# Task: Write a Children's Story
|
||||
|
||||
## Contextual Memory (Use these if provided)
|
||||
{memory_context}
|
||||
|
||||
## Inputs
|
||||
- **Keywords/Topic**: {keywords}
|
||||
- **Educational Theme**: {education_theme}
|
||||
- **Magic Element (Must Incorporate)**: {random_element}
|
||||
|
||||
## Constraints
|
||||
- Length: 300-600 words.
|
||||
- Structure: Beginning (Hook) -> Middle (Challenge) -> End (Resolution & Growth).
|
||||
"""
|
||||
|
||||
# ==============================================================================
|
||||
# Model B: 故事润色 (Story Enhancement)
|
||||
# ==============================================================================
|
||||
|
||||
SYSTEM_INSTRUCTION_ENHANCER = """
|
||||
# Role
|
||||
You are "**Dream Weaver Editor**", an expert children's book editor who turns rough drafts into polished gems.
|
||||
|
||||
# Mission
|
||||
Analyze the user's input story and rewrite it to be:
|
||||
1. **More Engaging**: Enhance the plot with a "Magic Element" to add surprise.
|
||||
2. **More Educational**: Weave the "Educational Theme" deeper into the narrative arc.
|
||||
3. **Better Written**: Polish the sentences for rhythm and flow (suitable for reading aloud).
|
||||
4. **Safe**: Remove any inappropriate content (violence, scary interaction) and replace it with constructive solutions.
|
||||
|
||||
# Output Format
|
||||
You MUST return a pure JSON object with NO markdown formatting (no ```json code blocks).
|
||||
The JSON object must have the following schema:
|
||||
{
|
||||
"mode": "enhanced",
|
||||
"title": "An improved title (or the original if perfect)",
|
||||
"story_text": "The rewritten story text. Use \\n\\n for paragraph breaks.",
|
||||
"cover_prompt_suggestion": "A detailed English image generation prompt for the cover."
|
||||
}
|
||||
"""
|
||||
|
||||
USER_PROMPT_ENHANCEMENT = """
|
||||
# Task: Enhance This Story
|
||||
|
||||
## Contextual Memory
|
||||
{memory_context}
|
||||
|
||||
## Inputs
|
||||
- **Original Story**: {full_story}
|
||||
- **Target Theme**: {education_theme}
|
||||
- **Magic Element to Add**: {random_element}
|
||||
|
||||
## Constraints
|
||||
- Length: 300-600 words.
|
||||
- Keep the original character names if possible, but feel free to upgrade the plot.
|
||||
"""
|
||||
|
||||
# ==============================================================================
|
||||
# Model C: 成就提取 (Achievement Extraction)
|
||||
# ==============================================================================
|
||||
|
||||
# 保持简单,暂不使用 System Instruction,沿用单次提示
|
||||
ACHIEVEMENT_EXTRACTION_PROMPT = """
|
||||
Analyze the story and extract key growth moments or achievements for the child protagonist.
|
||||
|
||||
# Story
|
||||
{story_text}
|
||||
|
||||
# Target Categories (Examples)
|
||||
- **Courage**: Overcoming fear, trying something new.
|
||||
- **Kindness**: Helping others, sharing, empathy.
|
||||
- **Curiosity**: Asking questions, exploring, learning.
|
||||
- **Resilience**: Not giving up, handling failure.
|
||||
- **Wisdom**: Problem-solving, honesty, patience.
|
||||
|
||||
# Output Format
|
||||
Return a pure JSON object (no markdown):
|
||||
{{
|
||||
"achievements": [
|
||||
{{
|
||||
"type": "Category Name",
|
||||
"description": "Brief reason (max 10 words)",
|
||||
"score": 8 // 1-10 intensity
|
||||
}}
|
||||
]
|
||||
}}
|
||||
"""
|
||||
|
||||
# ==============================================================================
|
||||
# Model D: 绘本生成 (Storybook Generation)
|
||||
# ==============================================================================
|
||||
|
||||
SYSTEM_INSTRUCTION_STORYBOOK = """
|
||||
# Role
|
||||
You are "**Dream Weaver Illustrator**", a creative children's book author and visual director.
|
||||
Your mission is to create a paginated picture book for children (ages 3-8), where each page has text and a matching illustration prompt.
|
||||
|
||||
# Core Philosophy
|
||||
1. **Pacing**: The story must flow logically across the specified number of pages.
|
||||
2. **Visual Consistency**: Define the "Main Character" and "Art Style" once, and ensure all image prompts adhere to them.
|
||||
3. **Language**: The story text MUST be in **Chinese (Simplified)**. The image prompts MUST be in **English**.
|
||||
4. **Memory**: If a memory context is provided, incorporate known characters or references naturally.
|
||||
|
||||
# Output Format
|
||||
You MUST return a pure JSON object using the following schema (no markdown):
|
||||
{
|
||||
"title": "Story Title (Chinese)",
|
||||
"main_character": "Description of the protagonist (e.g., 'A small blue robot with rusty gears')",
|
||||
"art_style": "Visual style description (e.g., 'Watercolor, soft pastel colors, whimsical')",
|
||||
"pages": [
|
||||
{
|
||||
"page_number": 1,
|
||||
"text": "Page text in Chinese (30-60 chars).",
|
||||
"image_prompt": "Detailed English image prompt describing the scene. Include 'main_character' reference."
|
||||
}
|
||||
],
|
||||
"cover_prompt": "English image prompt for the book cover."
|
||||
}
|
||||
"""
|
||||
|
||||
USER_PROMPT_STORYBOOK = """
|
||||
# Task: Create a {page_count}-Page Storybook
|
||||
|
||||
## Contextual Memory
|
||||
{memory_context}
|
||||
|
||||
## Inputs
|
||||
- **Keywords/Topic**: {keywords}
|
||||
- **Educational Theme**: {education_theme}
|
||||
- **Magic Element**: {random_element}
|
||||
|
||||
## Constraints
|
||||
- Pages: Exactly {page_count} pages.
|
||||
- Structure: Intro -> Development -> Climax -> Resolution.
|
||||
"""
|
||||
25
backend/app/core/security.py
Normal file
25
backend/app/core/security.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_DAYS = 7
|
||||
|
||||
|
||||
def create_access_token(data: dict) -> str:
|
||||
"""创建 JWT token"""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=ACCESS_TOKEN_EXPIRE_DAYS)
|
||||
to_encode.update({"exp": expire})
|
||||
return jwt.encode(to_encode, settings.secret_key, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> dict | None:
|
||||
"""解码 JWT token"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=[ALGORITHM])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
0
backend/app/db/__init__.py
Normal file
0
backend/app/db/__init__.py
Normal file
119
backend/app/db/admin_models.py
Normal file
119
backend/app/db/admin_models.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, Integer, Numeric, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.models import Base
|
||||
|
||||
|
||||
def _uuid() -> str:
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
class Provider(Base):
|
||||
"""Model provider registry."""
|
||||
|
||||
__tablename__ = "providers"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(50), nullable=False) # text/image/tts/storybook
|
||||
adapter: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
model: Mapped[str] = mapped_column(String(200), nullable=True)
|
||||
api_base: Mapped[str] = mapped_column(String(300), nullable=True)
|
||||
api_key: Mapped[str] = mapped_column(String(500), nullable=True) # 可选,优先于 config_ref
|
||||
timeout_ms: Mapped[int] = mapped_column(Integer, default=60000)
|
||||
max_retries: Mapped[int] = mapped_column(Integer, default=1)
|
||||
weight: Mapped[int] = mapped_column(Integer, default=1)
|
||||
priority: Mapped[int] = mapped_column(Integer, default=0)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
config_json: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 存储额外配置(speed, vol, etc)
|
||||
config_ref: Mapped[str] = mapped_column(String(100), nullable=True) # 环境变量 key 名称(回退)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
updated_by: Mapped[str] = mapped_column(String(100), nullable=True)
|
||||
|
||||
|
||||
class ProviderMetrics(Base):
|
||||
"""供应商调用指标记录。"""
|
||||
|
||||
__tablename__ = "provider_metrics"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
provider_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("providers.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, index=True
|
||||
)
|
||||
success: Mapped[bool] = mapped_column(Boolean, nullable=False)
|
||||
latency_ms: Mapped[int] = mapped_column(Integer, nullable=True)
|
||||
cost_usd: Mapped[Decimal] = mapped_column(Numeric(10, 6), nullable=True)
|
||||
error_message: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
request_id: Mapped[str] = mapped_column(String(100), nullable=True)
|
||||
|
||||
|
||||
class ProviderHealth(Base):
|
||||
"""供应商健康状态。"""
|
||||
|
||||
__tablename__ = "provider_health"
|
||||
|
||||
provider_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("providers.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
is_healthy: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
last_check: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
consecutive_failures: Mapped[int] = mapped_column(Integer, default=0)
|
||||
last_error: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
|
||||
|
||||
class ProviderSecret(Base):
|
||||
"""供应商密钥加密存储。"""
|
||||
|
||||
__tablename__ = "provider_secrets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
|
||||
encrypted_value: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
|
||||
class CostRecord(Base):
|
||||
"""成本记录表 - 记录每次 API 调用的成本。"""
|
||||
|
||||
__tablename__ = "cost_records"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[str] = mapped_column(String(36), nullable=False, index=True)
|
||||
provider_id: Mapped[str] = mapped_column(String(36), nullable=True) # 可能是环境变量配置
|
||||
provider_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
capability: Mapped[str] = mapped_column(String(50), nullable=False) # text/image/tts/storybook
|
||||
estimated_cost: Mapped[Decimal] = mapped_column(Numeric(10, 6), nullable=False)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, index=True
|
||||
)
|
||||
|
||||
|
||||
class UserBudget(Base):
|
||||
"""用户预算配置。"""
|
||||
|
||||
__tablename__ = "user_budgets"
|
||||
|
||||
user_id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
daily_limit_usd: Mapped[Decimal] = mapped_column(Numeric(10, 4), default=Decimal("1.0"))
|
||||
monthly_limit_usd: Mapped[Decimal] = mapped_column(Numeric(10, 4), default=Decimal("10.0"))
|
||||
alert_threshold: Mapped[Decimal] = mapped_column(
|
||||
Numeric(3, 2), default=Decimal("0.8")
|
||||
) # 80% 时告警
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
50
backend/app/db/database.py
Normal file
50
backend/app/db/database.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import threading
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
_engine = None
|
||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_engine():
|
||||
global _engine
|
||||
if _engine is None:
|
||||
with _lock:
|
||||
if _engine is None:
|
||||
_engine = create_async_engine(
|
||||
settings.database_url,
|
||||
echo=settings.debug,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=300,
|
||||
)
|
||||
return _engine
|
||||
|
||||
|
||||
def _get_session_factory():
|
||||
global _session_factory
|
||||
if _session_factory is None:
|
||||
with _lock:
|
||||
if _session_factory is None:
|
||||
_session_factory = async_sessionmaker(
|
||||
_get_engine(), class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
return _session_factory
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""Create tables if they do not exist."""
|
||||
from app.db.models import Base # main models
|
||||
|
||||
engine = _get_engine()
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def get_db():
|
||||
"""Yield a DB session with proper cleanup."""
|
||||
session_factory = _get_session_factory()
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
232
backend/app/db/models.py
Normal file
232
backend/app/db/models.py
Normal file
@@ -0,0 +1,232 @@
|
||||
from datetime import date, datetime, time
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
Date,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
Time,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Declarative base."""
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""User entity."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(255), primary_key=True) # OAuth provider user ID
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
avatar_url: Mapped[str | None] = mapped_column(String(500))
|
||||
provider: Mapped[str] = mapped_column(String(50), nullable=False) # github / google
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
stories: Mapped[list["Story"]] = relationship(
|
||||
"Story", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
child_profiles: Mapped[list["ChildProfile"]] = relationship(
|
||||
"ChildProfile", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Story(Base):
|
||||
"""Story entity."""
|
||||
|
||||
__tablename__ = "stories"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
child_profile_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("child_profiles.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
universe_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("story_universes.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
story_text: Mapped[str] = mapped_column(Text, nullable=True) # 允许为空(绘本模式下)
|
||||
pages: Mapped[list[dict] | None] = mapped_column(JSON, default=list) # 绘本分页数据
|
||||
cover_prompt: Mapped[str | None] = mapped_column(Text)
|
||||
image_url: Mapped[str | None] = mapped_column(String(500))
|
||||
mode: Mapped[str] = mapped_column(String(20), nullable=False, default="generated")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped["User"] = relationship("User", back_populates="stories")
|
||||
child_profile: Mapped["ChildProfile | None"] = relationship("ChildProfile")
|
||||
story_universe: Mapped["StoryUniverse | None"] = relationship("StoryUniverse")
|
||||
|
||||
|
||||
def _uuid() -> str:
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
class ChildProfile(Base):
|
||||
"""Child profile entity."""
|
||||
|
||||
__tablename__ = "child_profiles"
|
||||
__table_args__ = (UniqueConstraint("user_id", "name", name="uq_child_profile_user_name"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
avatar_url: Mapped[str | None] = mapped_column(String(500))
|
||||
birth_date: Mapped[date | None] = mapped_column(Date)
|
||||
gender: Mapped[str | None] = mapped_column(String(10))
|
||||
|
||||
interests: Mapped[list[str]] = mapped_column(JSON, default=list)
|
||||
growth_themes: Mapped[list[str]] = mapped_column(JSON, default=list)
|
||||
reading_preferences: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
|
||||
stories_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_reading_time: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
user: Mapped["User"] = relationship("User", back_populates="child_profiles")
|
||||
story_universes: Mapped[list["StoryUniverse"]] = relationship(
|
||||
"StoryUniverse", back_populates="child_profile", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@property
|
||||
def age(self) -> int | None:
|
||||
if not self.birth_date:
|
||||
return None
|
||||
today = date.today()
|
||||
return today.year - self.birth_date.year - (
|
||||
(today.month, today.day) < (self.birth_date.month, self.birth_date.day)
|
||||
)
|
||||
|
||||
|
||||
class StoryUniverse(Base):
|
||||
"""Story universe entity."""
|
||||
__tablename__ = "story_universes"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
child_profile_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
protagonist: Mapped[dict] = mapped_column(JSON, nullable=False)
|
||||
recurring_characters: Mapped[list] = mapped_column(JSON, default=list)
|
||||
world_settings: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
achievements: Mapped[list] = mapped_column(JSON, default=list)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
child_profile: Mapped["ChildProfile"] = relationship("ChildProfile", back_populates="story_universes")
|
||||
|
||||
|
||||
class ReadingEvent(Base):
|
||||
"""Reading event entity."""
|
||||
|
||||
__tablename__ = "reading_events"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
child_profile_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
story_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("stories.id", ondelete="SET NULL"), nullable=True, index=True
|
||||
)
|
||||
event_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
reading_time: Mapped[int] = mapped_column(Integer, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), index=True
|
||||
)
|
||||
|
||||
class PushConfig(Base):
|
||||
"""Push configuration entity."""
|
||||
|
||||
__tablename__ = "push_configs"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("child_profile_id", name="uq_push_config_child"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
child_profile_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
push_time: Mapped[time | None] = mapped_column(Time)
|
||||
push_days: Mapped[list[int]] = mapped_column(JSON, default=list)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class PushEvent(Base):
|
||||
"""Push event entity."""
|
||||
|
||||
__tablename__ = "push_events"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
child_profile_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
trigger_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
reason: Mapped[str | None] = mapped_column(Text)
|
||||
sent_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class MemoryItem(Base):
|
||||
"""Memory item entity with time decay metadata."""
|
||||
|
||||
__tablename__ = "memory_items"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
child_profile_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("child_profiles.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
universe_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("story_universes.id", ondelete="SET NULL"), nullable=True, index=True
|
||||
)
|
||||
type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
value: Mapped[dict] = mapped_column(JSON, nullable=False)
|
||||
base_weight: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
last_used_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
ttl_days: Mapped[int | None] = mapped_column(Integer)
|
||||
80
backend/app/main.py
Normal file
80
backend/app/main.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api import (
|
||||
auth,
|
||||
memories,
|
||||
profiles,
|
||||
push_configs,
|
||||
reading_events,
|
||||
stories,
|
||||
universes,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger, setup_logging
|
||||
from app.db.database import init_db
|
||||
|
||||
setup_logging()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""App lifespan manager."""
|
||||
logger.info("app_starting", app_name=settings.app_name)
|
||||
await init_db()
|
||||
logger.info("database_initialized")
|
||||
|
||||
# 加载 provider 缓存
|
||||
await _load_provider_cache()
|
||||
|
||||
yield
|
||||
logger.info("app_shutdown")
|
||||
|
||||
|
||||
async def _load_provider_cache():
|
||||
"""启动时加载 provider 缓存。"""
|
||||
from app.db.database import _get_session_factory
|
||||
from app.services.provider_cache import reload_providers
|
||||
|
||||
try:
|
||||
session_factory = _get_session_factory()
|
||||
async with session_factory() as session:
|
||||
cache = await reload_providers(session)
|
||||
provider_count = sum(len(v) for v in cache.values())
|
||||
logger.info("provider_cache_loaded", provider_count=provider_count)
|
||||
except Exception as e:
|
||||
logger.warning("provider_cache_load_failed", error=str(e))
|
||||
# 不阻止启动,使用 settings 中的默认配置
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
description="AI-driven story generator for kids.",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
app.include_router(stories.router, prefix="/api", tags=["stories"])
|
||||
app.include_router(profiles.router, prefix="/api", tags=["profiles"])
|
||||
app.include_router(universes.router, prefix="/api", tags=["universes"])
|
||||
app.include_router(push_configs.router, prefix="/api", tags=["push-configs"])
|
||||
app.include_router(reading_events.router, prefix="/api", tags=["reading-events"])
|
||||
app.include_router(memories.router, prefix="/api", tags=["memories"])
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Simple liveness check."""
|
||||
return {"status": "ok"}
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
85
backend/app/services/achievement_extractor.py
Normal file
85
backend/app/services/achievement_extractor.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Achievement extraction service."""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.core.prompts import ACHIEVEMENT_EXTRACTION_PROMPT
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
|
||||
async def extract_achievements(story_text: str) -> list[dict]:
|
||||
"""Extract achievements from story text using LLM."""
|
||||
if not settings.text_api_key:
|
||||
logger.warning("achievement_extraction_skipped", reason="missing_text_api_key")
|
||||
return []
|
||||
|
||||
model = settings.text_model or "gemini-2.0-flash"
|
||||
url = f"{TEXT_API_BASE}/{model}:generateContent"
|
||||
|
||||
prompt = ACHIEVEMENT_EXTRACTION_PROMPT.format(story_text=story_text)
|
||||
payload = {
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"responseMimeType": "application/json",
|
||||
"temperature": 0.2,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"x-goog-api-key": settings.text_api_key},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
candidates = result.get("candidates") or []
|
||||
if not candidates:
|
||||
logger.warning("achievement_extraction_empty")
|
||||
return []
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts") or []
|
||||
if not parts or "text" not in parts[0]:
|
||||
logger.warning("achievement_extraction_missing_text")
|
||||
return []
|
||||
|
||||
response_text = parts[0]["text"]
|
||||
clean_json = response_text
|
||||
if response_text.startswith("```json"):
|
||||
clean_json = re.sub(r"^```json\n|```$", "", response_text)
|
||||
|
||||
try:
|
||||
parsed = json.loads(clean_json)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("achievement_extraction_parse_failed")
|
||||
return []
|
||||
|
||||
achievements = parsed.get("achievements")
|
||||
if not isinstance(achievements, list):
|
||||
return []
|
||||
|
||||
normalized: list[dict] = []
|
||||
for item in achievements:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
a_type = str(item.get("type", "")).strip()
|
||||
description = str(item.get("description", "")).strip()
|
||||
score = item.get("score", 0)
|
||||
if not a_type or not description:
|
||||
continue
|
||||
normalized.append({
|
||||
"type": a_type,
|
||||
"description": description,
|
||||
"score": score
|
||||
})
|
||||
|
||||
return normalized
|
||||
21
backend/app/services/adapters/__init__.py
Normal file
21
backend/app/services/adapters/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""适配器模块 - 供应商平台化架构核心。"""
|
||||
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
|
||||
# Image adapters
|
||||
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
# Storybook adapters
|
||||
from app.services.adapters.storybook import primary as _storybook_primary # noqa: F401
|
||||
from app.services.adapters.text import gemini as _text_gemini_adapter # noqa: F401
|
||||
|
||||
# 导入所有适配器以触发注册
|
||||
# Text adapters
|
||||
from app.services.adapters.text import openai as _text_openai_adapter # noqa: F401
|
||||
|
||||
# TTS adapters
|
||||
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
|
||||
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
|
||||
|
||||
__all__ = ["AdapterConfig", "BaseAdapter", "AdapterRegistry"]
|
||||
46
backend/app/services/adapters/base.py
Normal file
46
backend/app/services/adapters/base.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""适配器基类定义。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AdapterConfig(BaseModel):
|
||||
"""适配器配置基类。"""
|
||||
|
||||
api_key: str
|
||||
api_base: str | None = None
|
||||
model: str | None = None
|
||||
timeout_ms: int = 60000
|
||||
max_retries: int = 3
|
||||
extra_config: dict = {}
|
||||
|
||||
|
||||
class BaseAdapter(ABC, Generic[T]):
|
||||
"""适配器基类,所有供应商适配器必须继承此类。"""
|
||||
|
||||
# 子类必须定义
|
||||
adapter_type: str # text / image / tts
|
||||
adapter_name: str # text_primary / image_primary / tts_primary
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> T:
|
||||
"""执行适配器逻辑,返回结果。"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> bool:
|
||||
"""健康检查,返回是否可用。"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估单次调用成本 (USD)。"""
|
||||
pass
|
||||
3
backend/app/services/adapters/image/__init__.py
Normal file
3
backend/app/services/adapters/image/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""图像生成适配器。"""# Image adapters
|
||||
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
|
||||
from app.services.adapters.image import antigravity as _image_antigravity_adapter # noqa: F401
|
||||
214
backend/app/services/adapters/image/antigravity.py
Normal file
214
backend/app/services/adapters/image/antigravity.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Antigravity 图像生成适配器。
|
||||
|
||||
使用 OpenAI 兼容 API 生成图像。
|
||||
支持 gemini-3-pro-image 等模型。
|
||||
"""
|
||||
|
||||
import base64
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_API_BASE = "http://127.0.0.1:8045/v1"
|
||||
DEFAULT_MODEL = "gemini-3-pro-image"
|
||||
DEFAULT_SIZE = "1024x1024"
|
||||
|
||||
# 支持的尺寸映射
|
||||
SUPPORTED_SIZES = {
|
||||
"1024x1024": "1:1",
|
||||
"1280x720": "16:9",
|
||||
"720x1280": "9:16",
|
||||
"1216x896": "4:3",
|
||||
}
|
||||
|
||||
|
||||
@AdapterRegistry.register("image", "antigravity")
|
||||
class AntigravityImageAdapter(BaseAdapter[str]):
|
||||
"""Antigravity 图像生成适配器 (OpenAI 兼容 API)。
|
||||
|
||||
特点:
|
||||
- 使用 OpenAI 兼容的 chat.completions 端点
|
||||
- 通过 extra_body.size 指定图像尺寸
|
||||
- 支持 gemini-3-pro-image 等模型
|
||||
- 返回图片 URL 或 base64
|
||||
"""
|
||||
|
||||
adapter_type = "image"
|
||||
adapter_name = "antigravity"
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
super().__init__(config)
|
||||
self.api_base = config.api_base or DEFAULT_API_BASE
|
||||
self.client = AsyncOpenAI(
|
||||
base_url=self.api_base,
|
||||
api_key=config.api_key,
|
||||
timeout=config.timeout_ms / 1000,
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
size: str | None = None,
|
||||
num_images: int = 1,
|
||||
**kwargs,
|
||||
) -> str | list[str]:
|
||||
"""根据提示词生成图片,返回 URL 或 base64。
|
||||
|
||||
Args:
|
||||
prompt: 图片描述提示词
|
||||
model: 模型名称 (gemini-3-pro-image / gemini-3-pro-image-16-9 等)
|
||||
size: 图像尺寸 (1024x1024, 1280x720, 720x1280, 1216x896)
|
||||
num_images: 生成图片数量 (暂只支持 1)
|
||||
|
||||
Returns:
|
||||
图片 URL 或 base64 字符串
|
||||
"""
|
||||
# 优先使用传入参数,其次使用 Adapter 配置,最后使用默认值
|
||||
model = model or self.config.model or DEFAULT_MODEL
|
||||
|
||||
cfg = self.config.extra_config or {}
|
||||
size = size or cfg.get("size") or DEFAULT_SIZE
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
"antigravity_generate_start",
|
||||
prompt_length=len(prompt),
|
||||
model=model,
|
||||
size=size,
|
||||
)
|
||||
|
||||
# 调用 API
|
||||
image_url = await self._generate_image(prompt, model, size)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"antigravity_generate_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
model=model,
|
||||
)
|
||||
|
||||
return image_url
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 Antigravity API 是否可用。"""
|
||||
try:
|
||||
# 简单测试连通性
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.config.model or DEFAULT_MODEL,
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("antigravity_health_check_failed", error=str(e))
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估每张图片成本 (USD)。
|
||||
|
||||
Antigravity 使用 Gemini 模型,成本约 $0.02/张。
|
||||
"""
|
||||
return 0.02
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((Exception,)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
size: str,
|
||||
) -> str:
|
||||
"""调用 Antigravity API 生成图像。
|
||||
|
||||
Returns:
|
||||
图片 URL 或 base64 data URI
|
||||
"""
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
extra_body={"size": size},
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
content = response.choices[0].message.content
|
||||
if not content:
|
||||
raise ValueError("Antigravity 未返回内容")
|
||||
|
||||
# 尝试解析为图片 URL 或 base64
|
||||
# 响应可能是纯 URL、base64 或 markdown 格式的图片
|
||||
image_url = self._extract_image_url(content)
|
||||
if image_url:
|
||||
return image_url
|
||||
|
||||
raise ValueError(f"Antigravity 响应无法解析为图片: {content[:200]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"antigravity_generate_error",
|
||||
error=str(e),
|
||||
model=model,
|
||||
)
|
||||
raise
|
||||
|
||||
def _extract_image_url(self, content: str) -> str | None:
|
||||
"""从响应内容中提取图片 URL。
|
||||
|
||||
支持多种格式:
|
||||
- 纯 URL: https://...
|
||||
- Markdown: 
|
||||
- Base64 data URI: data:image/...
|
||||
- 纯 base64 字符串
|
||||
"""
|
||||
content = content.strip()
|
||||
|
||||
# 1. 检查是否为 data URI
|
||||
if content.startswith("data:image/"):
|
||||
return content
|
||||
|
||||
# 2. 检查是否为纯 URL
|
||||
if content.startswith("http://") or content.startswith("https://"):
|
||||
# 可能有多行,取第一行
|
||||
return content.split("\n")[0].strip()
|
||||
|
||||
# 3. 检查 Markdown 图片格式 
|
||||
import re
|
||||
md_match = re.search(r"!\[.*?\]\((https?://[^\)]+)\)", content)
|
||||
if md_match:
|
||||
return md_match.group(1)
|
||||
|
||||
# 4. 检查是否像 base64 编码的图片数据
|
||||
if self._looks_like_base64(content):
|
||||
# 假设是 PNG
|
||||
return f"data:image/png;base64,{content}"
|
||||
|
||||
return None
|
||||
|
||||
def _looks_like_base64(self, s: str) -> bool:
|
||||
"""判断字符串是否看起来像 base64 编码。"""
|
||||
# Base64 只包含 A-Z, a-z, 0-9, +, /, =
|
||||
# 且长度通常较长
|
||||
if len(s) < 100:
|
||||
return False
|
||||
import re
|
||||
return bool(re.match(r"^[A-Za-z0-9+/=]+$", s.replace("\n", "")))
|
||||
252
backend/app/services/adapters/image/cqtai.py
Normal file
252
backend/app/services/adapters/image/cqtai.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""CQTAI nano 图像生成适配器。
|
||||
|
||||
支持异步生成 + 轮询获取结果。
|
||||
API 文档: https://api.cqtai.com
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_API_BASE = "https://api.cqtai.com"
|
||||
DEFAULT_MODEL = "nano-banana"
|
||||
DEFAULT_RESOLUTION = "2K"
|
||||
DEFAULT_ASPECT_RATIO = "1:1"
|
||||
POLL_INTERVAL_SECONDS = 2
|
||||
MAX_POLL_ATTEMPTS = 60 # 最多轮询 2 分钟
|
||||
|
||||
|
||||
@AdapterRegistry.register("image", "cqtai")
|
||||
class CQTAIImageAdapter(BaseAdapter[str]):
|
||||
"""CQTAI nano 图像生成适配器,返回图片 URL。
|
||||
|
||||
特点:
|
||||
- 异步生成 + 轮询获取结果
|
||||
- 支持 nano-banana (标准) 和 nano-banana-pro (高画质)
|
||||
- 支持多种分辨率和画面比例
|
||||
- 支持图生图 (filesUrl)
|
||||
"""
|
||||
|
||||
adapter_type = "image"
|
||||
adapter_name = "cqtai"
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
super().__init__(config)
|
||||
self.api_base = config.api_base or DEFAULT_API_BASE
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
resolution: str | None = None,
|
||||
aspect_ratio: str | None = None,
|
||||
num_images: int = 1,
|
||||
files_url: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> str | list[str]:
|
||||
"""根据提示词生成图片,返回 URL 或 URL 列表。
|
||||
|
||||
Args:
|
||||
prompt: 图片描述提示词
|
||||
model: 模型名称 (nano-banana / nano-banana-pro)
|
||||
resolution: 分辨率 (1K / 2K / 4K)
|
||||
aspect_ratio: 画面比例 (1:1, 16:9, 9:16, 4:3, 3:4 等)
|
||||
num_images: 生成图片数量 (1-4)
|
||||
files_url: 输入图片 URL 列表 (图生图)
|
||||
|
||||
Returns:
|
||||
单张图片返回 str,多张返回 list[str]
|
||||
"""
|
||||
# 1. 优先使用传入参数
|
||||
# 2. 其次使用 Adapter 配置里的 default (extra_config)
|
||||
# 3. 最后使用系统默认值
|
||||
model = model or self.config.model or DEFAULT_MODEL
|
||||
|
||||
cfg = self.config.extra_config or {}
|
||||
resolution = resolution or cfg.get("resolution") or DEFAULT_RESOLUTION
|
||||
aspect_ratio = aspect_ratio or cfg.get("aspect_ratio") or DEFAULT_ASPECT_RATIO
|
||||
num_images = min(max(num_images, 1), 4) # 限制 1-4
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
"cqtai_generate_start",
|
||||
prompt_length=len(prompt),
|
||||
model=model,
|
||||
resolution=resolution,
|
||||
aspect_ratio=aspect_ratio,
|
||||
num_images=num_images,
|
||||
)
|
||||
|
||||
# 1. 提交生成任务
|
||||
task_id = await self._submit_task(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
resolution=resolution,
|
||||
aspect_ratio=aspect_ratio,
|
||||
num_images=num_images,
|
||||
files_url=files_url or [],
|
||||
)
|
||||
|
||||
logger.info("cqtai_task_submitted", task_id=task_id)
|
||||
|
||||
# 2. 轮询获取结果
|
||||
result = await self._poll_result(task_id)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"cqtai_generate_success",
|
||||
task_id=task_id,
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
image_count=len(result) if isinstance(result, list) else 1,
|
||||
)
|
||||
|
||||
# 单张图片返回字符串,多张返回列表
|
||||
if num_images == 1 and isinstance(result, list) and len(result) == 1:
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 CQTAI API 是否可用。"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
# 简单的连通性测试
|
||||
response = await client.get(
|
||||
f"{self.api_base}/api/cqt/info/nano",
|
||||
params={"id": "health_check_test"},
|
||||
headers={"Authorization": self.config.api_key},
|
||||
)
|
||||
# 即使返回错误也说明服务可达
|
||||
return response.status_code in (200, 400, 401, 403, 404)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估每张图片成本 (USD)。
|
||||
|
||||
nano-banana: ¥0.1 ≈ $0.014
|
||||
nano-banana-pro: ¥0.2 ≈ $0.028
|
||||
"""
|
||||
model = self.config.model or DEFAULT_MODEL
|
||||
if model == "nano-banana-pro":
|
||||
return 0.028
|
||||
return 0.014
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _submit_task(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
num_images: int,
|
||||
files_url: list[str],
|
||||
) -> str:
|
||||
"""提交图像生成任务,返回任务 ID。"""
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"numImages": num_images,
|
||||
"aspectRatio": aspect_ratio,
|
||||
"filesUrl": files_url,
|
||||
}
|
||||
|
||||
# 可选参数,不传则使用默认值
|
||||
if model != DEFAULT_MODEL:
|
||||
payload["model"] = model
|
||||
if resolution != DEFAULT_RESOLUTION:
|
||||
payload["resolution"] = resolution
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.api_base}/api/cqt/generator/nano",
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": self.config.api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if data.get("code") != 200:
|
||||
raise ValueError(f"CQTAI 任务提交失败: {data.get('msg', '未知错误')}")
|
||||
|
||||
task_id = data.get("data")
|
||||
if not task_id:
|
||||
raise ValueError("CQTAI 未返回任务 ID")
|
||||
|
||||
return task_id
|
||||
|
||||
async def _poll_result(self, task_id: str) -> list[str]:
|
||||
"""轮询获取生成结果。
|
||||
|
||||
Returns:
|
||||
图片 URL 列表
|
||||
"""
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
for attempt in range(MAX_POLL_ATTEMPTS):
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.api_base}/api/cqt/info/nano",
|
||||
params={"id": task_id},
|
||||
headers={"Authorization": self.config.api_key},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if data.get("code") != 200:
|
||||
raise ValueError(f"CQTAI 查询失败: {data.get('msg', '未知错误')}")
|
||||
|
||||
result_data = data.get("data", {})
|
||||
status = result_data.get("status")
|
||||
|
||||
if status == "completed":
|
||||
# 提取图片 URL
|
||||
images = result_data.get("images", [])
|
||||
if not images:
|
||||
# 兼容不同返回格式
|
||||
image_url = result_data.get("imageUrl") or result_data.get("url")
|
||||
if image_url:
|
||||
images = [image_url]
|
||||
|
||||
if not images:
|
||||
raise ValueError("CQTAI 未返回图片 URL")
|
||||
|
||||
return images
|
||||
|
||||
elif status == "failed":
|
||||
error_msg = result_data.get("error", "生成失败")
|
||||
raise ValueError(f"CQTAI 图像生成失败: {error_msg}")
|
||||
|
||||
# 继续等待
|
||||
logger.debug(
|
||||
"cqtai_poll_waiting",
|
||||
task_id=task_id,
|
||||
attempt=attempt + 1,
|
||||
status=status,
|
||||
)
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
raise TimeoutError(f"CQTAI 任务超时: {task_id}")
|
||||
73
backend/app/services/adapters/registry.py
Normal file
73
backend/app/services/adapters/registry.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""适配器注册表 - 支持动态注册和工厂创建。"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
|
||||
|
||||
class AdapterRegistry:
|
||||
"""适配器注册表,管理所有已注册的适配器类。"""
|
||||
|
||||
_adapters: dict[str, type["BaseAdapter"]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, adapter_type: str, adapter_name: str):
|
||||
"""装饰器:注册适配器类。
|
||||
|
||||
用法:
|
||||
@AdapterRegistry.register("text", "text_primary")
|
||||
class TextPrimaryAdapter(BaseAdapter[StoryOutput]):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(adapter_class: type["BaseAdapter"]):
|
||||
key = f"{adapter_type}:{adapter_name}"
|
||||
cls._adapters[key] = adapter_class
|
||||
# 自动设置类属性
|
||||
adapter_class.adapter_type = adapter_type
|
||||
adapter_class.adapter_name = adapter_name
|
||||
return adapter_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def get(cls, adapter_type: str, adapter_name: str) -> type["BaseAdapter"] | None:
|
||||
"""获取已注册的适配器类。"""
|
||||
key = f"{adapter_type}:{adapter_name}"
|
||||
return cls._adapters.get(key)
|
||||
|
||||
@classmethod
|
||||
def list_adapters(cls, adapter_type: str | None = None) -> list[str]:
|
||||
"""列出所有已注册的适配器。
|
||||
|
||||
Args:
|
||||
adapter_type: 可选,筛选特定类型 (text/image/tts)
|
||||
|
||||
Returns:
|
||||
适配器键列表,格式为 "type:name"
|
||||
"""
|
||||
if adapter_type:
|
||||
return [k for k in cls._adapters if k.startswith(f"{adapter_type}:")]
|
||||
return list(cls._adapters.keys())
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
adapter_type: str,
|
||||
adapter_name: str,
|
||||
config: "AdapterConfig",
|
||||
) -> "BaseAdapter":
|
||||
"""工厂方法:创建适配器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 适配器未注册
|
||||
"""
|
||||
adapter_class = cls.get(adapter_type, adapter_name)
|
||||
if not adapter_class:
|
||||
available = cls.list_adapters(adapter_type)
|
||||
raise ValueError(
|
||||
f"适配器 '{adapter_type}:{adapter_name}' 未注册。"
|
||||
f"可用: {available}"
|
||||
)
|
||||
return adapter_class(config)
|
||||
1
backend/app/services/adapters/storybook/__init__.py
Normal file
1
backend/app/services/adapters/storybook/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Storybook 适配器模块。"""
|
||||
195
backend/app/services/adapters/storybook/primary.py
Normal file
195
backend/app/services/adapters/storybook/primary.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Storybook 适配器 - 生成可翻页的分页故事书。"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.prompts import (
|
||||
RANDOM_ELEMENTS,
|
||||
SYSTEM_INSTRUCTION_STORYBOOK,
|
||||
USER_PROMPT_STORYBOOK,
|
||||
)
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorybookPage:
|
||||
"""故事书单页。"""
|
||||
|
||||
page_number: int
|
||||
text: str
|
||||
image_prompt: str
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Storybook:
|
||||
"""故事书输出。"""
|
||||
|
||||
title: str
|
||||
main_character: str
|
||||
art_style: str
|
||||
pages: list[StorybookPage] = field(default_factory=list)
|
||||
cover_prompt: str = ""
|
||||
cover_url: str | None = None
|
||||
|
||||
|
||||
@AdapterRegistry.register("storybook", "storybook_primary")
|
||||
class StorybookPrimaryAdapter(BaseAdapter[Storybook]):
|
||||
"""Storybook 生成适配器(默认)。
|
||||
|
||||
生成分页故事书结构,包含每页文字和图像提示词。
|
||||
图像生成需要单独调用 image adapter。
|
||||
"""
|
||||
|
||||
adapter_type = "storybook"
|
||||
adapter_name = "storybook_primary"
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
keywords: str,
|
||||
page_count: int = 6,
|
||||
education_theme: str | None = None,
|
||||
memory_context: str | None = None,
|
||||
**kwargs,
|
||||
) -> Storybook:
|
||||
"""生成分页故事书。
|
||||
|
||||
Args:
|
||||
keywords: 故事关键词
|
||||
page_count: 页数 (4-12)
|
||||
education_theme: 教育主题
|
||||
memory_context: 记忆上下文
|
||||
|
||||
Returns:
|
||||
Storybook 对象,包含标题、页面列表和封面提示词
|
||||
"""
|
||||
start_time = time.time()
|
||||
page_count = max(4, min(page_count, 12)) # 限制 4-12 页
|
||||
|
||||
logger.info(
|
||||
"storybook_generate_start",
|
||||
keywords=keywords,
|
||||
page_count=page_count,
|
||||
has_memory=bool(memory_context),
|
||||
)
|
||||
|
||||
theme = education_theme or "成长"
|
||||
random_element = random.choice(RANDOM_ELEMENTS)
|
||||
|
||||
prompt = USER_PROMPT_STORYBOOK.format(
|
||||
keywords=keywords,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
page_count=page_count,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
|
||||
payload = {
|
||||
"system_instruction": {"parts": [{"text": SYSTEM_INSTRUCTION_STORYBOOK}]},
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"responseMimeType": "application/json",
|
||||
"temperature": 0.95,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
result = await self._call_api(payload)
|
||||
|
||||
candidates = result.get("candidates") or []
|
||||
if not candidates:
|
||||
raise ValueError("Storybook 服务未返回内容")
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts") or []
|
||||
if not parts or "text" not in parts[0]:
|
||||
raise ValueError("Storybook 服务响应缺少文本")
|
||||
|
||||
response_text = parts[0]["text"]
|
||||
clean_json = response_text
|
||||
if response_text.startswith("```json"):
|
||||
clean_json = re.sub(r"^```json\n|```$", "", response_text)
|
||||
|
||||
try:
|
||||
parsed = json.loads(clean_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"Storybook JSON 解析失败: {exc}")
|
||||
|
||||
# 构建 Storybook 对象
|
||||
pages = [
|
||||
StorybookPage(
|
||||
page_number=p.get("page_number", i + 1),
|
||||
text=p.get("text", ""),
|
||||
image_prompt=p.get("image_prompt", ""),
|
||||
)
|
||||
for i, p in enumerate(parsed.get("pages", []))
|
||||
]
|
||||
|
||||
storybook = Storybook(
|
||||
title=parsed.get("title", "未命名故事"),
|
||||
main_character=parsed.get("main_character", ""),
|
||||
art_style=parsed.get("art_style", ""),
|
||||
pages=pages,
|
||||
cover_prompt=parsed.get("cover_prompt", ""),
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"storybook_generate_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
title=storybook.title,
|
||||
page_count=len(pages),
|
||||
)
|
||||
|
||||
return storybook
|
||||
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 API 是否可用。"""
|
||||
try:
|
||||
payload = {
|
||||
"contents": [{"parts": [{"text": "Hi"}]}],
|
||||
"generationConfig": {"maxOutputTokens": 10},
|
||||
}
|
||||
await self._call_api(payload)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估成本(仅文本生成,不含图像)。"""
|
||||
return 0.002 # 比普通故事稍贵,因为输出更长
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_api(self, payload: dict) -> dict:
|
||||
"""调用 API,带重试机制。"""
|
||||
model = self.config.model or "gemini-2.0-flash"
|
||||
url = f"{TEXT_API_BASE}/{model}:generateContent?key={self.config.api_key}"
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
1
backend/app/services/adapters/text/__init__.py
Normal file
1
backend/app/services/adapters/text/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""文本生成适配器。"""
|
||||
164
backend/app/services/adapters/text/gemini.py
Normal file
164
backend/app/services/adapters/text/gemini.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""文本生成适配器 (Google Gemini)。"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.prompts import (
|
||||
RANDOM_ELEMENTS,
|
||||
SYSTEM_INSTRUCTION_ENHANCER,
|
||||
SYSTEM_INSTRUCTION_STORYTELLER,
|
||||
USER_PROMPT_ENHANCEMENT,
|
||||
USER_PROMPT_GENERATION,
|
||||
)
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
|
||||
@AdapterRegistry.register("text", "gemini")
|
||||
class GeminiTextAdapter(BaseAdapter[StoryOutput]):
|
||||
"""Google Gemini 文本生成适配器。"""
|
||||
|
||||
adapter_type = "text"
|
||||
adapter_name = "gemini"
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
input_type: Literal["keywords", "full_story"],
|
||||
data: str,
|
||||
education_theme: str | None = None,
|
||||
memory_context: str | None = None,
|
||||
**kwargs,
|
||||
) -> StoryOutput:
|
||||
"""生成或润色故事。"""
|
||||
start_time = time.time()
|
||||
logger.info("request_start", adapter="gemini", input_type=input_type, data_length=len(data))
|
||||
|
||||
theme = education_theme or "成长"
|
||||
random_element = random.choice(RANDOM_ELEMENTS)
|
||||
|
||||
if input_type == "keywords":
|
||||
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
|
||||
prompt = USER_PROMPT_GENERATION.format(
|
||||
keywords=data,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
else:
|
||||
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
|
||||
prompt = USER_PROMPT_ENHANCEMENT.format(
|
||||
full_story=data,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
|
||||
# Gemini API Payload supports 'system_instruction'
|
||||
payload = {
|
||||
"system_instruction": {"parts": [{"text": system_instruction}]},
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"responseMimeType": "application/json",
|
||||
"temperature": 0.95,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
result = await self._call_api(payload)
|
||||
|
||||
candidates = result.get("candidates") or []
|
||||
if not candidates:
|
||||
raise ValueError("Gemini 未返回内容")
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts") or []
|
||||
if not parts or "text" not in parts[0]:
|
||||
raise ValueError("Gemini 响应缺少文本")
|
||||
|
||||
response_text = parts[0]["text"]
|
||||
clean_json = response_text
|
||||
if response_text.startswith("```json"):
|
||||
clean_json = re.sub(r"^```json\n|```$", "", response_text)
|
||||
|
||||
try:
|
||||
parsed = json.loads(clean_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"Gemini 输出 JSON 解析失败: {exc}")
|
||||
|
||||
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
|
||||
if any(field not in parsed for field in required_fields):
|
||||
raise ValueError("Gemini 输出缺少必要字段")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"request_success",
|
||||
adapter="gemini",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
title=parsed["title"],
|
||||
)
|
||||
|
||||
return StoryOutput(
|
||||
mode=parsed["mode"],
|
||||
title=parsed["title"],
|
||||
story_text=parsed["story_text"],
|
||||
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
|
||||
)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 Gemini API 是否可用。"""
|
||||
try:
|
||||
payload = {
|
||||
"contents": [{"parts": [{"text": "Hi"}]}],
|
||||
"generationConfig": {"maxOutputTokens": 10},
|
||||
}
|
||||
await self._call_api(payload)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
return 0.001
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_api(self, payload: dict) -> dict:
|
||||
"""调用 Gemini API。"""
|
||||
model = self.config.model or "gemini-2.0-flash"
|
||||
base_url = self.config.api_base or TEXT_API_BASE
|
||||
|
||||
# 智能补全:
|
||||
# 1. 如果用户填了完整路径 (以 /models 结尾),就直接用 (支持 v1 或 v1beta)
|
||||
if self.config.api_base and base_url.rstrip("/").endswith("/models"):
|
||||
pass
|
||||
# 2. 如果没填路径 (只是域名),默认补全代码适配的 /v1beta/models
|
||||
elif self.config.api_base:
|
||||
base_url = f"{base_url.rstrip('/')}/v1beta/models"
|
||||
|
||||
url = f"{base_url}/{model}:generateContent?key={self.config.api_key}"
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
11
backend/app/services/adapters/text/models.py
Normal file
11
backend/app/services/adapters/text/models.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoryOutput:
|
||||
"""故事生成输出。"""
|
||||
mode: Literal["generated", "enhanced"]
|
||||
title: str
|
||||
story_text: str
|
||||
cover_prompt_suggestion: str
|
||||
172
backend/app/services/adapters/text/openai.py
Normal file
172
backend/app/services/adapters/text/openai.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""OpenAI 文本生成适配器。"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.prompts import (
|
||||
RANDOM_ELEMENTS,
|
||||
SYSTEM_INSTRUCTION_ENHANCER,
|
||||
SYSTEM_INSTRUCTION_STORYTELLER,
|
||||
USER_PROMPT_ENHANCEMENT,
|
||||
USER_PROMPT_GENERATION,
|
||||
)
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
OPENAI_API_BASE = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
|
||||
|
||||
|
||||
@AdapterRegistry.register("text", "openai")
|
||||
class OpenAITextAdapter(BaseAdapter[StoryOutput]):
|
||||
"""OpenAI 文本生成适配器。"""
|
||||
|
||||
adapter_type = "text"
|
||||
adapter_name = "openai"
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
input_type: Literal["keywords", "full_story"],
|
||||
data: str,
|
||||
education_theme: str | None = None,
|
||||
memory_context: str | None = None,
|
||||
**kwargs,
|
||||
) -> StoryOutput:
|
||||
"""生成或润色故事。"""
|
||||
start_time = time.time()
|
||||
logger.info("openai_text_request_start", input_type=input_type, data_length=len(data))
|
||||
|
||||
theme = education_theme or "成长"
|
||||
random_element = random.choice(RANDOM_ELEMENTS)
|
||||
|
||||
if input_type == "keywords":
|
||||
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
|
||||
prompt = USER_PROMPT_GENERATION.format(
|
||||
keywords=data,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
else:
|
||||
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
|
||||
prompt = USER_PROMPT_ENHANCEMENT.format(
|
||||
full_story=data,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
|
||||
model = self.config.model or "gpt-4o-mini"
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_instruction,
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"response_format": {"type": "json_object"},
|
||||
"temperature": 0.95,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
|
||||
result = await self._call_api(payload)
|
||||
|
||||
choices = result.get("choices") or []
|
||||
if not choices:
|
||||
raise ValueError("OpenAI 未返回内容")
|
||||
|
||||
response_text = choices[0].get("message", {}).get("content", "")
|
||||
if not response_text:
|
||||
raise ValueError("OpenAI 响应缺少文本")
|
||||
|
||||
clean_json = response_text
|
||||
if response_text.startswith("```json"):
|
||||
clean_json = re.sub(r"^```json\n|```$", "", response_text)
|
||||
|
||||
try:
|
||||
parsed = json.loads(clean_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"OpenAI 输出 JSON 解析失败: {exc}")
|
||||
|
||||
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
|
||||
if any(field not in parsed for field in required_fields):
|
||||
raise ValueError("OpenAI 输出缺少必要字段")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"openai_text_request_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
title=parsed["title"],
|
||||
mode=parsed["mode"],
|
||||
)
|
||||
|
||||
return StoryOutput(
|
||||
mode=parsed["mode"],
|
||||
title=parsed["title"],
|
||||
story_text=parsed["story_text"],
|
||||
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(httpx.HTTPStatusError),
|
||||
)
|
||||
async def _call_api(self, payload: dict) -> dict:
|
||||
"""调用 OpenAI API,带重试机制。"""
|
||||
url = self.config.api_base or OPENAI_API_BASE
|
||||
|
||||
# 智能补全: 如果用户只填了 Base URL,自动补全路径
|
||||
if self.config.api_base and not url.endswith("/chat/completions"):
|
||||
base = url.rstrip("/")
|
||||
url = f"{base}/chat/completions"
|
||||
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 OpenAI API 是否可用。"""
|
||||
try:
|
||||
payload = {
|
||||
"model": self.config.model or "gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 5,
|
||||
}
|
||||
await self._call_api(payload)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估文本生成成本 (USD)。"""
|
||||
return 0.01
|
||||
5
backend/app/services/adapters/tts/__init__.py
Normal file
5
backend/app/services/adapters/tts/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""TTS 语音合成适配器。"""
|
||||
|
||||
from app.services.adapters.tts import edge_tts as _tts_edge_tts_adapter # noqa: F401
|
||||
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
|
||||
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
|
||||
66
backend/app/services/adapters/tts/edge_tts.py
Normal file
66
backend/app/services/adapters/tts/edge_tts.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""EdgeTTS 免费语音生成适配器。"""
|
||||
|
||||
import time
|
||||
|
||||
import edge_tts
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 默认中文女声 (晓晓)
|
||||
DEFAULT_VOICE = "zh-CN-XiaoxiaoNeural"
|
||||
|
||||
|
||||
@AdapterRegistry.register("tts", "edge_tts")
|
||||
class EdgeTTSAdapter(BaseAdapter[bytes]):
|
||||
"""EdgeTTS 语音生成适配器 (Free)。
|
||||
|
||||
不需要 API Key。
|
||||
"""
|
||||
|
||||
adapter_type = "tts"
|
||||
adapter_name = "edge_tts"
|
||||
|
||||
async def execute(self, text: str, **kwargs) -> bytes:
|
||||
"""生成语音。"""
|
||||
# 支持动态指定音色
|
||||
voice = kwargs.get("voice") or self.config.model or DEFAULT_VOICE
|
||||
|
||||
start_time = time.time()
|
||||
logger.info("edge_tts_generate_start", text_length=len(text), voice=voice)
|
||||
|
||||
# EdgeTTS 只能输出到文件,我们需要用临时文件周转一下
|
||||
# 或者直接 capture stream (communicate) 但 edge-tts 库主要面向文件
|
||||
|
||||
# 优化: 使用 communicate 直接获取 bytes,无需磁盘IO
|
||||
communicate = edge_tts.Communicate(text, voice)
|
||||
|
||||
audio_data = b""
|
||||
async for chunk in communicate.stream():
|
||||
if chunk["type"] == "audio":
|
||||
audio_data += chunk["data"]
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"edge_tts_generate_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
audio_size_bytes=len(audio_data),
|
||||
)
|
||||
|
||||
return audio_data
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 EdgeTTS 是否可用 (网络连通性)。"""
|
||||
try:
|
||||
# 简单生成一个词
|
||||
await self.execute("Hi")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
return 0.0 # Free!
|
||||
104
backend/app/services/adapters/tts/elevenlabs.py
Normal file
104
backend/app/services/adapters/tts/elevenlabs.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""ElevenLabs TTS 语音合成适配器。"""
|
||||
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
ELEVENLABS_API_BASE = "https://api.elevenlabs.io/v1"
|
||||
DEFAULT_VOICE_ID = "21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
|
||||
|
||||
@AdapterRegistry.register("tts", "elevenlabs")
|
||||
class ElevenLabsTtsAdapter(BaseAdapter[bytes]):
|
||||
"""ElevenLabs TTS 语音合成适配器,返回 MP3 bytes。"""
|
||||
|
||||
adapter_type = "tts"
|
||||
adapter_name = "elevenlabs"
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
super().__init__(config)
|
||||
self.api_base = config.api_base or ELEVENLABS_API_BASE
|
||||
|
||||
async def execute(self, text: str, **kwargs) -> bytes:
|
||||
"""将文本转换为语音 MP3 bytes。"""
|
||||
start_time = time.time()
|
||||
logger.info("elevenlabs_tts_start", text_length=len(text))
|
||||
|
||||
voice_id = kwargs.get("voice_id") or DEFAULT_VOICE_ID
|
||||
model_id = kwargs.get("model") or self.config.model or "eleven_multilingual_v2"
|
||||
stability = kwargs.get("stability", 0.5)
|
||||
similarity_boost = kwargs.get("similarity_boost", 0.75)
|
||||
|
||||
url = f"{self.api_base}/text-to-speech/{voice_id}"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"model_id": model_id,
|
||||
"voice_settings": {
|
||||
"stability": stability,
|
||||
"similarity_boost": similarity_boost,
|
||||
},
|
||||
}
|
||||
|
||||
audio_bytes = await self._call_api(url, payload)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"elevenlabs_tts_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
audio_size_bytes=len(audio_bytes),
|
||||
)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 ElevenLabs API 是否可用。"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(
|
||||
f"{self.api_base}/voices",
|
||||
headers={"xi-api-key": self.config.api_key},
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估每千字符成本 (USD)。"""
|
||||
return 0.03
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_api(self, url: str, payload: dict) -> bytes:
|
||||
"""调用 ElevenLabs API,带重试机制。"""
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={
|
||||
"xi-api-key": self.config.api_key,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "audio/mpeg",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
149
backend/app/services/adapters/tts/minimax.py
Normal file
149
backend/app/services/adapters/tts/minimax.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""MiniMax 语音生成适配器 (T2A V2)。"""
|
||||
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# MiniMax API 配置
|
||||
DEFAULT_API_URL = "https://api.minimaxi.com/v1/t2a_v2"
|
||||
DEFAULT_MODEL = "speech-2.6-turbo"
|
||||
|
||||
@AdapterRegistry.register("tts", "minimax")
|
||||
class MiniMaxTTSAdapter(BaseAdapter[bytes]):
|
||||
"""MiniMax 语音生成适配器。
|
||||
|
||||
需要配置:
|
||||
- api_key: MiniMax API Key
|
||||
- minimax_group_id: 可选 (取决于使用的模型/账户类型)
|
||||
"""
|
||||
|
||||
adapter_type = "tts"
|
||||
adapter_name = "minimax"
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
super().__init__(config)
|
||||
self.api_url = DEFAULT_API_URL
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
text: str,
|
||||
voice_id: str | None = None,
|
||||
model: str | None = None,
|
||||
speed: float | None = None,
|
||||
vol: float | None = None,
|
||||
pitch: int | None = None,
|
||||
emotion: str | None = None,
|
||||
**kwargs,
|
||||
) -> bytes:
|
||||
"""生成语音。"""
|
||||
# 1. 优先使用传入参数
|
||||
# 2. 其次使用 Adapter 配置里的 default
|
||||
# 3. 最后使用系统默认值
|
||||
model = model or self.config.model or DEFAULT_MODEL
|
||||
|
||||
cfg = self.config.extra_config or {}
|
||||
|
||||
voice_id = voice_id or cfg.get("voice_id") or "male-qn-qingse"
|
||||
speed = speed if speed is not None else (cfg.get("speed") or 1.0)
|
||||
vol = vol if vol is not None else (cfg.get("vol") or 1.0)
|
||||
pitch = pitch if pitch is not None else (cfg.get("pitch") or 0)
|
||||
emotion = emotion or cfg.get("emotion")
|
||||
group_id = kwargs.get("group_id") or settings.minimax_group_id
|
||||
|
||||
url = self.api_url
|
||||
if group_id:
|
||||
url = f"{self.api_url}?GroupId={group_id}"
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"text": text,
|
||||
"stream": False,
|
||||
"voice_setting": {
|
||||
"voice_id": voice_id,
|
||||
"speed": speed,
|
||||
"vol": vol,
|
||||
"pitch": pitch,
|
||||
},
|
||||
"audio_setting": {
|
||||
"sample_rate": 32000,
|
||||
"bitrate": 128000,
|
||||
"format": "mp3",
|
||||
"channel": 1
|
||||
}
|
||||
}
|
||||
|
||||
if emotion:
|
||||
payload["voice_setting"]["emotion"] = emotion
|
||||
|
||||
start_time = time.time()
|
||||
logger.info("minimax_generate_start", text_length=len(text), model=model)
|
||||
|
||||
result = await self._call_api(url, payload)
|
||||
|
||||
# 错误处理
|
||||
if result.get("base_resp", {}).get("status_code") != 0:
|
||||
error_msg = result.get("base_resp", {}).get("status_msg", "未知错误")
|
||||
raise ValueError(f"MiniMax API 错误: {error_msg}")
|
||||
|
||||
# Hex 解码 (关键逻辑,从 primary.py 迁移)
|
||||
hex_audio = result.get("data", {}).get("audio")
|
||||
if not hex_audio:
|
||||
raise ValueError("API 响应中未找到音频数据 (data.audio)")
|
||||
|
||||
try:
|
||||
audio_bytes = bytes.fromhex(hex_audio)
|
||||
except ValueError:
|
||||
raise ValueError("MiniMax 返回的音频数据不是有效的 Hex 字符串")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"minimax_generate_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
audio_size_bytes=len(audio_bytes),
|
||||
)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 Minimax API 是否可用。"""
|
||||
try:
|
||||
# 尝试生成极短文本
|
||||
await self.execute("Hi")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_api(self, url: str, payload: dict) -> dict:
|
||||
"""调用 API,带重试机制。"""
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
196
backend/app/services/cost_tracker.py
Normal file
196
backend/app/services/cost_tracker.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""成本追踪服务。
|
||||
|
||||
记录 API 调用成本,支持预算控制。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.admin_models import CostRecord, UserBudget
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BudgetExceededError(Exception):
|
||||
"""预算超限错误。"""
|
||||
|
||||
def __init__(self, limit_type: str, used: Decimal, limit: Decimal):
|
||||
self.limit_type = limit_type
|
||||
self.used = used
|
||||
self.limit = limit
|
||||
super().__init__(f"{limit_type} 预算已超限: {used}/{limit} USD")
|
||||
|
||||
|
||||
class CostTracker:
|
||||
"""成本追踪器。"""
|
||||
|
||||
async def record_cost(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
provider_name: str,
|
||||
capability: str,
|
||||
estimated_cost: float,
|
||||
provider_id: str | None = None,
|
||||
) -> CostRecord:
|
||||
"""记录一次 API 调用成本。"""
|
||||
record = CostRecord(
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
capability=capability,
|
||||
estimated_cost=Decimal(str(estimated_cost)),
|
||||
)
|
||||
db.add(record)
|
||||
await db.commit()
|
||||
|
||||
logger.debug(
|
||||
"cost_recorded",
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
capability=capability,
|
||||
cost=estimated_cost,
|
||||
)
|
||||
return record
|
||||
|
||||
async def get_daily_cost(self, db: AsyncSession, user_id: str) -> Decimal:
|
||||
"""获取用户今日成本。"""
|
||||
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
result = await db.execute(
|
||||
select(func.sum(CostRecord.estimated_cost)).where(
|
||||
CostRecord.user_id == user_id,
|
||||
CostRecord.timestamp >= today_start,
|
||||
)
|
||||
)
|
||||
total = result.scalar()
|
||||
return Decimal(str(total)) if total else Decimal("0")
|
||||
|
||||
async def get_monthly_cost(self, db: AsyncSession, user_id: str) -> Decimal:
|
||||
"""获取用户本月成本。"""
|
||||
now = datetime.utcnow()
|
||||
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
result = await db.execute(
|
||||
select(func.sum(CostRecord.estimated_cost)).where(
|
||||
CostRecord.user_id == user_id,
|
||||
CostRecord.timestamp >= month_start,
|
||||
)
|
||||
)
|
||||
total = result.scalar()
|
||||
return Decimal(str(total)) if total else Decimal("0")
|
||||
|
||||
async def get_cost_by_capability(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
days: int = 30,
|
||||
) -> dict[str, Decimal]:
|
||||
"""按能力类型统计成本。"""
|
||||
since = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
result = await db.execute(
|
||||
select(CostRecord.capability, func.sum(CostRecord.estimated_cost))
|
||||
.where(CostRecord.user_id == user_id, CostRecord.timestamp >= since)
|
||||
.group_by(CostRecord.capability)
|
||||
)
|
||||
return {row[0]: Decimal(str(row[1])) for row in result.all()}
|
||||
|
||||
async def check_budget(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
estimated_cost: float,
|
||||
) -> bool:
|
||||
"""检查预算是否允许此次调用。
|
||||
|
||||
Returns:
|
||||
True 如果允许,否则抛出 BudgetExceededError
|
||||
"""
|
||||
budget = await self.get_user_budget(db, user_id)
|
||||
if not budget or not budget.enabled:
|
||||
return True
|
||||
|
||||
# 检查日预算
|
||||
daily_cost = await self.get_daily_cost(db, user_id)
|
||||
if daily_cost + Decimal(str(estimated_cost)) > budget.daily_limit_usd:
|
||||
raise BudgetExceededError("日", daily_cost, budget.daily_limit_usd)
|
||||
|
||||
# 检查月预算
|
||||
monthly_cost = await self.get_monthly_cost(db, user_id)
|
||||
if monthly_cost + Decimal(str(estimated_cost)) > budget.monthly_limit_usd:
|
||||
raise BudgetExceededError("月", monthly_cost, budget.monthly_limit_usd)
|
||||
|
||||
return True
|
||||
|
||||
async def get_user_budget(self, db: AsyncSession, user_id: str) -> UserBudget | None:
|
||||
"""获取用户预算配置。"""
|
||||
result = await db.execute(
|
||||
select(UserBudget).where(UserBudget.user_id == user_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def set_user_budget(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
daily_limit: float | None = None,
|
||||
monthly_limit: float | None = None,
|
||||
alert_threshold: float | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> UserBudget:
|
||||
"""设置用户预算。"""
|
||||
budget = await self.get_user_budget(db, user_id)
|
||||
|
||||
if budget is None:
|
||||
budget = UserBudget(user_id=user_id)
|
||||
db.add(budget)
|
||||
|
||||
if daily_limit is not None:
|
||||
budget.daily_limit_usd = Decimal(str(daily_limit))
|
||||
if monthly_limit is not None:
|
||||
budget.monthly_limit_usd = Decimal(str(monthly_limit))
|
||||
if alert_threshold is not None:
|
||||
budget.alert_threshold = Decimal(str(alert_threshold))
|
||||
if enabled is not None:
|
||||
budget.enabled = enabled
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(budget)
|
||||
return budget
|
||||
|
||||
async def get_cost_summary(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
"""获取用户成本摘要。"""
|
||||
daily = await self.get_daily_cost(db, user_id)
|
||||
monthly = await self.get_monthly_cost(db, user_id)
|
||||
by_capability = await self.get_cost_by_capability(db, user_id)
|
||||
budget = await self.get_user_budget(db, user_id)
|
||||
|
||||
return {
|
||||
"daily_cost_usd": float(daily),
|
||||
"monthly_cost_usd": float(monthly),
|
||||
"by_capability": {k: float(v) for k, v in by_capability.items()},
|
||||
"budget": {
|
||||
"daily_limit_usd": float(budget.daily_limit_usd) if budget else None,
|
||||
"monthly_limit_usd": float(budget.monthly_limit_usd) if budget else None,
|
||||
"daily_usage_percent": float(daily / budget.daily_limit_usd * 100)
|
||||
if budget and budget.daily_limit_usd
|
||||
else None,
|
||||
"monthly_usage_percent": float(monthly / budget.monthly_limit_usd * 100)
|
||||
if budget and budget.monthly_limit_usd
|
||||
else None,
|
||||
"enabled": budget.enabled if budget else False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
cost_tracker = CostTracker()
|
||||
471
backend/app/services/memory_service.py
Normal file
471
backend/app/services/memory_service.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""Memory service handles memory retrieval, scoring, and prompt injection."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import ChildProfile, MemoryItem, StoryUniverse
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemoryType:
|
||||
"""记忆类型常量及配置。"""
|
||||
|
||||
# 基础类型
|
||||
RECENT_STORY = "recent_story"
|
||||
FAVORITE_CHARACTER = "favorite_character"
|
||||
SCARY_ELEMENT = "scary_element"
|
||||
VOCABULARY_GROWTH = "vocabulary_growth"
|
||||
EMOTIONAL_HIGHLIGHT = "emotional_highlight"
|
||||
|
||||
# Phase 1 新增类型
|
||||
READING_PREFERENCE = "reading_preference" # 阅读偏好
|
||||
MILESTONE = "milestone" # 里程碑事件
|
||||
SKILL_MASTERED = "skill_mastered" # 掌握的技能
|
||||
|
||||
# 类型配置: (默认权重, 默认TTL天数, 描述)
|
||||
CONFIG = {
|
||||
RECENT_STORY: (1.0, 30, "最近阅读的故事"),
|
||||
FAVORITE_CHARACTER: (1.5, None, "喜欢的角色"), # None = 永久
|
||||
SCARY_ELEMENT: (2.0, None, "回避的元素"), # 高权重,永久有效
|
||||
VOCABULARY_GROWTH: (0.8, 90, "词汇积累"),
|
||||
EMOTIONAL_HIGHLIGHT: (1.2, 60, "情感高光"),
|
||||
READING_PREFERENCE: (1.0, None, "阅读偏好"),
|
||||
MILESTONE: (1.5, None, "里程碑事件"),
|
||||
SKILL_MASTERED: (1.0, 180, "掌握的技能"),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_default_weight(cls, memory_type: str) -> float:
|
||||
"""获取类型的默认权重。"""
|
||||
config = cls.CONFIG.get(memory_type)
|
||||
return config[0] if config else 1.0
|
||||
|
||||
@classmethod
|
||||
def get_default_ttl(cls, memory_type: str) -> int | None:
|
||||
"""获取类型的默认 TTL 天数。"""
|
||||
config = cls.CONFIG.get(memory_type)
|
||||
return config[1] if config else None
|
||||
|
||||
|
||||
def _decay_factor(days: float) -> float:
|
||||
"""计算时间衰减因子。"""
|
||||
if days <= 7:
|
||||
return 1.0
|
||||
if days <= 30:
|
||||
return 0.7
|
||||
if days <= 90:
|
||||
return 0.4
|
||||
return 0.2
|
||||
|
||||
|
||||
async def build_enhanced_memory_context(
|
||||
profile_id: str | None,
|
||||
universe_id: str | None,
|
||||
db: AsyncSession,
|
||||
) -> str | None:
|
||||
"""构建增强版记忆上下文(自然语言 Prompt)。"""
|
||||
if not profile_id and not universe_id:
|
||||
return None
|
||||
|
||||
context_parts: list[str] = []
|
||||
|
||||
# 1. 基础档案 (Identity Layer)
|
||||
if profile_id:
|
||||
profile = await db.scalar(select(ChildProfile).where(ChildProfile.id == profile_id))
|
||||
if profile:
|
||||
context_parts.append(f"【目标读者】\n姓名:{profile.name}")
|
||||
if profile.age:
|
||||
context_parts.append(f"年龄:{profile.age}岁")
|
||||
if profile.interests:
|
||||
context_parts.append(f"兴趣爱好:{'、'.join(profile.interests)}")
|
||||
if profile.growth_themes:
|
||||
context_parts.append(f"当前成长关注点:{'、'.join(profile.growth_themes)}")
|
||||
context_parts.append("") # 空行
|
||||
|
||||
# 2. 故事宇宙 (Universe Layer)
|
||||
if universe_id:
|
||||
universe = await db.scalar(select(StoryUniverse).where(StoryUniverse.id == universe_id))
|
||||
if universe:
|
||||
context_parts.append("【故事宇宙设定】")
|
||||
context_parts.append(f"世界观:{universe.name}")
|
||||
|
||||
# 主角
|
||||
protagonist = universe.protagonist or {}
|
||||
p_desc = f"{protagonist.get('name', '主角')} ({protagonist.get('personality', '')})"
|
||||
context_parts.append(f"主角设定:{p_desc}")
|
||||
|
||||
# 常驻角色
|
||||
if universe.recurring_characters:
|
||||
chars = [f"{c.get('name')} ({c.get('type')})" for c in universe.recurring_characters if isinstance(c, dict)]
|
||||
context_parts.append(f"已知伙伴:{'、'.join(chars)}")
|
||||
|
||||
# 成就
|
||||
if universe.achievements:
|
||||
badges = [str(a.get('type')) for a in universe.achievements if isinstance(a, dict)]
|
||||
if badges:
|
||||
context_parts.append(f"已获荣誉:{'、'.join(badges[:5])}")
|
||||
|
||||
context_parts.append("")
|
||||
|
||||
# 3. 动态记忆 (Working Memory)
|
||||
if profile_id:
|
||||
memories = await _fetch_scored_memories(profile_id, universe_id, db)
|
||||
if memories:
|
||||
memory_text = _format_memories_to_prompt(memories)
|
||||
if memory_text:
|
||||
context_parts.append("【关键记忆回忆】(请在故事中自然地融入或致敬以下元素)")
|
||||
context_parts.append(memory_text)
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
|
||||
async def _fetch_scored_memories(
|
||||
profile_id: str,
|
||||
universe_id: str | None,
|
||||
db: AsyncSession,
|
||||
limit: int = 8
|
||||
) -> list[MemoryItem]:
|
||||
"""获取并评分记忆项,返回 Top N。"""
|
||||
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
|
||||
if universe_id:
|
||||
query = query.where(
|
||||
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
|
||||
)
|
||||
# 取最近 50 条进行评分
|
||||
query = query.order_by(MemoryItem.last_used_at.desc(), MemoryItem.created_at.desc()).limit(50)
|
||||
|
||||
result = await db.execute(query)
|
||||
items = result.scalars().all()
|
||||
|
||||
scored: list[tuple[float, MemoryItem]] = []
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
for item in items:
|
||||
reference = item.last_used_at or item.created_at or now
|
||||
delta_days = max((now - reference).total_seconds() / 86400, 0)
|
||||
|
||||
if item.ttl_days and delta_days > item.ttl_days:
|
||||
continue
|
||||
|
||||
score = (item.base_weight or 1.0) * _decay_factor(delta_days)
|
||||
if score <= 0.1: # 忽略低权重
|
||||
continue
|
||||
|
||||
scored.append((score, item))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
return [item for _, item in scored[:limit]]
|
||||
|
||||
|
||||
def _format_memories_to_prompt(memories: list[MemoryItem]) -> str:
|
||||
"""将记忆项转换为自然语言指令。"""
|
||||
lines = []
|
||||
|
||||
# 分类处理
|
||||
recent_stories = []
|
||||
favorites = []
|
||||
scary = []
|
||||
vocab = []
|
||||
|
||||
for m in memories:
|
||||
if m.type == MemoryType.RECENT_STORY:
|
||||
recent_stories.append(m)
|
||||
elif m.type == MemoryType.FAVORITE_CHARACTER:
|
||||
favorites.append(m)
|
||||
elif m.type == MemoryType.SCARY_ELEMENT:
|
||||
scary.append(m)
|
||||
elif m.type == MemoryType.VOCABULARY_GROWTH:
|
||||
vocab.append(m)
|
||||
|
||||
# 1. 喜欢的角色
|
||||
if favorites:
|
||||
names = []
|
||||
for m in favorites:
|
||||
val = m.value
|
||||
if isinstance(val, dict):
|
||||
names.append(f"{val.get('name')} ({val.get('description', '')})")
|
||||
if names:
|
||||
lines.append(f"- 孩子特别喜欢这些角色,可以让他们客串出场:{', '.join(names)}")
|
||||
|
||||
# 2. 避雷区
|
||||
if scary:
|
||||
items = []
|
||||
for m in scary:
|
||||
val = m.value
|
||||
if isinstance(val, dict):
|
||||
items.append(val.get('keyword', ''))
|
||||
elif isinstance(val, str):
|
||||
items.append(val)
|
||||
if items:
|
||||
lines.append(f"- 【注意禁止】不要出现以下让孩子害怕的元素:{', '.join(items)}")
|
||||
|
||||
# 3. 近期故事 (取最近 2 个)
|
||||
if recent_stories:
|
||||
lines.append("- 近期经历(可作为彩蛋提及):")
|
||||
for m in recent_stories[:2]:
|
||||
val = m.value
|
||||
if isinstance(val, dict):
|
||||
title = val.get('title', '未知故事')
|
||||
lines.append(f" * 之前读过《{title}》")
|
||||
|
||||
# 4. 词汇积累
|
||||
if vocab:
|
||||
words = []
|
||||
for m in vocab:
|
||||
val = m.value
|
||||
if isinstance(val, dict):
|
||||
words.append(val.get('word'))
|
||||
if words:
|
||||
lines.append(f"- 已掌握词汇(可适当复现以巩固):{', '.join([w for w in words if w])}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def prune_expired_memories(db: AsyncSession) -> int:
|
||||
"""清理过期的记忆项。
|
||||
|
||||
Returns:
|
||||
删除的记录数量
|
||||
"""
|
||||
from sqlalchemy import delete
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# 查找所有设置了 TTL 的项目
|
||||
stmt = select(MemoryItem).where(MemoryItem.ttl_days.is_not(None))
|
||||
result = await db.execute(stmt)
|
||||
candidates = result.scalars().all()
|
||||
|
||||
to_delete_ids = []
|
||||
for item in candidates:
|
||||
if not item.ttl_days:
|
||||
continue
|
||||
|
||||
reference = item.last_used_at or item.created_at or now
|
||||
delta_days = (now - reference).total_seconds() / 86400
|
||||
|
||||
if delta_days > item.ttl_days:
|
||||
to_delete_ids.append(item.id)
|
||||
|
||||
if not to_delete_ids:
|
||||
return 0
|
||||
|
||||
delete_stmt = delete(MemoryItem).where(MemoryItem.id.in_(to_delete_ids))
|
||||
await db.execute(delete_stmt)
|
||||
await db.commit()
|
||||
|
||||
logger.info("memory_pruned", count=len(to_delete_ids))
|
||||
return len(to_delete_ids)
|
||||
|
||||
|
||||
async def create_memory(
|
||||
db: AsyncSession,
|
||||
profile_id: str,
|
||||
memory_type: str,
|
||||
value: dict,
|
||||
universe_id: str | None = None,
|
||||
weight: float | None = None,
|
||||
ttl_days: int | None = None,
|
||||
) -> MemoryItem:
|
||||
"""创建新的记忆项。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
profile_id: 孩子档案 ID
|
||||
memory_type: 记忆类型 (使用 MemoryType 常量)
|
||||
value: 记忆内容 (JSON 格式)
|
||||
universe_id: 可选,关联的故事宇宙 ID
|
||||
weight: 可选,权重 (默认使用类型配置)
|
||||
ttl_days: 可选,过期天数 (默认使用类型配置)
|
||||
|
||||
Returns:
|
||||
创建的 MemoryItem
|
||||
"""
|
||||
memory = MemoryItem(
|
||||
child_profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
type=memory_type,
|
||||
value=value,
|
||||
base_weight=weight or MemoryType.get_default_weight(memory_type),
|
||||
ttl_days=ttl_days if ttl_days is not None else MemoryType.get_default_ttl(memory_type),
|
||||
)
|
||||
db.add(memory)
|
||||
await db.commit()
|
||||
await db.refresh(memory)
|
||||
|
||||
logger.info(
|
||||
"memory_created",
|
||||
memory_id=memory.id,
|
||||
profile_id=profile_id,
|
||||
type=memory_type,
|
||||
)
|
||||
return memory
|
||||
|
||||
|
||||
async def update_memory_usage(db: AsyncSession, memory_id: str) -> None:
|
||||
"""更新记忆的最后使用时间。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
memory_id: 记忆项 ID
|
||||
"""
|
||||
result = await db.execute(select(MemoryItem).where(MemoryItem.id == memory_id))
|
||||
memory = result.scalar_one_or_none()
|
||||
|
||||
if memory:
|
||||
memory.last_used_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
logger.debug("memory_usage_updated", memory_id=memory_id)
|
||||
|
||||
|
||||
async def get_profile_memories(
|
||||
db: AsyncSession,
|
||||
profile_id: str,
|
||||
memory_type: str | None = None,
|
||||
universe_id: str | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[MemoryItem]:
|
||||
"""获取档案的记忆列表。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
profile_id: 孩子档案 ID
|
||||
memory_type: 可选,按类型筛选
|
||||
universe_id: 可选,按宇宙筛选
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
MemoryItem 列表
|
||||
"""
|
||||
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
|
||||
|
||||
if memory_type:
|
||||
query = query.where(MemoryItem.type == memory_type)
|
||||
|
||||
if universe_id:
|
||||
query = query.where(
|
||||
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
|
||||
)
|
||||
|
||||
query = query.order_by(MemoryItem.created_at.desc()).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def create_story_memory(
|
||||
db: AsyncSession,
|
||||
profile_id: str,
|
||||
story_id: int,
|
||||
title: str,
|
||||
summary: str | None = None,
|
||||
keywords: list[str] | None = None,
|
||||
universe_id: str | None = None,
|
||||
) -> MemoryItem:
|
||||
"""为故事创建记忆项。
|
||||
|
||||
这是一个便捷函数,专门用于在故事阅读后创建 recent_story 类型的记忆。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
profile_id: 孩子档案 ID
|
||||
story_id: 故事 ID
|
||||
title: 故事标题
|
||||
summary: 故事梗概
|
||||
keywords: 关键词列表
|
||||
universe_id: 可选,关联的故事宇宙 ID
|
||||
|
||||
Returns:
|
||||
创建的 MemoryItem
|
||||
"""
|
||||
value = {
|
||||
"story_id": story_id,
|
||||
"title": title,
|
||||
"summary": summary or "",
|
||||
"keywords": keywords or [],
|
||||
}
|
||||
|
||||
return await create_memory(
|
||||
db=db,
|
||||
profile_id=profile_id,
|
||||
memory_type=MemoryType.RECENT_STORY,
|
||||
value=value,
|
||||
universe_id=universe_id,
|
||||
)
|
||||
|
||||
|
||||
async def create_character_memory(
|
||||
db: AsyncSession,
|
||||
profile_id: str,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
source_story_id: int | None = None,
|
||||
affinity_score: float = 1.0,
|
||||
universe_id: str | None = None,
|
||||
) -> MemoryItem:
|
||||
"""为喜欢的角色创建记忆项。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
profile_id: 孩子档案 ID
|
||||
name: 角色名称
|
||||
description: 角色描述
|
||||
source_story_id: 来源故事 ID
|
||||
affinity_score: 喜爱程度 (0.0-1.0)
|
||||
universe_id: 可选,关联的故事宇宙 ID
|
||||
|
||||
Returns:
|
||||
创建的 MemoryItem
|
||||
"""
|
||||
value = {
|
||||
"name": name,
|
||||
"description": description or "",
|
||||
"source_story_id": source_story_id,
|
||||
"affinity_score": min(1.0, max(0.0, affinity_score)),
|
||||
}
|
||||
|
||||
return await create_memory(
|
||||
db=db,
|
||||
profile_id=profile_id,
|
||||
memory_type=MemoryType.FAVORITE_CHARACTER,
|
||||
value=value,
|
||||
universe_id=universe_id,
|
||||
)
|
||||
|
||||
|
||||
async def create_scary_element_memory(
|
||||
db: AsyncSession,
|
||||
profile_id: str,
|
||||
keyword: str,
|
||||
category: str = "other",
|
||||
source_story_id: int | None = None,
|
||||
) -> MemoryItem:
|
||||
"""为回避元素创建记忆项。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
profile_id: 孩子档案 ID
|
||||
keyword: 回避的关键词
|
||||
category: 分类 (creature/scene/action/other)
|
||||
source_story_id: 来源故事 ID
|
||||
|
||||
Returns:
|
||||
创建的 MemoryItem
|
||||
"""
|
||||
value = {
|
||||
"keyword": keyword,
|
||||
"category": category,
|
||||
"source_story_id": source_story_id,
|
||||
}
|
||||
|
||||
return await create_memory(
|
||||
db=db,
|
||||
profile_id=profile_id,
|
||||
memory_type=MemoryType.SCARY_ELEMENT,
|
||||
value=value,
|
||||
)
|
||||
|
||||
31
backend/app/services/provider_cache.py
Normal file
31
backend/app/services/provider_cache.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""In-memory cache for providers loaded from DB."""
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.admin_models import Provider
|
||||
|
||||
ProviderType = Literal["text", "image", "tts", "storybook"]
|
||||
|
||||
_cache: dict[ProviderType, list[Provider]] = defaultdict(list)
|
||||
|
||||
|
||||
async def reload_providers(db: AsyncSession):
|
||||
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
|
||||
providers = result.scalars().all()
|
||||
grouped: dict[ProviderType, list[Provider]] = defaultdict(list)
|
||||
for p in providers:
|
||||
grouped[p.type].append(p)
|
||||
# sort by priority desc, then weight desc
|
||||
for k in grouped:
|
||||
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
|
||||
_cache.clear()
|
||||
_cache.update(grouped)
|
||||
return _cache
|
||||
|
||||
|
||||
def get_providers(provider_type: ProviderType) -> list[Provider]:
|
||||
return _cache.get(provider_type, [])
|
||||
248
backend/app/services/provider_metrics.py
Normal file
248
backend/app/services/provider_metrics.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""供应商指标收集和健康检查服务。"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.admin_models import ProviderHealth, ProviderMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 熔断阈值:连续失败次数
|
||||
CIRCUIT_BREAKER_THRESHOLD = 3
|
||||
# 熔断恢复时间(秒)
|
||||
CIRCUIT_BREAKER_RECOVERY_SECONDS = 60
|
||||
|
||||
|
||||
class MetricsCollector:
|
||||
"""供应商调用指标收集器。"""
|
||||
|
||||
async def record_call(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
provider_id: str,
|
||||
success: bool,
|
||||
latency_ms: int | None = None,
|
||||
cost_usd: float | None = None,
|
||||
error_message: str | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> None:
|
||||
"""记录一次 API 调用。"""
|
||||
metric = ProviderMetrics(
|
||||
provider_id=provider_id,
|
||||
success=success,
|
||||
latency_ms=latency_ms,
|
||||
cost_usd=Decimal(str(cost_usd)) if cost_usd else None,
|
||||
error_message=error_message,
|
||||
request_id=request_id,
|
||||
)
|
||||
db.add(metric)
|
||||
await db.commit()
|
||||
|
||||
logger.debug(
|
||||
"metrics_recorded",
|
||||
provider_id=provider_id,
|
||||
success=success,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
async def get_success_rate(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
provider_id: str,
|
||||
window_minutes: int = 60,
|
||||
) -> float:
|
||||
"""获取指定时间窗口内的成功率。"""
|
||||
since = datetime.utcnow() - timedelta(minutes=window_minutes)
|
||||
|
||||
result = await db.execute(
|
||||
select(
|
||||
func.count().filter(ProviderMetrics.success.is_(True)).label("success_count"),
|
||||
func.count().label("total_count"),
|
||||
).where(
|
||||
ProviderMetrics.provider_id == provider_id,
|
||||
ProviderMetrics.timestamp >= since,
|
||||
)
|
||||
)
|
||||
row = result.one()
|
||||
success_count, total_count = row.success_count, row.total_count
|
||||
|
||||
if total_count == 0:
|
||||
return 1.0 # 无数据时假设健康
|
||||
|
||||
return success_count / total_count
|
||||
|
||||
async def get_avg_latency(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
provider_id: str,
|
||||
window_minutes: int = 60,
|
||||
) -> float:
|
||||
"""获取指定时间窗口内的平均延迟(毫秒)。"""
|
||||
since = datetime.utcnow() - timedelta(minutes=window_minutes)
|
||||
|
||||
result = await db.execute(
|
||||
select(func.avg(ProviderMetrics.latency_ms)).where(
|
||||
ProviderMetrics.provider_id == provider_id,
|
||||
ProviderMetrics.timestamp >= since,
|
||||
ProviderMetrics.latency_ms.isnot(None),
|
||||
)
|
||||
)
|
||||
avg = result.scalar()
|
||||
return float(avg) if avg else 0.0
|
||||
|
||||
async def get_total_cost(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
provider_id: str,
|
||||
window_minutes: int = 60,
|
||||
) -> float:
|
||||
"""获取指定时间窗口内的总成本(USD)。"""
|
||||
since = datetime.utcnow() - timedelta(minutes=window_minutes)
|
||||
|
||||
result = await db.execute(
|
||||
select(func.sum(ProviderMetrics.cost_usd)).where(
|
||||
ProviderMetrics.provider_id == provider_id,
|
||||
ProviderMetrics.timestamp >= since,
|
||||
)
|
||||
)
|
||||
total = result.scalar()
|
||||
return float(total) if total else 0.0
|
||||
|
||||
|
||||
class HealthChecker:
|
||||
"""供应商健康检查器。"""
|
||||
|
||||
async def check_provider(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
provider_id: str,
|
||||
adapter: "BaseAdapter",
|
||||
) -> bool:
|
||||
"""执行健康检查并更新状态。"""
|
||||
try:
|
||||
is_healthy = await adapter.health_check()
|
||||
except Exception as e:
|
||||
logger.warning("health_check_failed", provider_id=provider_id, error=str(e))
|
||||
is_healthy = False
|
||||
|
||||
await self.update_health_status(
|
||||
db,
|
||||
provider_id,
|
||||
is_healthy,
|
||||
error=None if is_healthy else "Health check failed",
|
||||
)
|
||||
return is_healthy
|
||||
|
||||
async def update_health_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
provider_id: str,
|
||||
is_healthy: bool,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""更新供应商健康状态(含熔断逻辑)。"""
|
||||
result = await db.execute(
|
||||
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
|
||||
)
|
||||
health = result.scalar_one_or_none()
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
if health is None:
|
||||
health = ProviderHealth(
|
||||
provider_id=provider_id,
|
||||
is_healthy=is_healthy,
|
||||
last_check=now,
|
||||
consecutive_failures=0 if is_healthy else 1,
|
||||
last_error=error,
|
||||
)
|
||||
db.add(health)
|
||||
else:
|
||||
health.last_check = now
|
||||
|
||||
if is_healthy:
|
||||
health.is_healthy = True
|
||||
health.consecutive_failures = 0
|
||||
health.last_error = None
|
||||
else:
|
||||
health.consecutive_failures += 1
|
||||
health.last_error = error
|
||||
|
||||
# 熔断逻辑
|
||||
if health.consecutive_failures >= CIRCUIT_BREAKER_THRESHOLD:
|
||||
health.is_healthy = False
|
||||
logger.warning(
|
||||
"circuit_breaker_triggered",
|
||||
provider_id=provider_id,
|
||||
consecutive_failures=health.consecutive_failures,
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
async def record_call_result(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
provider_id: str,
|
||||
success: bool,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""根据调用结果更新健康状态。"""
|
||||
await self.update_health_status(db, provider_id, success, error)
|
||||
|
||||
async def get_healthy_providers(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
provider_ids: list[str],
|
||||
) -> list[str]:
|
||||
"""获取健康的供应商列表。"""
|
||||
if not provider_ids:
|
||||
return []
|
||||
|
||||
# 查询所有已记录的健康状态
|
||||
result = await db.execute(
|
||||
select(ProviderHealth.provider_id, ProviderHealth.is_healthy).where(
|
||||
ProviderHealth.provider_id.in_(provider_ids),
|
||||
)
|
||||
)
|
||||
health_map = {row[0]: row[1] for row in result.all()}
|
||||
|
||||
# 未记录的供应商默认健康,已记录但不健康的排除
|
||||
return [
|
||||
pid for pid in provider_ids
|
||||
if pid not in health_map or health_map[pid]
|
||||
]
|
||||
|
||||
async def is_healthy(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
provider_id: str,
|
||||
) -> bool:
|
||||
"""检查供应商是否健康。"""
|
||||
result = await db.execute(
|
||||
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
|
||||
)
|
||||
health = result.scalar_one_or_none()
|
||||
|
||||
if health is None:
|
||||
return True # 未记录默认健康
|
||||
|
||||
# 检查是否可以恢复
|
||||
if not health.is_healthy and health.last_check:
|
||||
recovery_time = health.last_check + timedelta(seconds=CIRCUIT_BREAKER_RECOVERY_SECONDS)
|
||||
if datetime.utcnow() >= recovery_time:
|
||||
return True # 允许重试
|
||||
|
||||
return health.is_healthy
|
||||
|
||||
|
||||
# 全局单例
|
||||
metrics_collector = MetricsCollector()
|
||||
health_checker = HealthChecker()
|
||||
432
backend/app/services/provider_router.py
Normal file
432
backend/app/services/provider_router.py
Normal file
@@ -0,0 +1,432 @@
|
||||
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Literal, TypeVar
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters import AdapterConfig, AdapterRegistry
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
from app.services.cost_tracker import cost_tracker
|
||||
from app.services.provider_cache import get_providers
|
||||
from app.services.provider_metrics import health_checker, metrics_collector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.db.admin_models import Provider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
ProviderType = Literal["text", "image", "tts", "storybook"]
|
||||
|
||||
|
||||
class RoutingStrategy(str, Enum):
|
||||
"""路由策略枚举。"""
|
||||
|
||||
PRIORITY = "priority" # 按优先级排序(默认)
|
||||
COST = "cost" # 按成本排序
|
||||
LATENCY = "latency" # 按延迟排序
|
||||
ROUND_ROBIN = "round_robin" # 轮询
|
||||
|
||||
|
||||
# 默认配置映射(当 DB 无配置时使用)
|
||||
# 默认配置映射(当 DB 无配置时使用)
|
||||
# 这是“代码级”的默认策略,对应 .env 为空的情况
|
||||
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
|
||||
"text": ["gemini", "openai"],
|
||||
"image": ["cqtai"],
|
||||
"tts": ["minimax", "elevenlabs", "edge_tts"],
|
||||
"storybook": ["gemini"],
|
||||
}
|
||||
|
||||
# API Key 映射:adapter_name -> settings 属性名
|
||||
API_KEY_MAP: dict[str, str] = {
|
||||
# Text
|
||||
"gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段
|
||||
"text_primary": "text_api_key", # 兼容旧别名
|
||||
"openai": "openai_api_key",
|
||||
|
||||
# Image
|
||||
"cqtai": "cqtai_api_key",
|
||||
"image_primary": "image_api_key", # 兼容旧别名
|
||||
|
||||
# TTS
|
||||
"minimax": "minimax_api_key",
|
||||
"elevenlabs": "elevenlabs_api_key",
|
||||
"edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空)
|
||||
"tts_primary": "tts_api_key", # 兼容旧别名
|
||||
}
|
||||
|
||||
# 轮询计数器
|
||||
_round_robin_counters: dict[ProviderType, int] = {
|
||||
"text": 0,
|
||||
"image": 0,
|
||||
"tts": 0,
|
||||
}
|
||||
|
||||
# 延迟缓存(内存中,简化实现)
|
||||
_latency_cache: dict[str, float] = {}
|
||||
|
||||
|
||||
def _get_api_key(config_ref: str | None, adapter_name: str) -> str:
|
||||
"""根据 config_ref 或适配器名称获取 API Key。"""
|
||||
# 优先使用 config_ref
|
||||
key_attr = API_KEY_MAP.get(config_ref or adapter_name, None)
|
||||
if key_attr:
|
||||
return getattr(settings, key_attr, "")
|
||||
# 回退到适配器名称
|
||||
key_attr = API_KEY_MAP.get(adapter_name, None)
|
||||
if key_attr:
|
||||
return getattr(settings, key_attr, "")
|
||||
return ""
|
||||
|
||||
|
||||
def _get_default_config(adapter_name: str) -> AdapterConfig | None:
|
||||
"""获取适配器的默认配置(无 DB 记录时使用)。返回 None 表示未知适配器。"""
|
||||
|
||||
# --- Text Defaults ---
|
||||
if adapter_name in ("gemini", "text_primary"):
|
||||
return AdapterConfig(
|
||||
api_key=settings.text_api_key,
|
||||
model=settings.text_model or "gemini-2.0-flash",
|
||||
timeout_ms=60000,
|
||||
)
|
||||
if adapter_name == "openai":
|
||||
return AdapterConfig(
|
||||
api_key=getattr(settings, "openai_api_key", ""),
|
||||
model="gpt-4o-mini", # 这里可以从 settings 读取,看需求
|
||||
timeout_ms=60000,
|
||||
)
|
||||
|
||||
# --- Image Defaults ---
|
||||
if adapter_name in ("cqtai"):
|
||||
return AdapterConfig(
|
||||
api_key=getattr(settings, "cqtai_api_key", ""),
|
||||
model="nano-banana-pro", # 默认使用 Pro
|
||||
timeout_ms=120000,
|
||||
)
|
||||
if adapter_name == "image_primary":
|
||||
# 如果还有地方在用 image_primary,暂时映射到快或者其他
|
||||
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
|
||||
return AdapterConfig(
|
||||
api_key=settings.image_api_key,
|
||||
timeout_ms=120000
|
||||
)
|
||||
|
||||
# --- TTS Defaults ---
|
||||
if adapter_name == "minimax":
|
||||
# 传递 group_id 到 Adapter
|
||||
# 目前 AdapterConfig 没有 group_id 字段,我们暂时不改 Base,
|
||||
# 而是假设 Adapter 会从 config (通过 kwargs 或其他方式) 拿。
|
||||
# 实际上我们的 MiniMaxTTSAdapter 还没有处理 group_id。
|
||||
# 最简单的方法:把 group_id 藏在 api_base 里或者让 Adapter 自己去 settings 拿。
|
||||
# 鉴于 _build_config_from_provider 里我们无法传递额外参数给 Adapter.__init__,
|
||||
# 我们这里暂时返回基础配置。
|
||||
return AdapterConfig(
|
||||
api_key=getattr(settings, "minimax_api_key", ""),
|
||||
model="speech-2.6-turbo",
|
||||
timeout_ms=60000,
|
||||
)
|
||||
|
||||
if adapter_name == "elevenlabs":
|
||||
return AdapterConfig(
|
||||
api_key=getattr(settings, "elevenlabs_api_key", ""),
|
||||
timeout_ms=120000,
|
||||
)
|
||||
if adapter_name in ("edge_tts", "tts_primary"):
|
||||
return AdapterConfig(
|
||||
api_key=settings.tts_api_key,
|
||||
api_base=settings.tts_api_base,
|
||||
model=settings.tts_model or "zh-CN-XiaoxiaoNeural",
|
||||
timeout_ms=120000,
|
||||
)
|
||||
|
||||
# --- Others ---
|
||||
if adapter_name in ("storybook_primary", "storybook_gemini"):
|
||||
return AdapterConfig(
|
||||
api_key=settings.text_api_key, # 复用 Gemini key
|
||||
model=settings.text_model,
|
||||
timeout_ms=120000,
|
||||
)
|
||||
|
||||
# 未知适配器返回 None
|
||||
return None
|
||||
|
||||
|
||||
def _build_config_from_provider(provider: "Provider") -> AdapterConfig:
|
||||
"""从 DB Provider 记录构建 AdapterConfig。"""
|
||||
api_key = getattr(provider, "api_key", None) or ""
|
||||
if not api_key:
|
||||
api_key = _get_api_key(provider.config_ref, provider.adapter)
|
||||
|
||||
default = _get_default_config(provider.adapter)
|
||||
if default is None:
|
||||
default = AdapterConfig(api_key="", timeout_ms=60000)
|
||||
|
||||
return AdapterConfig(
|
||||
api_key=api_key or default.api_key,
|
||||
api_base=provider.api_base or default.api_base,
|
||||
model=provider.model or default.model,
|
||||
timeout_ms=provider.timeout_ms or default.timeout_ms,
|
||||
max_retries=provider.max_retries or default.max_retries,
|
||||
extra_config=provider.config_json or {},
|
||||
)
|
||||
|
||||
|
||||
def _get_providers_with_config(
|
||||
provider_type: ProviderType,
|
||||
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
|
||||
"""获取供应商列表及其配置。
|
||||
|
||||
Returns:
|
||||
[(adapter_name, config, provider_or_none), ...] 按优先级排序
|
||||
"""
|
||||
db_providers = get_providers(provider_type)
|
||||
|
||||
if db_providers:
|
||||
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
|
||||
|
||||
settings_map = {
|
||||
"text": settings.text_providers,
|
||||
"image": settings.image_providers,
|
||||
"tts": settings.tts_providers,
|
||||
}
|
||||
names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type]
|
||||
result = []
|
||||
for name in names:
|
||||
config = _get_default_config(name)
|
||||
if config is None:
|
||||
logger.warning("unknown_adapter_skipped", adapter=name, provider_type=provider_type)
|
||||
continue
|
||||
result.append((name, config, None))
|
||||
return result
|
||||
|
||||
|
||||
def _sort_by_strategy(
|
||||
providers: list[tuple[str, AdapterConfig, "Provider | None"]],
|
||||
strategy: RoutingStrategy,
|
||||
provider_type: ProviderType,
|
||||
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
|
||||
"""按策略排序供应商列表。"""
|
||||
if strategy == RoutingStrategy.PRIORITY:
|
||||
# 按 priority 降序, weight 降序
|
||||
return sorted(
|
||||
providers,
|
||||
key=lambda x: (-(x[2].priority if x[2] else 0), -(x[2].weight if x[2] else 1)),
|
||||
)
|
||||
|
||||
if strategy == RoutingStrategy.COST:
|
||||
# 按预估成本升序
|
||||
def get_cost(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
|
||||
adapter_class = AdapterRegistry.get(provider_type, item[0])
|
||||
if adapter_class:
|
||||
try:
|
||||
adapter = adapter_class(item[1])
|
||||
return adapter.estimated_cost
|
||||
except Exception:
|
||||
pass
|
||||
return float("inf")
|
||||
|
||||
return sorted(providers, key=get_cost)
|
||||
|
||||
if strategy == RoutingStrategy.LATENCY:
|
||||
# 按历史延迟升序
|
||||
def get_latency(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
|
||||
return _latency_cache.get(item[0], float("inf"))
|
||||
|
||||
return sorted(providers, key=get_latency)
|
||||
|
||||
if strategy == RoutingStrategy.ROUND_ROBIN:
|
||||
# 轮询:旋转列表
|
||||
counter = _round_robin_counters[provider_type]
|
||||
_round_robin_counters[provider_type] = (counter + 1) % max(len(providers), 1)
|
||||
return providers[counter:] + providers[:counter]
|
||||
|
||||
return providers
|
||||
|
||||
|
||||
async def _route_with_failover(
|
||||
provider_type: ProviderType,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
user_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""通用 provider failover 路由。
|
||||
|
||||
Args:
|
||||
provider_type: 供应商类型 (text/image/tts/storybook)
|
||||
strategy: 路由策略
|
||||
db: 数据库会话(可选,用于指标收集和熔断检查)
|
||||
user_id: 用户 ID(可选,用于成本追踪和预算检查)
|
||||
**kwargs: 传递给适配器的参数
|
||||
"""
|
||||
providers = _get_providers_with_config(provider_type)
|
||||
|
||||
if not providers:
|
||||
raise ValueError(f"No {provider_type} providers configured.")
|
||||
|
||||
# 按策略排序
|
||||
sorted_providers = _sort_by_strategy(providers, strategy, provider_type)
|
||||
|
||||
# 如果有 db 会话,过滤掉熔断的供应商
|
||||
if db:
|
||||
healthy_providers = []
|
||||
for item in sorted_providers:
|
||||
name, config, db_provider = item
|
||||
provider_id = db_provider.id if db_provider else name
|
||||
if await health_checker.is_healthy(db, provider_id):
|
||||
healthy_providers.append(item)
|
||||
else:
|
||||
logger.debug("provider_circuit_open", adapter=name, provider_id=provider_id)
|
||||
# 如果所有供应商都熔断,仍然尝试第一个(允许恢复)
|
||||
if not healthy_providers:
|
||||
healthy_providers = sorted_providers[:1]
|
||||
sorted_providers = healthy_providers
|
||||
|
||||
errors: list[str] = []
|
||||
for name, config, db_provider in sorted_providers:
|
||||
adapter_class = AdapterRegistry.get(provider_type, name)
|
||||
if not adapter_class:
|
||||
errors.append(f"{name}: 适配器未注册")
|
||||
continue
|
||||
|
||||
provider_id = db_provider.id if db_provider else name
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
"provider_attempt",
|
||||
provider_type=provider_type,
|
||||
adapter=name,
|
||||
strategy=strategy.value,
|
||||
)
|
||||
|
||||
adapter = adapter_class(config)
|
||||
|
||||
# 执行并计时
|
||||
start_time = time.time()
|
||||
result = await adapter.execute(**kwargs)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 更新延迟缓存
|
||||
_latency_cache[name] = latency_ms
|
||||
|
||||
# 记录成功指标
|
||||
if db:
|
||||
await metrics_collector.record_call(
|
||||
db,
|
||||
provider_id=provider_id,
|
||||
success=True,
|
||||
latency_ms=latency_ms,
|
||||
cost_usd=adapter.estimated_cost,
|
||||
)
|
||||
await health_checker.record_call_result(db, provider_id, success=True)
|
||||
|
||||
# 记录用户成本
|
||||
if user_id:
|
||||
await cost_tracker.record_cost(
|
||||
db,
|
||||
user_id=user_id,
|
||||
provider_name=name,
|
||||
capability=provider_type,
|
||||
estimated_cost=adapter.estimated_cost,
|
||||
provider_id=provider_id if db_provider else None,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"provider_success",
|
||||
provider_type=provider_type,
|
||||
adapter=name,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
logger.warning(
|
||||
"provider_failed",
|
||||
provider_type=provider_type,
|
||||
adapter=name,
|
||||
error=error_msg,
|
||||
)
|
||||
errors.append(f"{name}: {exc}")
|
||||
|
||||
# 记录失败指标
|
||||
if db:
|
||||
await metrics_collector.record_call(
|
||||
db,
|
||||
provider_id=provider_id,
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
)
|
||||
await health_checker.record_call_result(
|
||||
db, provider_id, success=False, error=error_msg
|
||||
)
|
||||
|
||||
raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}")
|
||||
|
||||
|
||||
async def generate_story_content(
|
||||
input_type: Literal["keywords", "full_story"],
|
||||
data: str,
|
||||
education_theme: str | None = None,
|
||||
memory_context: str | None = None,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
) -> StoryOutput:
|
||||
"""生成或润色故事,支持 failover。"""
|
||||
return await _route_with_failover(
|
||||
"text",
|
||||
strategy=strategy,
|
||||
db=db,
|
||||
input_type=input_type,
|
||||
data=data,
|
||||
education_theme=education_theme,
|
||||
memory_context=memory_context,
|
||||
)
|
||||
|
||||
|
||||
async def generate_image(
|
||||
prompt: str,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""生成图片,返回 URL,支持 failover。"""
|
||||
return await _route_with_failover("image", strategy=strategy, db=db, prompt=prompt, **kwargs)
|
||||
|
||||
|
||||
async def text_to_speech(
|
||||
text: str,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
) -> bytes:
|
||||
"""文本转语音,返回 MP3 bytes,支持 failover。"""
|
||||
return await _route_with_failover("tts", strategy=strategy, db=db, text=text)
|
||||
|
||||
|
||||
async def generate_storybook(
|
||||
keywords: str,
|
||||
page_count: int = 6,
|
||||
education_theme: str | None = None,
|
||||
memory_context: str | None = None,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
):
|
||||
"""生成分页故事书,支持 failover。"""
|
||||
from app.services.adapters.storybook.primary import Storybook
|
||||
|
||||
result: Storybook = await _route_with_failover(
|
||||
"storybook",
|
||||
strategy=strategy,
|
||||
db=db,
|
||||
keywords=keywords,
|
||||
page_count=page_count,
|
||||
education_theme=education_theme,
|
||||
memory_context=memory_context,
|
||||
)
|
||||
return result
|
||||
207
backend/app/services/secret_service.py
Normal file
207
backend/app/services/secret_service.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""供应商密钥加密存储服务。
|
||||
|
||||
使用 Fernet 对称加密,密钥从 SECRET_KEY 派生。
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.db.admin_models import ProviderSecret
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SecretEncryptionError(Exception):
|
||||
"""密钥加密/解密错误。"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SecretService:
|
||||
"""供应商密钥加密存储服务。"""
|
||||
|
||||
_fernet: Fernet | None = None
|
||||
|
||||
@classmethod
|
||||
def _get_fernet(cls) -> Fernet:
|
||||
"""获取 Fernet 实例,从 SECRET_KEY 派生加密密钥。"""
|
||||
if cls._fernet is None:
|
||||
# 从 SECRET_KEY 派生 32 字节密钥
|
||||
key_bytes = hashlib.sha256(settings.secret_key.encode()).digest()
|
||||
fernet_key = base64.urlsafe_b64encode(key_bytes)
|
||||
cls._fernet = Fernet(fernet_key)
|
||||
return cls._fernet
|
||||
|
||||
@classmethod
|
||||
def encrypt(cls, plaintext: str) -> str:
|
||||
"""加密明文,返回 base64 编码的密文。
|
||||
|
||||
Args:
|
||||
plaintext: 要加密的明文
|
||||
|
||||
Returns:
|
||||
base64 编码的密文
|
||||
"""
|
||||
if not plaintext:
|
||||
return ""
|
||||
fernet = cls._get_fernet()
|
||||
encrypted = fernet.encrypt(plaintext.encode())
|
||||
return encrypted.decode()
|
||||
|
||||
@classmethod
|
||||
def decrypt(cls, ciphertext: str) -> str:
|
||||
"""解密密文,返回明文。
|
||||
|
||||
Args:
|
||||
ciphertext: base64 编码的密文
|
||||
|
||||
Returns:
|
||||
解密后的明文
|
||||
|
||||
Raises:
|
||||
SecretEncryptionError: 解密失败
|
||||
"""
|
||||
if not ciphertext:
|
||||
return ""
|
||||
try:
|
||||
fernet = cls._get_fernet()
|
||||
decrypted = fernet.decrypt(ciphertext.encode())
|
||||
return decrypted.decode()
|
||||
except InvalidToken as e:
|
||||
logger.error("secret_decrypt_failed", error=str(e))
|
||||
raise SecretEncryptionError("密钥解密失败,可能是 SECRET_KEY 已更改") from e
|
||||
|
||||
@classmethod
|
||||
async def get_secret(cls, db: AsyncSession, name: str) -> str | None:
|
||||
"""从数据库获取并解密密钥。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
name: 密钥名称
|
||||
|
||||
Returns:
|
||||
解密后的密钥值,不存在返回 None
|
||||
"""
|
||||
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
|
||||
secret = result.scalar_one_or_none()
|
||||
if secret is None:
|
||||
return None
|
||||
return cls.decrypt(secret.encrypted_value)
|
||||
|
||||
@classmethod
|
||||
async def set_secret(cls, db: AsyncSession, name: str, value: str) -> ProviderSecret:
|
||||
"""存储或更新加密密钥。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
name: 密钥名称
|
||||
value: 密钥明文值
|
||||
|
||||
Returns:
|
||||
ProviderSecret 实例
|
||||
"""
|
||||
encrypted = cls.encrypt(value)
|
||||
|
||||
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
|
||||
secret = result.scalar_one_or_none()
|
||||
|
||||
if secret is None:
|
||||
secret = ProviderSecret(name=name, encrypted_value=encrypted)
|
||||
db.add(secret)
|
||||
else:
|
||||
secret.encrypted_value = encrypted
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(secret)
|
||||
logger.info("secret_stored", name=name)
|
||||
return secret
|
||||
|
||||
@classmethod
|
||||
async def delete_secret(cls, db: AsyncSession, name: str) -> bool:
|
||||
"""删除密钥。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
name: 密钥名称
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
|
||||
secret = result.scalar_one_or_none()
|
||||
if secret is None:
|
||||
return False
|
||||
|
||||
await db.delete(secret)
|
||||
await db.commit()
|
||||
logger.info("secret_deleted", name=name)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def list_secrets(cls, db: AsyncSession) -> list[str]:
|
||||
"""列出所有密钥名称(不返回值)。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
密钥名称列表
|
||||
"""
|
||||
result = await db.execute(select(ProviderSecret.name))
|
||||
return [row[0] for row in result.fetchall()]
|
||||
|
||||
@classmethod
|
||||
async def get_api_key(
|
||||
cls,
|
||||
db: AsyncSession,
|
||||
provider_api_key: str | None,
|
||||
config_ref: str | None,
|
||||
) -> str | None:
|
||||
"""获取 Provider 的 API Key,按优先级查找。
|
||||
|
||||
优先级:
|
||||
1. provider.api_key (数据库明文/加密)
|
||||
2. provider.config_ref 指向的 ProviderSecret
|
||||
3. 环境变量 (config_ref 作为变量名)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider_api_key: Provider 表中的 api_key 字段
|
||||
config_ref: Provider 表中的 config_ref 字段
|
||||
|
||||
Returns:
|
||||
API Key 或 None
|
||||
"""
|
||||
# 1. 直接使用 provider.api_key
|
||||
if provider_api_key:
|
||||
# 尝试解密,如果失败则当作明文
|
||||
try:
|
||||
decrypted = cls.decrypt(provider_api_key)
|
||||
if decrypted:
|
||||
return decrypted
|
||||
except SecretEncryptionError:
|
||||
pass
|
||||
return provider_api_key
|
||||
|
||||
# 2. 从 ProviderSecret 表查找
|
||||
if config_ref:
|
||||
secret_value = await cls.get_secret(db, config_ref)
|
||||
if secret_value:
|
||||
return secret_value
|
||||
|
||||
# 3. 从环境变量查找
|
||||
env_value = getattr(settings, config_ref.lower(), None)
|
||||
if env_value:
|
||||
return env_value
|
||||
|
||||
return None
|
||||
3
backend/app/tasks/__init__.py
Normal file
3
backend/app/tasks/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Celery tasks package."""
|
||||
|
||||
from . import achievements, memory, push_notifications # noqa: F401
|
||||
82
backend/app/tasks/achievements.py
Normal file
82
backend/app/tasks/achievements.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Celery tasks for achievements."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.celery_app import celery_app
|
||||
from app.core.logging import get_logger
|
||||
from app.db.database import _get_session_factory
|
||||
from app.db.models import Story, StoryUniverse
|
||||
from app.services.achievement_extractor import extract_achievements
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def extract_story_achievements(story_id: int, universe_id: str) -> None:
|
||||
"""Extract achievements and update universe."""
|
||||
asyncio.run(_extract_story_achievements(story_id, universe_id))
|
||||
|
||||
|
||||
async def _extract_story_achievements(story_id: int, universe_id: str) -> None:
|
||||
session_factory = _get_session_factory()
|
||||
async with session_factory() as session:
|
||||
result = await session.execute(select(Story).where(Story.id == story_id))
|
||||
story = result.scalar_one_or_none()
|
||||
if not story:
|
||||
logger.warning("achievement_task_story_missing", story_id=story_id)
|
||||
return
|
||||
|
||||
result = await session.execute(
|
||||
select(StoryUniverse).where(StoryUniverse.id == universe_id)
|
||||
)
|
||||
universe = result.scalar_one_or_none()
|
||||
if not universe:
|
||||
logger.warning("achievement_task_universe_missing", universe_id=universe_id)
|
||||
return
|
||||
|
||||
text_content = story.story_text
|
||||
if not text_content and story.pages:
|
||||
# 如果是绘本,拼接每页文本
|
||||
text_content = "\n".join([str(p.get("text", "")) for p in story.pages])
|
||||
|
||||
if not text_content:
|
||||
logger.warning("achievement_task_empty_content", story_id=story_id)
|
||||
return
|
||||
|
||||
achievements = await extract_achievements(text_content)
|
||||
if not achievements:
|
||||
logger.info("achievement_task_no_new", story_id=story_id)
|
||||
return
|
||||
|
||||
existing = {
|
||||
(str(item.get("type", "")).strip(), str(item.get("description", "")).strip())
|
||||
for item in (universe.achievements or [])
|
||||
if isinstance(item, dict)
|
||||
}
|
||||
merged = list(universe.achievements or [])
|
||||
added_count = 0
|
||||
|
||||
for item in achievements:
|
||||
key = (item.get("type", "").strip(), item.get("description", "").strip())
|
||||
if key in existing:
|
||||
continue
|
||||
merged.append({
|
||||
"type": key[0],
|
||||
"description": key[1],
|
||||
"obtained_at": datetime.now().isoformat(),
|
||||
"source_story_id": story_id,
|
||||
})
|
||||
existing.add(key)
|
||||
added_count += 1
|
||||
|
||||
universe.achievements = merged
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"achievement_task_success",
|
||||
story_id=story_id,
|
||||
universe_id=universe_id,
|
||||
added=added_count,
|
||||
)
|
||||
29
backend/app/tasks/memory.py
Normal file
29
backend/app/tasks/memory.py
Normal file
@@ -0,0 +1,29 @@
|
||||
|
||||
import asyncio
|
||||
|
||||
from app.core.celery_app import celery_app
|
||||
from app.core.logging import get_logger
|
||||
from app.db.database import _get_session_factory
|
||||
from app.services.memory_service import prune_expired_memories
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@celery_app.task
|
||||
def prune_memories_task():
|
||||
"""Daily task to prune expired memories."""
|
||||
logger.info("prune_memories_task_started")
|
||||
|
||||
async def _run():
|
||||
# Ensure engine is initialized in this process
|
||||
session_factory = _get_session_factory()
|
||||
async with session_factory() as session:
|
||||
return await prune_expired_memories(session)
|
||||
|
||||
try:
|
||||
# Create a new event loop for this task execution
|
||||
count = asyncio.run(_run())
|
||||
logger.info("prune_memories_task_completed", deleted_count=count)
|
||||
return f"Deleted {count} expired memories"
|
||||
except Exception as exc:
|
||||
logger.error("prune_memories_task_failed", error=str(exc))
|
||||
raise
|
||||
108
backend/app/tasks/push_notifications.py
Normal file
108
backend/app/tasks/push_notifications.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Celery tasks for push notifications."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, time
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.celery_app import celery_app
|
||||
from app.core.logging import get_logger
|
||||
from app.db.database import _get_session_factory
|
||||
from app.db.models import PushConfig, PushEvent
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
LOCAL_TZ = ZoneInfo("Asia/Shanghai")
|
||||
QUIET_HOURS_START = time(21, 0)
|
||||
QUIET_HOURS_END = time(9, 0)
|
||||
TRIGGER_WINDOW_MINUTES = 30
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def check_push_notifications() -> None:
|
||||
"""Check push configs and create push events."""
|
||||
asyncio.run(_check_push_notifications())
|
||||
|
||||
|
||||
def _is_quiet_hours(current: time) -> bool:
|
||||
if QUIET_HOURS_START < QUIET_HOURS_END:
|
||||
return QUIET_HOURS_START <= current < QUIET_HOURS_END
|
||||
return current >= QUIET_HOURS_START or current < QUIET_HOURS_END
|
||||
|
||||
|
||||
def _within_window(current: time, target: time) -> bool:
|
||||
current_minutes = current.hour * 60 + current.minute
|
||||
target_minutes = target.hour * 60 + target.minute
|
||||
return 0 <= current_minutes - target_minutes < TRIGGER_WINDOW_MINUTES
|
||||
|
||||
|
||||
async def _already_sent_today(
|
||||
session,
|
||||
child_profile_id: str,
|
||||
now: datetime,
|
||||
) -> bool:
|
||||
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = now.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
result = await session.execute(
|
||||
select(PushEvent.id).where(
|
||||
PushEvent.child_profile_id == child_profile_id,
|
||||
PushEvent.status == "sent",
|
||||
PushEvent.sent_at >= start,
|
||||
PushEvent.sent_at <= end,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
|
||||
async def _check_push_notifications() -> None:
|
||||
session_factory = _get_session_factory()
|
||||
now = datetime.now(LOCAL_TZ)
|
||||
current_day = now.weekday()
|
||||
current_time = now.time()
|
||||
|
||||
async with session_factory() as session:
|
||||
result = await session.execute(
|
||||
select(PushConfig).where(PushConfig.enabled.is_(True))
|
||||
)
|
||||
configs = result.scalars().all()
|
||||
|
||||
for config in configs:
|
||||
if not config.push_time:
|
||||
continue
|
||||
if config.push_days and current_day not in config.push_days:
|
||||
continue
|
||||
if not _within_window(current_time, config.push_time):
|
||||
continue
|
||||
if _is_quiet_hours(current_time):
|
||||
session.add(
|
||||
PushEvent(
|
||||
user_id=config.user_id,
|
||||
child_profile_id=config.child_profile_id,
|
||||
trigger_type="time",
|
||||
status="suppressed",
|
||||
reason="quiet_hours",
|
||||
sent_at=now,
|
||||
)
|
||||
)
|
||||
continue
|
||||
if await _already_sent_today(session, config.child_profile_id, now):
|
||||
continue
|
||||
|
||||
session.add(
|
||||
PushEvent(
|
||||
user_id=config.user_id,
|
||||
child_profile_id=config.child_profile_id,
|
||||
trigger_type="time",
|
||||
status="sent",
|
||||
reason=None,
|
||||
sent_at=now,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"push_event_sent",
|
||||
child_profile_id=config.child_profile_id,
|
||||
user_id=config.user_id,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
14
backend/docs/code_review_report.md
Normal file
14
backend/docs/code_review_report.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# Code Review Report (2nd follow-up)
|
||||
|
||||
## What¡¯s fixed
|
||||
- Provider cache now loads on startup via lifespan (`app/main.py`), so DB providers are honored without manual reload.
|
||||
- Providers support DB-stored `api_key` precedence (`provider_router.py:77-104`) and Provider model added `api_key` column (`db/admin_models.py:25`).
|
||||
- Frontend uses `/api/generate/full` and propagates image-failure warning to detail via query flag; StoryDetail displays banner when image generation failed.
|
||||
- Tests added for full generation, provider failover, config-from-DB, and startup cache load.
|
||||
|
||||
## Remaining issue
|
||||
1) **Missing DB migration for new Provider.api_key column**
|
||||
- Files updated model (`backend/app/db/admin_models.py:25`) but `backend/alembic/versions/0001_init_providers_and_story_mode.py` lacks this column. Existing databases will not have `api_key`, causing runtime errors when accessing or inserting. Add an Alembic migration to add/drop `api_key` to `providers` table and update sample data if needed.
|
||||
|
||||
## Suggested action
|
||||
- Create and apply an Alembic migration adding `api_key` (String, nullable) to `providers`. After migration, verify admin CRUD works with the new field.
|
||||
147
backend/docs/memory_system_dev.md
Normal file
147
backend/docs/memory_system_dev.md
Normal file
@@ -0,0 +1,147 @@
|
||||
# 记忆系统开发指南 (Development Guide)
|
||||
|
||||
本文档详细说明了 PRD 中定义的记忆系统的技术实现细节。
|
||||
|
||||
## 1. 数据库架构变更 (Schema Changes)
|
||||
|
||||
目前的 `MemoryItem` 表结构尚可,但需要增强字段以支持丰富的情感和元数据。
|
||||
|
||||
### 1.1 `MemoryItem` 表优化
|
||||
建议使用 Alembic 进行迁移,增加以下字段或在 `value` JSON 中规范化以下结构:
|
||||
|
||||
```python
|
||||
# 建议在 models.py 中明确这些字段,或者严格定义 value 字段的 Schema
|
||||
class MemoryItem(Base):
|
||||
# ... 现有字段 ...
|
||||
|
||||
# 新增/规范化字段建议
|
||||
# value 字段的 JSON 结构规范:
|
||||
# {
|
||||
# "content": "小兔子战胜了大灰狼", # 记忆的核心文本
|
||||
# "keywords": ["勇敢", "森林"], # 用于检索的关键词
|
||||
# "emotion": "positive", # 情感倾向: positive/negative/neutral
|
||||
# "source_story_id": 123, # 来源故事 ID
|
||||
# "confidence": 0.85 # 记忆置信度 (如果是 AI 自动提取)
|
||||
# }
|
||||
```
|
||||
|
||||
### 1.2 `StoryUniverse` 表优化 (成就系统)
|
||||
目前的成就存储在 `achievements` JSON 字段中。为了支持更复杂的查询(如"获得勇气勋章的所有用户"),建议将其重构为独立关联表,或保持 JSON 但规范化结构。
|
||||
|
||||
**当前 JSON 结构规范**:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "badge_courage_01",
|
||||
"type": "勇气",
|
||||
"name": "小小勇士",
|
||||
"description": "第一次在故事中独自面对困难",
|
||||
"icon_url": "badges/courage.png",
|
||||
"obtained_at": "2023-10-27T10:00:00Z",
|
||||
"source_story_id": 45
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 核心逻辑实现
|
||||
|
||||
### 2.1 记忆注入逻辑 (Prompt Engineering)
|
||||
|
||||
修改 `backend/app/api/stories.py` 中的 `_build_memory_context` 函数。
|
||||
|
||||
**目标**: 生成一段自然的、不仅是罗列数据的 Prompt。
|
||||
|
||||
**伪代码逻辑**:
|
||||
```python
|
||||
def format_memory_for_prompt(memories: list[MemoryItem]) -> str:
|
||||
"""
|
||||
将记忆项转换为自然语言 Prompt 片段。
|
||||
"""
|
||||
context_parts = []
|
||||
|
||||
# 1. 角色记忆
|
||||
chars = [m for m in memories if m.type == 'favorite_character']
|
||||
if chars:
|
||||
names = ", ".join([c.value['name'] for c in chars])
|
||||
context_parts.append(f"孩子特别喜欢的角色有:{names}。请尝试让他们客串出场。")
|
||||
|
||||
# 2. 近期情节
|
||||
recent_stories = [m for m in memories if m.type == 'recent_story'][:2]
|
||||
if recent_stories:
|
||||
for story in recent_stories:
|
||||
context_parts.append(f"最近发生过:{story.value['summary']}。可以在对话中不经意地提及此事。")
|
||||
|
||||
# 3. 避雷区 (负面记忆)
|
||||
scary = [m for m in memories if m.type == 'scary_element']
|
||||
if scary:
|
||||
items = ", ".join([s.value['keyword'] for s in scary])
|
||||
context_parts.append(f"【绝对禁止】不要出现以下让孩子害怕的元素:{items}。")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
```
|
||||
|
||||
### 2.2 成就提取与通知流程
|
||||
|
||||
当前流程在 `app/tasks/achievements.py`。需要完善闭环。
|
||||
|
||||
**改进后的流程**:
|
||||
1. **Story Generation**: 故事生成成功,存入数据库。
|
||||
2. **Async Task**: 触发 Celery 任务 `extract_story_achievements`。
|
||||
3. **LLM Analysis**: 调用 Gemini 分析故事,提取成就。
|
||||
4. **Update DB**: 更新 `StoryUniverse.achievements`。
|
||||
5. **Notification (新增)**:
|
||||
* 创建一个 `Notification` 或 `UserMessage` 记录(需要新建表或使用 Redis Pub/Sub)。
|
||||
* 前端轮询或通过 SSE (Server-Sent Events) 接收通知:"🎉 恭喜!在这个故事里,小明获得了[诚实勋章]!"
|
||||
|
||||
### 2.3 记忆清理与衰减 (Maintenance)
|
||||
|
||||
需要一个后台定时任务(Cron Job),清理无效记忆,避免 Prompt 过长。
|
||||
|
||||
* **频率**: 每天一次。
|
||||
* **逻辑**:
|
||||
* 删除 `ttl_days` 已过期的记录。
|
||||
* 对 `recent_story` 类型的 `base_weight` 进行每日衰减 update(或者只在读取时计算,数据库存静态值,推荐读取时动态计算以减少写操作)。
|
||||
* 当 `MemoryItem` 总数超过 100 条时,触发"记忆总结"任务,将多条旧记忆合并为一条"长期印象" (Long-term Impression)。
|
||||
|
||||
---
|
||||
|
||||
## 3. API 接口规划
|
||||
|
||||
### 3.1 获取成长时间轴
|
||||
`GET /api/profiles/{id}/timeline`
|
||||
|
||||
**Response**:
|
||||
```json
|
||||
{
|
||||
"events": [
|
||||
{
|
||||
"date": "2023-10-01",
|
||||
"type": "milestone",
|
||||
"title": "初次相遇",
|
||||
"description": "创建了角色 [小明]"
|
||||
},
|
||||
{
|
||||
"date": "2023-10-05",
|
||||
"type": "story",
|
||||
"title": "小明与魔法树",
|
||||
"image_url": "..."
|
||||
},
|
||||
{
|
||||
"date": "2023-10-05",
|
||||
"type": "achievement",
|
||||
"badge": {
|
||||
"name": "好奇宝宝",
|
||||
"icon": "..."
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 3.2 记忆反馈 (人工干预)
|
||||
`POST /api/memories/{id}/feedback`
|
||||
|
||||
允许家长手动删除或修正错误的记忆。
|
||||
* **Action**: `delete` | `reinforce` (强化,增加权重)
|
||||
93
backend/docs/memory_system_prd.md
Normal file
93
backend/docs/memory_system_prd.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# 梦语织机 (DreamWeaver) 记忆系统升级 PRD
|
||||
> 版本: v1.0 | 状态: 规划中 | 优先级: High
|
||||
|
||||
## 1. 核心愿景 (Vision)
|
||||
|
||||
将当前的"数据存储"升级为有温度的**"情感连接系统"**。
|
||||
我们不只是在记住数据,而是在**维护孩子与故事世界的关系**。让每一个故事不再是孤立的碎片,而是构建孩子专属"故事宇宙"的砖瓦。
|
||||
|
||||
---
|
||||
|
||||
## 2. 产品痛点与解决方案
|
||||
|
||||
| 用户角色 | 核心痛点 | 解决方案 | 预期价值 |
|
||||
|---------|---------|---------|---------|
|
||||
| **孩子** | "上次的小兔子怎么不认识我了?" <br> 故事之间缺乏连续性,只有单次体验。 | **角色一致性与记忆注入** <br> 故事开头主动提及往事,角色性格延续。 | 建立情感依恋,提升沉浸感。 |
|
||||
| **家长** | "这App除了生成故事还能干嘛?" <br> 无法感知产品的长期教育价值。 | **显性化成长轨迹** <br> 词汇量统计、主题变化、成就徽章可视化。 | 提高付费意愿,提供社交货币。 |
|
||||
| **平台** | 用户用完即走,缺乏留存壁垒。 | **沉没成本与情感资产** <br> 积累的记忆越多,越舍不得离开。 | 提升长期留存率 (LTV)。 |
|
||||
|
||||
---
|
||||
|
||||
## 3. 功能架构:记忆分层模型
|
||||
|
||||
### 3.1 层级 1: 核心档案 (Identity Layer)
|
||||
*性质:永久、静态、显性*
|
||||
* **数据**: 姓名、年龄、性别。
|
||||
* **输入**: 家长在 Onboarding 阶段手动输入。
|
||||
* **作用**: 决定故事的基础适龄性和称呼。
|
||||
|
||||
### 3.2 层级 2: 故事宇宙 (Universe Layer)
|
||||
*性质:长期、动态积累、半显性*
|
||||
* **主角设定**: 姓名、性格特征(勇敢/害羞)、外貌特征(戴眼镜/卷发)。
|
||||
* **常驻配角**: 从随机故事中涌现出的固定伙伴(如"爱吃胡萝卜的松鼠奇奇")。
|
||||
* **世界观**: 故事发生的背景(魔法森林、未来城市、海底世界)。
|
||||
* **成就系统**: 孩子获得的虚拟奖励(勇气勋章、小小探险家)。
|
||||
|
||||
### 3.3 层级 3: 工作记忆 (Working Memory)
|
||||
*性质:短期、自动衰减、隐性*
|
||||
* **关键情节**: 最近 3 个故事的结局和核心冲突。
|
||||
* **情感标记**: 孩子对特定内容的反应(根据“重播”、“跳过”推断)。
|
||||
* **新学词汇**: 故事中出现的高级词汇。
|
||||
|
||||
---
|
||||
|
||||
## 4. 关键功能特性 (Feature Specs)
|
||||
|
||||
### 4.1 智能开场白 (Memory Injection)
|
||||
在生成新故事时,Prompt 必须包含一段"记忆唤醒"指令。
|
||||
* **示例**: "小明,还记得上周我们帮小松鼠找回了松果吗?今天,小松鼠带来了一位新朋友..."
|
||||
* **策略**: 提取权重最高的 Top 3 记忆注入 Prompt。
|
||||
|
||||
### 4.2 成长时间轴 (Growth Timeline)
|
||||
一个可视化的 H5 页面或 App 模块,以时间轴形式展示里程碑。
|
||||
* **节点类型**:
|
||||
* 🌟 **初次相遇**: 创建角色的第一天。
|
||||
* 📖 **阅读打卡**: 累计阅读 10/50/100 本。
|
||||
* 🏅 **获得成就**: 获得"诚实勋章"。
|
||||
* 🧠 **能力解锁**: 第一次阅读"科幻"题材。
|
||||
|
||||
### 4.3 成就仪式感 (Achievement Ceremony)
|
||||
* **触发**: 故事生成并分析后,如果获得新成就。
|
||||
* **表现**: 弹窗动画 + 音效 + "恭喜获得 [勇气] 徽章"。
|
||||
* **分享**:允许生成带二维码的成就海报。
|
||||
|
||||
---
|
||||
|
||||
## 5. 记忆类型扩展 (Memory Types)
|
||||
|
||||
| 类型 Key | 描述 | 来源 | 过期策略 |
|
||||
|---------|------|------|---------|
|
||||
| `recent_story` | 最近读过的故事梗概 | 阅读事件 | 30天衰减 |
|
||||
| `favorite_character` | 孩子喜欢的角色 | 重播/高评分 | 长期有效 |
|
||||
| `scary_element` | 孩子害怕/不喜欢的元素 | 跳过/负反馈 | 长期有效 (避雷) |
|
||||
| `vocabulary_growth` | 新掌握的词汇 | 故事分析 | 90天衰减 |
|
||||
| `emotional_highlight` | 高光时刻 (如: 特别开心的情节) | 互动数据 | 60天衰减 |
|
||||
|
||||
---
|
||||
|
||||
## 6. 实施路线图 (Roadmap)
|
||||
|
||||
### Phase 1: 基础建设 (v0.3.0)
|
||||
* [x] 数据库 `MemoryItem` 表 (已存在)。
|
||||
* [ ] 扩展 `MemoryItem` 类型字段,支持更多维度。
|
||||
* [ ] 优化 `_build_memory_context`,支持更自然的 Prompt 注入。
|
||||
* [ ] 前端:简单的"近期回忆"展示列表。
|
||||
|
||||
### Phase 2: 可视化与成就 (v0.4.0)
|
||||
* [ ] 实现"成就提取器" (Achievement Extractor) 的闭环通知。
|
||||
* [ ] 前端:开发"我的成就"和"成长时间轴"页面。
|
||||
* [ ] 增加故事开场白的动态生成逻辑。
|
||||
|
||||
### Phase 3: 深度智能 (v0.5.0+)
|
||||
* [ ] 引入向量数据库,实现基于语义的记忆检索 (不仅是时间最近)。
|
||||
* [ ] 情感分析模型:分析用户行为推断情感倾向。
|
||||
246
backend/docs/provider_system.md
Normal file
246
backend/docs/provider_system.md
Normal file
@@ -0,0 +1,246 @@
|
||||
# Provider 系统开发文档
|
||||
|
||||
## 当前版本功能 (v0.2.0)
|
||||
|
||||
### 已完成功能
|
||||
|
||||
1. **CQTAI nano 图像适配器** (`app/services/adapters/image/cqtai.py`)
|
||||
- 异步生成 + 轮询获取结果
|
||||
- 支持 nano-banana / nano-banana-pro 模型
|
||||
- 支持多种分辨率和画面比例
|
||||
- 支持图生图 (filesUrl)
|
||||
|
||||
2. **密钥加密存储** (`app/services/secret_service.py`)
|
||||
- Fernet 对称加密,密钥从 SECRET_KEY 派生
|
||||
- Provider API Key 自动加密存储
|
||||
- 密钥管理 API (CRUD)
|
||||
|
||||
3. **指标收集系统** (`app/services/provider_metrics.py`)
|
||||
- 调用成功率、延迟、成本统计
|
||||
- 时间窗口聚合查询
|
||||
- 已集成到 provider_router
|
||||
|
||||
4. **熔断器功能** (`app/services/provider_metrics.py`)
|
||||
- 连续失败 3 次触发熔断
|
||||
- 60 秒后自动恢复尝试
|
||||
- 健康状态持久化到数据库
|
||||
|
||||
5. **管理后台前端** (`app/admin_app.py`)
|
||||
- 独立端口部署 (8001)
|
||||
- Vue 3 + Tailwind CSS 单页应用
|
||||
- Provider CRUD 管理
|
||||
- 密钥管理界面
|
||||
- Basic Auth 认证
|
||||
|
||||
### 配置说明
|
||||
|
||||
```bash
|
||||
# 启动主应用
|
||||
uvicorn app.main:app --port 8000
|
||||
|
||||
# 启动管理后台 (独立端口)
|
||||
uvicorn app.admin_app:app --port 8001
|
||||
```
|
||||
|
||||
环境变量:
|
||||
```
|
||||
CQTAI_API_KEY=your-cqtai-api-key
|
||||
ENABLE_ADMIN_CONSOLE=true
|
||||
ADMIN_USERNAME=admin
|
||||
ADMIN_PASSWORD=your-secure-password
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 下一版本优化计划 (v0.3.0)
|
||||
|
||||
### 高优先级
|
||||
|
||||
#### 1. 智能负载分流 (方案 B)
|
||||
**目标**: 主渠道压力大时自动分流到后备渠道
|
||||
|
||||
**实现方案**:
|
||||
- 监控指标: 并发数、响应延迟、错误率
|
||||
- 分流阈值配置:
|
||||
```python
|
||||
class LoadBalanceConfig:
|
||||
max_concurrent: int = 10 # 并发超过此值时分流
|
||||
max_latency_ms: int = 5000 # 延迟超过此值时分流
|
||||
max_error_rate: float = 0.1 # 错误率超过 10% 时分流
|
||||
```
|
||||
- 分流策略: 加权轮询,根据健康度动态调整权重
|
||||
|
||||
**涉及文件**:
|
||||
- `app/services/provider_router.py` - 添加负载均衡逻辑
|
||||
- `app/services/provider_metrics.py` - 添加并发计数器
|
||||
- `app/db/admin_models.py` - 添加 LoadBalanceConfig 模型
|
||||
|
||||
#### 2. Storybook 适配器
|
||||
**目标**: 生成可翻页的分页故事书
|
||||
|
||||
**实现方案**:
|
||||
- 参考 Gemini AI Story Generator 格式
|
||||
- 输出结构:
|
||||
```python
|
||||
class StorybookPage:
|
||||
page_number: int
|
||||
text: str
|
||||
image_prompt: str
|
||||
image_url: str | None
|
||||
|
||||
class Storybook:
|
||||
title: str
|
||||
pages: list[StorybookPage]
|
||||
cover_url: str | None
|
||||
```
|
||||
- 集成文本 + 图像生成流水线
|
||||
|
||||
**涉及文件**:
|
||||
- `app/services/adapters/storybook/` - 新建目录
|
||||
- `app/api/stories.py` - 添加 storybook 生成端点
|
||||
|
||||
### 中优先级
|
||||
|
||||
#### 3. 成本追踪系统
|
||||
**目标**: 记录实际消费,支持预算控制
|
||||
|
||||
**实现方案**:
|
||||
- 成本记录表:
|
||||
```python
|
||||
class CostRecord:
|
||||
user_id: str
|
||||
provider_id: str
|
||||
capability: str # text/image/tts
|
||||
estimated_cost: Decimal
|
||||
actual_cost: Decimal | None
|
||||
timestamp: datetime
|
||||
```
|
||||
- 预算配置:
|
||||
```python
|
||||
class BudgetConfig:
|
||||
user_id: str
|
||||
daily_limit: Decimal
|
||||
monthly_limit: Decimal
|
||||
alert_threshold: float = 0.8 # 80% 时告警
|
||||
```
|
||||
- 超预算处理: 拒绝请求 / 降级到低成本 provider
|
||||
|
||||
**涉及文件**:
|
||||
- `app/db/admin_models.py` - 添加 CostRecord, BudgetConfig
|
||||
- `app/services/cost_tracker.py` - 新建
|
||||
- `app/api/admin_providers.py` - 添加成本查询 API
|
||||
|
||||
#### 4. 指标可视化
|
||||
**目标**: 管理后台展示供应商指标图表
|
||||
|
||||
**实现方案**:
|
||||
- 添加指标查询 API:
|
||||
- GET /admin/metrics/summary - 汇总统计
|
||||
- GET /admin/metrics/timeline - 时间线数据
|
||||
- GET /admin/metrics/providers/{id} - 单个供应商详情
|
||||
- 前端使用 Chart.js 或 ECharts 展示
|
||||
|
||||
### 低优先级
|
||||
|
||||
#### 5. 多租户 Provider 配置
|
||||
**目标**: 每个租户可配置独立 provider 列表和 API Key
|
||||
|
||||
**实现方案**:
|
||||
- 租户配置表:
|
||||
```python
|
||||
class TenantProviderConfig:
|
||||
tenant_id: str
|
||||
provider_type: str
|
||||
provider_ids: list[str] # 按优先级排序
|
||||
api_key_override: str | None # 加密存储
|
||||
```
|
||||
- 路由时优先使用租户配置,回退到全局配置
|
||||
|
||||
#### 6. Provider 健康检查调度器
|
||||
**目标**: 定期主动检查 provider 健康状态
|
||||
|
||||
**实现方案**:
|
||||
- Celery Beat 定时任务
|
||||
- 每 5 分钟检查一次所有启用的 provider
|
||||
- 更新 ProviderHealth 表
|
||||
|
||||
#### 7. 适配器热加载
|
||||
**目标**: 支持运行时动态加载新适配器
|
||||
|
||||
**实现方案**:
|
||||
- 适配器插件目录: `app/services/adapters/plugins/`
|
||||
- 启动时扫描并注册
|
||||
- 提供 API 触发重新扫描
|
||||
|
||||
---
|
||||
|
||||
## API 变更记录
|
||||
|
||||
### v0.2.0 新增
|
||||
|
||||
| Method | Route | Description |
|
||||
|--------|-------|-------------|
|
||||
| GET | `/admin/secrets` | 列出所有密钥名称 |
|
||||
| POST | `/admin/secrets` | 创建/更新密钥 |
|
||||
| DELETE | `/admin/secrets/{name}` | 删除密钥 |
|
||||
| GET | `/admin/secrets/{name}/verify` | 验证密钥有效性 |
|
||||
|
||||
### 计划中 (v0.3.0)
|
||||
|
||||
| Method | Route | Description |
|
||||
|--------|-------|-------------|
|
||||
| GET | `/admin/metrics/summary` | 指标汇总 |
|
||||
| GET | `/admin/metrics/timeline` | 时间线数据 |
|
||||
| POST | `/api/storybook/generate` | 生成分页故事书 |
|
||||
| GET | `/admin/costs` | 成本统计 |
|
||||
| POST | `/admin/budgets` | 设置预算 |
|
||||
|
||||
---
|
||||
|
||||
## 适配器开发指南
|
||||
|
||||
### 添加新适配器
|
||||
|
||||
1. 创建适配器文件:
|
||||
```python
|
||||
# app/services/adapters/image/new_provider.py
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
@AdapterRegistry.register("image", "new_provider")
|
||||
class NewProviderAdapter(BaseAdapter[str]):
|
||||
adapter_type = "image"
|
||||
adapter_name = "new_provider"
|
||||
|
||||
async def execute(self, prompt: str, **kwargs) -> str:
|
||||
# 实现生成逻辑
|
||||
pass
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
# 实现健康检查
|
||||
pass
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
return 0.01 # USD
|
||||
```
|
||||
|
||||
2. 在 `__init__.py` 中导入:
|
||||
```python
|
||||
# app/services/adapters/__init__.py
|
||||
from app.services.adapters.image import new_provider as _new_provider # noqa: F401
|
||||
```
|
||||
|
||||
3. 添加配置:
|
||||
```python
|
||||
# app/core/config.py
|
||||
new_provider_api_key: str = ""
|
||||
|
||||
# app/services/provider_router.py
|
||||
API_KEY_MAP["new_provider"] = "new_provider_api_key"
|
||||
```
|
||||
|
||||
4. 更新 `.env.example`:
|
||||
```
|
||||
NEW_PROVIDER_API_KEY=
|
||||
```
|
||||
46
backend/pyproject.toml
Normal file
46
backend/pyproject.toml
Normal file
@@ -0,0 +1,46 @@
|
||||
[project]
|
||||
name = "dreamweaver"
|
||||
version = "0.1.0"
|
||||
description = "AI 驱动的儿童故事生成应用"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn[standard]>=0.32.0",
|
||||
"sqlalchemy[asyncio]>=2.0.0",
|
||||
"asyncpg>=0.30.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic-settings>=2.0.0",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"cryptography>=43.0.0",
|
||||
"httpx>=0.28.0",
|
||||
"alembic>=1.13.0",
|
||||
"cachetools>=5.0.0",
|
||||
"tenacity>=8.0.0",
|
||||
"structlog>=24.0.0",
|
||||
"sse-starlette>=2.0.0",
|
||||
"celery>=5.4.0",
|
||||
"redis>=5.0.0",
|
||||
"openai>=1.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"pytest-cov>=4.0.0",
|
||||
"aiosqlite>=0.20.0",
|
||||
"ruff>=0.8.0",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["app*"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py311"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "W"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
27
backend/scripts/add_config_column.py
Normal file
27
backend/scripts/add_config_column.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from app.db.database import async_engine
|
||||
from sqlalchemy import text
|
||||
|
||||
async def upgrade_db():
|
||||
print("🚀 Checking database schema...")
|
||||
async with async_engine.begin() as conn:
|
||||
# Check if column exists
|
||||
result = await conn.execute(text(
|
||||
"SELECT column_name FROM information_schema.columns WHERE table_name='providers' AND column_name='config_json';"
|
||||
))
|
||||
if result.scalar():
|
||||
print("✅ Column 'config_json' already exists.")
|
||||
else:
|
||||
print("⚠️ Column 'config_json' missing. Adding it now...")
|
||||
await conn.execute(text("ALTER TABLE providers ADD COLUMN config_json JSON;"))
|
||||
print("✅ Column 'config_json' added successfully.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if sys.platform == 'win32':
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
asyncio.run(upgrade_db())
|
||||
29
backend/scripts/fix_db_schema.py
Normal file
29
backend/scripts/fix_db_schema.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add backend to path
|
||||
sys.path.append(os.path.join(os.getcwd()))
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
||||
from app.core.config import settings
|
||||
|
||||
async def add_column():
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||
async with async_session() as session:
|
||||
try:
|
||||
print("Adding config_json column to providers table...")
|
||||
await session.execute(text("ALTER TABLE providers ADD COLUMN IF NOT EXISTS config_json JSONB DEFAULT '{}'::jsonb"))
|
||||
await session.commit()
|
||||
print("Successfully added config_json column.")
|
||||
except Exception as e:
|
||||
print(f"Error adding column: {e}")
|
||||
await session.rollback()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
asyncio.run(add_column())
|
||||
21
backend/scripts/manual_init_db.py
Normal file
21
backend/scripts/manual_init_db.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add backend to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from app.db.database import init_db
|
||||
from app.core.logging import setup_logging
|
||||
|
||||
async def main():
|
||||
setup_logging()
|
||||
print("Initializing database...")
|
||||
try:
|
||||
await init_db()
|
||||
print("Database initialized successfully.")
|
||||
except Exception as e:
|
||||
print(f"Error initializing database: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests package
|
||||
146
backend/tests/conftest.py
Normal file
146
backend/tests/conftest.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""测试配置和 fixtures。"""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing")
|
||||
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
|
||||
|
||||
from app.core.security import create_access_token
|
||||
from app.api.stories import _request_log
|
||||
from app.db.database import get_db
|
||||
from app.db.models import Base, Story, User
|
||||
from app.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""创建内存数据库引擎。"""
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""创建数据库会话。"""
|
||||
session_factory = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_user(db_session: AsyncSession) -> User:
|
||||
"""创建测试用户。"""
|
||||
user = User(
|
||||
id="github:12345",
|
||||
name="Test User",
|
||||
avatar_url="https://example.com/avatar.png",
|
||||
provider="github",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_story(db_session: AsyncSession, test_user: User) -> Story:
|
||||
"""创建测试故事。"""
|
||||
story = Story(
|
||||
user_id=test_user.id,
|
||||
title="测试故事",
|
||||
story_text="从前有一只小兔子...",
|
||||
cover_prompt="A cute rabbit in a forest",
|
||||
mode="generated",
|
||||
)
|
||||
db_session.add(story)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(story)
|
||||
return story
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_token(test_user: User) -> str:
|
||||
"""生成测试用户的 JWT token。"""
|
||||
return create_access_token({"sub": test_user.id})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(db_session: AsyncSession) -> TestClient:
|
||||
"""创建测试客户端。"""
|
||||
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_client(client: TestClient, auth_token: str) -> TestClient:
|
||||
"""带认证的测试客户端。"""
|
||||
client.cookies.set("access_token", auth_token)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_rate_limit_cache():
|
||||
"""确保每个测试用例的限流缓存互不影响。"""
|
||||
_request_log.clear()
|
||||
yield
|
||||
_request_log.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_provider():
|
||||
"""Mock 文本生成适配器 API 调用。"""
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
|
||||
mock_result = StoryOutput(
|
||||
mode="generated",
|
||||
title="小兔子的冒险",
|
||||
story_text="从前有一只小兔子...",
|
||||
cover_prompt_suggestion="A cute rabbit",
|
||||
)
|
||||
|
||||
with patch("app.api.stories.generate_story_content", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = mock_result
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_image_provider():
|
||||
"""Mock 图像生成。"""
|
||||
with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = "https://example.com/image.png"
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts_provider():
|
||||
"""Mock TTS。"""
|
||||
with patch("app.api.stories.text_to_speech", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = b"fake-audio-bytes"
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_all_providers(mock_text_provider, mock_image_provider, mock_tts_provider):
|
||||
"""Mock 所有 AI 供应商。"""
|
||||
return {
|
||||
"text_primary": mock_text_provider,
|
||||
"image_primary": mock_image_provider,
|
||||
"tts_primary": mock_tts_provider,
|
||||
}
|
||||
65
backend/tests/test_auth.py
Normal file
65
backend/tests/test_auth.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""认证相关测试。"""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.core.security import create_access_token, decode_access_token
|
||||
|
||||
|
||||
class TestJWT:
|
||||
"""JWT token 测试。"""
|
||||
|
||||
def test_create_and_decode_token(self):
|
||||
"""测试 token 创建和解码。"""
|
||||
payload = {"sub": "github:12345"}
|
||||
token = create_access_token(payload)
|
||||
decoded = decode_access_token(token)
|
||||
assert decoded is not None
|
||||
assert decoded["sub"] == "github:12345"
|
||||
|
||||
def test_decode_invalid_token(self):
|
||||
"""测试无效 token 解码。"""
|
||||
result = decode_access_token("invalid-token")
|
||||
assert result is None
|
||||
|
||||
def test_decode_empty_token(self):
|
||||
"""测试空 token 解码。"""
|
||||
result = decode_access_token("")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSession:
|
||||
"""Session 端点测试。"""
|
||||
|
||||
def test_session_without_auth(self, client: TestClient):
|
||||
"""未登录时获取 session。"""
|
||||
response = client.get("/auth/session")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user"] is None
|
||||
|
||||
def test_session_with_auth(self, auth_client: TestClient, test_user):
|
||||
"""已登录时获取 session。"""
|
||||
response = auth_client.get("/auth/session")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user"] is not None
|
||||
assert data["user"]["id"] == test_user.id
|
||||
assert data["user"]["name"] == test_user.name
|
||||
|
||||
def test_session_with_invalid_token(self, client: TestClient):
|
||||
"""无效 token 获取 session。"""
|
||||
client.cookies.set("access_token", "invalid-token")
|
||||
response = client.get("/auth/session")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user"] is None
|
||||
|
||||
|
||||
class TestSignout:
|
||||
"""登出测试。"""
|
||||
|
||||
def test_signout(self, auth_client: TestClient):
|
||||
"""测试登出。"""
|
||||
response = auth_client.post("/auth/signout", follow_redirects=False)
|
||||
assert response.status_code == 302
|
||||
78
backend/tests/test_profiles.py
Normal file
78
backend/tests/test_profiles.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Child profile API tests."""
|
||||
|
||||
from datetime import date
|
||||
|
||||
|
||||
def _calc_age(birth_date: date) -> int:
|
||||
today = date.today()
|
||||
return today.year - birth_date.year - (
|
||||
(today.month, today.day) < (birth_date.month, birth_date.day)
|
||||
)
|
||||
|
||||
|
||||
def test_list_profiles_empty(auth_client):
|
||||
response = auth_client.get("/api/profiles")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["profiles"] == []
|
||||
assert data["total"] == 0
|
||||
|
||||
|
||||
def test_create_update_delete_profile(auth_client):
|
||||
payload = {
|
||||
"name": "小明",
|
||||
"birth_date": "2020-05-12",
|
||||
"gender": "male",
|
||||
"interests": ["太空", "机器人"],
|
||||
"growth_themes": ["勇气"],
|
||||
}
|
||||
response = auth_client.post("/api/profiles", json=payload)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == payload["name"]
|
||||
assert data["gender"] == payload["gender"]
|
||||
assert data["interests"] == payload["interests"]
|
||||
assert data["growth_themes"] == payload["growth_themes"]
|
||||
assert data["age"] == _calc_age(date.fromisoformat(payload["birth_date"]))
|
||||
|
||||
profile_id = data["id"]
|
||||
|
||||
update_payload = {"growth_themes": ["分享", "独立"]}
|
||||
response = auth_client.put(f"/api/profiles/{profile_id}", json=update_payload)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["growth_themes"] == update_payload["growth_themes"]
|
||||
|
||||
response = auth_client.delete(f"/api/profiles/{profile_id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "Deleted"
|
||||
|
||||
|
||||
def test_profile_limit_and_duplicate(auth_client):
|
||||
# 先测试重复名称(在达到限制前)
|
||||
response = auth_client.post(
|
||||
"/api/profiles",
|
||||
json={"name": "孩子1", "gender": "female"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
response = auth_client.post(
|
||||
"/api/profiles",
|
||||
json={"name": "孩子1", "gender": "female"},
|
||||
)
|
||||
assert response.status_code == 409 # 重复名称
|
||||
|
||||
# 继续创建到上限
|
||||
for i in range(2, 6):
|
||||
response = auth_client.post(
|
||||
"/api/profiles",
|
||||
json={"name": f"孩子{i}", "gender": "female"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# 测试数量限制
|
||||
response = auth_client.post(
|
||||
"/api/profiles",
|
||||
json={"name": "孩子6", "gender": "female"},
|
||||
)
|
||||
assert response.status_code == 400 # 超过5个限制
|
||||
195
backend/tests/test_provider_router.py
Normal file
195
backend/tests/test_provider_router.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Provider router 测试 - failover 和配置加载。"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.adapters import AdapterConfig
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
|
||||
|
||||
class TestProviderFailover:
|
||||
"""Provider failover 测试。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failover_to_second_provider(self):
|
||||
"""第一个 provider 失败时切换到第二个。"""
|
||||
from app.services import provider_router
|
||||
|
||||
# Mock 两个 provider - 使用 spec=False 并显式设置所有属性
|
||||
mock_provider_1 = MagicMock()
|
||||
mock_provider_1.configure_mock(
|
||||
id="provider-1",
|
||||
type="text",
|
||||
adapter="text_primary",
|
||||
api_key="key1",
|
||||
api_base=None,
|
||||
model=None,
|
||||
timeout_ms=60000,
|
||||
max_retries=3,
|
||||
config_ref=None,
|
||||
config_json={},
|
||||
priority=10,
|
||||
weight=1.0,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
mock_provider_2 = MagicMock()
|
||||
mock_provider_2.configure_mock(
|
||||
id="provider-2",
|
||||
type="text",
|
||||
adapter="text_primary",
|
||||
api_key="key2",
|
||||
api_base=None,
|
||||
model=None,
|
||||
timeout_ms=60000,
|
||||
max_retries=3,
|
||||
config_ref=None,
|
||||
config_json={},
|
||||
priority=5,
|
||||
weight=1.0,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
mock_providers = [mock_provider_1, mock_provider_2]
|
||||
|
||||
mock_result = StoryOutput(
|
||||
mode="generated",
|
||||
title="测试故事",
|
||||
story_text="内容",
|
||||
cover_prompt_suggestion="prompt",
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise Exception("First provider failed")
|
||||
return mock_result
|
||||
|
||||
with patch.object(provider_router, "get_providers", return_value=mock_providers):
|
||||
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
|
||||
mock_adapter_class = MagicMock()
|
||||
mock_adapter_instance = MagicMock()
|
||||
mock_adapter_instance.execute = mock_execute
|
||||
mock_adapter_class.return_value = mock_adapter_instance
|
||||
mock_get.return_value = mock_adapter_class
|
||||
|
||||
result = await provider_router.generate_story_content(
|
||||
input_type="keywords",
|
||||
data="测试",
|
||||
)
|
||||
|
||||
assert result == mock_result
|
||||
assert call_count == 2 # 第一个失败,第二个成功
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_providers_fail(self):
|
||||
"""所有 provider 都失败时抛出异常。"""
|
||||
from app.services import provider_router
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.configure_mock(
|
||||
id="provider-1",
|
||||
type="text",
|
||||
adapter="text_primary",
|
||||
api_key="key1",
|
||||
api_base=None,
|
||||
model=None,
|
||||
timeout_ms=60000,
|
||||
max_retries=3,
|
||||
config_ref=None,
|
||||
config_json={},
|
||||
priority=10,
|
||||
weight=1.0,
|
||||
enabled=True,
|
||||
)
|
||||
mock_providers = [mock_provider]
|
||||
|
||||
async def mock_execute(**kwargs):
|
||||
raise Exception("Provider failed")
|
||||
|
||||
with patch.object(provider_router, "get_providers", return_value=mock_providers):
|
||||
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
|
||||
mock_adapter_class = MagicMock()
|
||||
mock_adapter_instance = MagicMock()
|
||||
mock_adapter_instance.execute = mock_execute
|
||||
mock_adapter_class.return_value = mock_adapter_instance
|
||||
mock_get.return_value = mock_adapter_class
|
||||
|
||||
with pytest.raises(ValueError, match="No text provider succeeded"):
|
||||
await provider_router.generate_story_content(
|
||||
input_type="keywords",
|
||||
data="测试",
|
||||
)
|
||||
|
||||
|
||||
class TestProviderConfigFromDB:
|
||||
"""从 DB 加载 provider 配置测试。"""
|
||||
|
||||
def test_build_config_from_provider_with_api_key(self):
|
||||
"""Provider 有 api_key 时优先使用。"""
|
||||
from app.services.provider_router import _build_config_from_provider
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.adapter = "text_primary"
|
||||
mock_provider.api_key = "db-api-key"
|
||||
mock_provider.api_base = "https://custom.api.com"
|
||||
mock_provider.model = "custom-model"
|
||||
mock_provider.timeout_ms = 30000
|
||||
mock_provider.max_retries = 5
|
||||
mock_provider.config_ref = None
|
||||
mock_provider.config_json = {}
|
||||
|
||||
config = _build_config_from_provider(mock_provider)
|
||||
|
||||
assert config.api_key == "db-api-key"
|
||||
assert config.api_base == "https://custom.api.com"
|
||||
assert config.model == "custom-model"
|
||||
assert config.timeout_ms == 30000
|
||||
assert config.max_retries == 5
|
||||
|
||||
def test_build_config_fallback_to_settings(self):
|
||||
"""Provider 无 api_key 时回退到 settings。"""
|
||||
from app.services.provider_router import _build_config_from_provider
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.adapter = "text_primary"
|
||||
mock_provider.api_key = None
|
||||
mock_provider.api_base = None
|
||||
mock_provider.model = None
|
||||
mock_provider.timeout_ms = None
|
||||
mock_provider.max_retries = None
|
||||
mock_provider.config_ref = "text_api_key"
|
||||
mock_provider.config_json = {}
|
||||
|
||||
with patch("app.services.provider_router.settings") as mock_settings:
|
||||
mock_settings.text_api_key = "settings-api-key"
|
||||
mock_settings.text_model = "gemini-2.0-flash"
|
||||
|
||||
config = _build_config_from_provider(mock_provider)
|
||||
|
||||
assert config.api_key == "settings-api-key"
|
||||
|
||||
|
||||
class TestProviderCacheStartup:
|
||||
"""Provider cache 启动加载测试。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_loaded_on_startup(self):
|
||||
"""启动时加载 provider cache。"""
|
||||
from app.main import _load_provider_cache
|
||||
|
||||
with patch("app.db.database._get_session_factory") as mock_factory:
|
||||
mock_session = AsyncMock()
|
||||
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_factory.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
with patch("app.services.provider_cache.reload_providers", new_callable=AsyncMock) as mock_reload:
|
||||
mock_reload.return_value = {"text": [], "image": [], "tts": []}
|
||||
|
||||
await _load_provider_cache()
|
||||
|
||||
mock_reload.assert_called_once()
|
||||
77
backend/tests/test_push_configs.py
Normal file
77
backend/tests/test_push_configs.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Push config API tests."""
|
||||
|
||||
|
||||
def _create_profile(auth_client) -> str:
|
||||
response = auth_client.post(
|
||||
"/api/profiles",
|
||||
json={"name": "小明", "gender": "male"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
def test_create_list_update_push_config(auth_client):
|
||||
profile_id = _create_profile(auth_client)
|
||||
|
||||
response = auth_client.put(
|
||||
"/api/push-configs",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"push_time": "20:30",
|
||||
"push_days": [1, 3, 5],
|
||||
"enabled": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["child_profile_id"] == profile_id
|
||||
assert data["push_time"].startswith("20:30")
|
||||
assert data["push_days"] == [1, 3, 5]
|
||||
assert data["enabled"] is True
|
||||
|
||||
response = auth_client.get("/api/push-configs")
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["total"] == 1
|
||||
|
||||
response = auth_client.put(
|
||||
"/api/push-configs",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"enabled": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["enabled"] is False
|
||||
assert data["push_time"].startswith("20:30")
|
||||
assert data["push_days"] == [1, 3, 5]
|
||||
|
||||
|
||||
def test_push_config_validation(auth_client):
|
||||
profile_id = _create_profile(auth_client)
|
||||
|
||||
response = auth_client.put(
|
||||
"/api/push-configs",
|
||||
json={"child_profile_id": profile_id},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
response = auth_client.put(
|
||||
"/api/push-configs",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"push_time": "19:00",
|
||||
"push_days": [7],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
response = auth_client.put(
|
||||
"/api/push-configs",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"push_time": None,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
143
backend/tests/test_reading_events.py
Normal file
143
backend/tests/test_reading_events.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Reading event API tests."""
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.db.models import MemoryItem
|
||||
from app.main import app
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def _create_profile(client: AsyncClient) -> str:
|
||||
response = await client.post(
|
||||
"/api/profiles",
|
||||
json={"name": "小明", "gender": "male"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
async def test_create_reading_event_updates_stats_and_memory(
|
||||
db_session,
|
||||
test_user,
|
||||
auth_token,
|
||||
test_story,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
profile_id = await _create_profile(client)
|
||||
|
||||
response = await client.post(
|
||||
"/api/reading-events",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"story_id": test_story.id,
|
||||
"event_type": "completed",
|
||||
"reading_time": 120,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["child_profile_id"] == profile_id
|
||||
assert data["story_id"] == test_story.id
|
||||
assert data["event_type"] == "completed"
|
||||
|
||||
response = await client.get(f"/api/profiles/{profile_id}")
|
||||
assert response.status_code == 200
|
||||
profile = response.json()
|
||||
assert profile["stories_count"] == 1
|
||||
assert profile["total_reading_time"] == 120
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
|
||||
)
|
||||
items = result.scalars().all()
|
||||
assert len(items) == 1
|
||||
assert items[0].type == "recent_story"
|
||||
assert items[0].value["story_id"] == test_story.id
|
||||
|
||||
response = await client.post(
|
||||
"/api/reading-events",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"story_id": test_story.id,
|
||||
"event_type": "skipped",
|
||||
"reading_time": 0,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
|
||||
)
|
||||
assert len(result.scalars().all()) == 1
|
||||
|
||||
response = await client.post(
|
||||
"/api/reading-events",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"story_id": test_story.id,
|
||||
"event_type": "completed",
|
||||
"reading_time": 0,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
response = await client.get(f"/api/profiles/{profile_id}")
|
||||
assert response.status_code == 200
|
||||
profile = response.json()
|
||||
assert profile["stories_count"] == 1
|
||||
assert profile["total_reading_time"] == 120
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_reading_event_validation_errors(
|
||||
db_session,
|
||||
test_user,
|
||||
auth_token,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post(
|
||||
"/api/reading-events",
|
||||
json={
|
||||
"child_profile_id": "not-exist",
|
||||
"event_type": "started",
|
||||
"reading_time": 0,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
profile_id = await _create_profile(client)
|
||||
|
||||
response = await client.post(
|
||||
"/api/reading-events",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"story_id": 999999,
|
||||
"event_type": "completed",
|
||||
"reading_time": 0,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
257
backend/tests/test_stories.py
Normal file
257
backend/tests/test_stories.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""故事 API 测试。"""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.stories import _request_log, RATE_LIMIT_REQUESTS
|
||||
|
||||
|
||||
class TestStoryGenerate:
|
||||
"""故事生成测试。"""
|
||||
|
||||
def test_generate_without_auth(self, client: TestClient):
|
||||
"""未登录时生成故事。"""
|
||||
response = client.post(
|
||||
"/api/generate",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_generate_with_empty_data(self, auth_client: TestClient):
|
||||
"""空数据生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate",
|
||||
json={"type": "keywords", "data": ""},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_generate_with_invalid_type(self, auth_client: TestClient):
|
||||
"""无效类型生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate",
|
||||
json={"type": "invalid", "data": "test"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_generate_story_success(self, auth_client: TestClient, mock_text_provider):
|
||||
"""成功生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate",
|
||||
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert "title" in data
|
||||
assert "story_text" in data
|
||||
assert data["mode"] == "generated"
|
||||
|
||||
|
||||
class TestStoryList:
|
||||
"""故事列表测试。"""
|
||||
|
||||
def test_list_without_auth(self, client: TestClient):
|
||||
"""未登录时获取列表。"""
|
||||
response = client.get("/api/stories")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_list_empty(self, auth_client: TestClient):
|
||||
"""空列表。"""
|
||||
response = auth_client.get("/api/stories")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_list_with_stories(self, auth_client: TestClient, test_story):
|
||||
"""有故事时获取列表。"""
|
||||
response = auth_client.get("/api/stories")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == test_story.id
|
||||
assert data[0]["title"] == test_story.title
|
||||
|
||||
def test_list_pagination(self, auth_client: TestClient, test_story):
|
||||
"""分页测试。"""
|
||||
response = auth_client.get("/api/stories?limit=1&offset=0")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
|
||||
response = auth_client.get("/api/stories?limit=1&offset=1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 0
|
||||
|
||||
|
||||
class TestStoryDetail:
|
||||
"""故事详情测试。"""
|
||||
|
||||
def test_get_story_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时获取详情。"""
|
||||
response = client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_get_story_not_found(self, auth_client: TestClient):
|
||||
"""故事不存在。"""
|
||||
response = auth_client.get("/api/stories/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_story_success(self, auth_client: TestClient, test_story):
|
||||
"""成功获取详情。"""
|
||||
response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == test_story.id
|
||||
assert data["title"] == test_story.title
|
||||
assert data["story_text"] == test_story.story_text
|
||||
|
||||
|
||||
class TestStoryDelete:
|
||||
"""故事删除测试。"""
|
||||
|
||||
def test_delete_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时删除。"""
|
||||
response = client.delete(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_delete_not_found(self, auth_client: TestClient):
|
||||
"""删除不存在的故事。"""
|
||||
response = auth_client.delete("/api/stories/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_delete_success(self, auth_client: TestClient, test_story):
|
||||
"""成功删除故事。"""
|
||||
response = auth_client.delete(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "Deleted"
|
||||
|
||||
response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestRateLimit:
|
||||
"""Rate limit 测试。"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清理 rate limit 缓存。"""
|
||||
_request_log.clear()
|
||||
|
||||
def test_rate_limit_allows_normal_requests(self, auth_client: TestClient, test_story):
|
||||
"""正常请求不触发限流。"""
|
||||
for _ in range(RATE_LIMIT_REQUESTS - 1):
|
||||
response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_rate_limit_blocks_excess_requests(self, auth_client: TestClient, test_story):
|
||||
"""超限请求被阻止。"""
|
||||
for _ in range(RATE_LIMIT_REQUESTS):
|
||||
auth_client.get(f"/api/stories/{test_story.id}")
|
||||
|
||||
response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 429
|
||||
assert "Too many requests" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestImageGenerate:
|
||||
"""封面图片生成测试。"""
|
||||
|
||||
def test_generate_image_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时生成图片。"""
|
||||
response = client.post(f"/api/image/generate/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_generate_image_not_found(self, auth_client: TestClient):
|
||||
"""故事不存在。"""
|
||||
response = auth_client.post("/api/image/generate/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestAudio:
|
||||
"""语音朗读测试。"""
|
||||
|
||||
def test_get_audio_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时获取音频。"""
|
||||
response = client.get(f"/api/audio/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_get_audio_not_found(self, auth_client: TestClient):
|
||||
"""故事不存在。"""
|
||||
response = auth_client.get("/api/audio/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_audio_success(self, auth_client: TestClient, test_story, mock_tts_provider):
|
||||
"""成功获取音频。"""
|
||||
response = auth_client.get(f"/api/audio/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
|
||||
|
||||
class TestGenerateFull:
|
||||
"""完整故事生成测试(/api/generate/full)。"""
|
||||
|
||||
def test_generate_full_without_auth(self, client: TestClient):
|
||||
"""未登录时生成完整故事。"""
|
||||
response = client.post(
|
||||
"/api/generate/full",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_generate_full_success(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
|
||||
"""成功生成完整故事(含图片)。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate/full",
|
||||
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert "title" in data
|
||||
assert "story_text" in data
|
||||
assert data["mode"] == "generated"
|
||||
assert data["image_url"] == "https://example.com/image.png"
|
||||
assert data["audio_ready"] is False # 音频按需生成
|
||||
assert data["errors"] == {}
|
||||
|
||||
def test_generate_full_image_failure(self, auth_client: TestClient, mock_text_provider):
|
||||
"""图片生成失败时返回部分成功。"""
|
||||
with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock_img:
|
||||
mock_img.side_effect = Exception("Image API error")
|
||||
response = auth_client.post(
|
||||
"/api/generate/full",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["image_url"] is None
|
||||
assert "image" in data["errors"]
|
||||
assert "Image API error" in data["errors"]["image"]
|
||||
|
||||
def test_generate_full_with_education_theme(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
|
||||
"""带教育主题生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate/full",
|
||||
json={
|
||||
"type": "keywords",
|
||||
"data": "小兔子, 森林",
|
||||
"education_theme": "勇气与友谊",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_text_provider.assert_called_once()
|
||||
call_kwargs = mock_text_provider.call_args.kwargs
|
||||
assert call_kwargs["education_theme"] == "勇气与友谊"
|
||||
|
||||
|
||||
class TestImageGenerateSuccess:
|
||||
"""封面图片生成成功测试。"""
|
||||
|
||||
def test_generate_image_success(self, auth_client: TestClient, test_story, mock_image_provider):
|
||||
"""成功生成图片。"""
|
||||
response = auth_client.post(f"/api/image/generate/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["image_url"] == "https://example.com/image.png"
|
||||
68
backend/tests/test_universes.py
Normal file
68
backend/tests/test_universes.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Story universe API tests."""
|
||||
|
||||
|
||||
def _create_profile(auth_client):
|
||||
response = auth_client.post(
|
||||
"/api/profiles",
|
||||
json={
|
||||
"name": "小明",
|
||||
"gender": "male",
|
||||
"interests": ["太空"],
|
||||
"growth_themes": ["勇气"],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
def test_create_list_update_universe(auth_client):
|
||||
profile_id = _create_profile(auth_client)
|
||||
|
||||
payload = {
|
||||
"name": "星际冒险",
|
||||
"protagonist": {"name": "小明", "role": "船长"},
|
||||
"recurring_characters": [{"name": "小七", "role": "机器人"}],
|
||||
"world_settings": {"world_name": "星际学院"},
|
||||
}
|
||||
|
||||
response = auth_client.post(f"/api/profiles/{profile_id}/universes", json=payload)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == payload["name"]
|
||||
|
||||
universe_id = data["id"]
|
||||
|
||||
response = auth_client.get(f"/api/profiles/{profile_id}/universes")
|
||||
assert response.status_code == 200
|
||||
list_data = response.json()
|
||||
assert list_data["total"] == 1
|
||||
|
||||
response = auth_client.put(
|
||||
f"/api/universes/{universe_id}",
|
||||
json={"name": "星际冒险·第二季"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "星际冒险·第二季"
|
||||
|
||||
|
||||
def test_add_achievement(auth_client):
|
||||
profile_id = _create_profile(auth_client)
|
||||
|
||||
response = auth_client.post(
|
||||
f"/api/profiles/{profile_id}/universes",
|
||||
json={
|
||||
"name": "梦幻森林",
|
||||
"protagonist": {"name": "小红"},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
universe_id = response.json()["id"]
|
||||
|
||||
response = auth_client.post(
|
||||
f"/api/universes/{universe_id}/achievements",
|
||||
json={"type": "勇气", "description": "克服黑暗"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert {"type": "勇气", "description": "克服黑暗"} in data["achievements"]
|
||||
Reference in New Issue
Block a user