wip: snapshot full local workspace state
Some checks are pending
Build and Push Docker Images / changes (push) Waiting to run
Build and Push Docker Images / build-backend (push) Blocked by required conditions
Build and Push Docker Images / build-frontend (push) Blocked by required conditions
Build and Push Docker Images / build-admin-frontend (push) Blocked by required conditions

This commit is contained in:
2026-04-17 18:58:11 +08:00
parent fea4ef012f
commit b8d3cb4644
181 changed files with 16964 additions and 17486 deletions

View File

@@ -1,33 +1,33 @@
FROM python:3.11-slim
WORKDIR /app
# 设置环境变量
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
# 安装系统工具 (curl用于可能的健康检查)
RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/*
# 1. 缓存层:仅复制依赖定义并安装
# 创建伪造的 app 目录以满足 pip install . 的要求
COPY pyproject.toml .
RUN mkdir app && touch app/__init__.py
RUN pip install --no-cache-dir .
# 2. 源码层:复制真实代码
COPY app ./app
COPY alembic ./alembic
COPY alembic.ini .
# 再次安装本身(不带依赖),确保源码更新被标记为已安装
RUN pip install --no-cache-dir --no-deps .
# 创建静态文件目录
RUN mkdir -p static/images
# 暴露端口
EXPOSE 8000
# 默认启动命令(可被 docker-compose 覆盖)
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
FROM python:3.11-slim
WORKDIR /app
# 设置环境变量
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
# 安装系统工具 (curl用于可能的健康检查)
RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/*
# 1. 缓存层:仅复制依赖定义并安装
# 创建伪造的 app 目录以满足 pip install . 的要求
COPY pyproject.toml .
RUN mkdir app && touch app/__init__.py
RUN pip install --no-cache-dir .
# 2. 源码层:复制真实代码
COPY app ./app
COPY alembic ./alembic
COPY alembic.ini .
# 再次安装本身(不带依赖),确保源码更新被标记为已安装
RUN pip install --no-cache-dir --no-deps .
# 创建静态文件目录
RUN mkdir -p static/images
# 暴露端口
EXPOSE 8000
# 默认启动命令(可被 docker-compose 覆盖)
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -1,29 +1,29 @@
"""add api_key to providers
Revision ID: 0002_add_api_key_to_providers
Revises: 0001_init_providers_and_story_mode
Create Date: 2025-01-01
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0002_add_api_key"
down_revision = "0001_init_providers"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 添加 api_key 列,可为空,优先于 config_ref 使用
with op.batch_alter_table("providers", schema=None) as batch_op:
batch_op.add_column(
sa.Column("api_key", sa.String(length=500), nullable=True)
)
def downgrade() -> None:
with op.batch_alter_table("providers", schema=None) as batch_op:
batch_op.drop_column("api_key")
"""add api_key to providers
Revision ID: 0002_add_api_key_to_providers
Revises: 0001_init_providers_and_story_mode
Create Date: 2025-01-01
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0002_add_api_key"
down_revision = "0001_init_providers"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 添加 api_key 列,可为空,优先于 config_ref 使用
with op.batch_alter_table("providers", schema=None) as batch_op:
batch_op.add_column(
sa.Column("api_key", sa.String(length=500), nullable=True)
)
def downgrade() -> None:
with op.batch_alter_table("providers", schema=None) as batch_op:
batch_op.drop_column("api_key")

View File

@@ -1,100 +1,100 @@
"""add provider monitoring tables
Revision ID: 0003_add_monitoring
Revises: 0002_add_api_key
Create Date: 2025-01-01
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0003_add_monitoring"
down_revision = "0002_add_api_key"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 创建 provider_metrics 表
op.create_table(
"provider_metrics",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("provider_id", sa.String(length=36), nullable=False),
sa.Column(
"timestamp",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("success", sa.Boolean(), nullable=False),
sa.Column("latency_ms", sa.Integer(), nullable=True),
sa.Column("cost_usd", sa.Numeric(precision=10, scale=6), nullable=True),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column("request_id", sa.String(length=100), nullable=True),
sa.ForeignKeyConstraint(
["provider_id"],
["providers.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_provider_metrics_provider_id",
"provider_metrics",
["provider_id"],
unique=False,
)
op.create_index(
"ix_provider_metrics_timestamp",
"provider_metrics",
["timestamp"],
unique=False,
)
# 创建 provider_health 表
op.create_table(
"provider_health",
sa.Column("provider_id", sa.String(length=36), nullable=False),
sa.Column("is_healthy", sa.Boolean(), server_default=sa.text("true"), nullable=True),
sa.Column("last_check", sa.DateTime(timezone=True), nullable=True),
sa.Column("consecutive_failures", sa.Integer(), server_default=sa.text("0"), nullable=True),
sa.Column("last_error", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(
["provider_id"],
["providers.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("provider_id"),
)
# 创建 provider_secrets 表
op.create_table(
"provider_secrets",
sa.Column("id", sa.String(length=36), nullable=False),
sa.Column("name", sa.String(length=100), nullable=False),
sa.Column("encrypted_value", sa.Text(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
def downgrade() -> None:
op.drop_table("provider_secrets")
op.drop_table("provider_health")
op.drop_index("ix_provider_metrics_timestamp", table_name="provider_metrics")
op.drop_index("ix_provider_metrics_provider_id", table_name="provider_metrics")
op.drop_table("provider_metrics")
"""add provider monitoring tables
Revision ID: 0003_add_monitoring
Revises: 0002_add_api_key
Create Date: 2025-01-01
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0003_add_monitoring"
down_revision = "0002_add_api_key"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 创建 provider_metrics 表
op.create_table(
"provider_metrics",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("provider_id", sa.String(length=36), nullable=False),
sa.Column(
"timestamp",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("success", sa.Boolean(), nullable=False),
sa.Column("latency_ms", sa.Integer(), nullable=True),
sa.Column("cost_usd", sa.Numeric(precision=10, scale=6), nullable=True),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column("request_id", sa.String(length=100), nullable=True),
sa.ForeignKeyConstraint(
["provider_id"],
["providers.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_provider_metrics_provider_id",
"provider_metrics",
["provider_id"],
unique=False,
)
op.create_index(
"ix_provider_metrics_timestamp",
"provider_metrics",
["timestamp"],
unique=False,
)
# 创建 provider_health 表
op.create_table(
"provider_health",
sa.Column("provider_id", sa.String(length=36), nullable=False),
sa.Column("is_healthy", sa.Boolean(), server_default=sa.text("true"), nullable=True),
sa.Column("last_check", sa.DateTime(timezone=True), nullable=True),
sa.Column("consecutive_failures", sa.Integer(), server_default=sa.text("0"), nullable=True),
sa.Column("last_error", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(
["provider_id"],
["providers.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("provider_id"),
)
# 创建 provider_secrets 表
op.create_table(
"provider_secrets",
sa.Column("id", sa.String(length=36), nullable=False),
sa.Column("name", sa.String(length=100), nullable=False),
sa.Column("encrypted_value", sa.Text(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
def downgrade() -> None:
op.drop_table("provider_secrets")
op.drop_table("provider_health")
op.drop_index("ix_provider_metrics_timestamp", table_name="provider_metrics")
op.drop_index("ix_provider_metrics_provider_id", table_name="provider_metrics")
op.drop_table("provider_metrics")

View File

@@ -39,4 +39,4 @@ def upgrade():
def downgrade():
op.drop_index("idx_child_profiles_user_id", table_name="child_profiles")
op.drop_table("child_profiles")
op.drop_table("child_profiles")

View File

@@ -64,4 +64,4 @@ def downgrade():
op.drop_index("idx_story_universes_updated_at", table_name="story_universes")
op.drop_index("idx_story_universes_child_id", table_name="story_universes")
op.drop_table("story_universes")
op.drop_table("story_universes")

View File

@@ -75,4 +75,4 @@ def downgrade():
op.drop_index("idx_reading_events_created", table_name="reading_events")
op.drop_index("idx_reading_events_story", table_name="reading_events")
op.drop_index("idx_reading_events_profile", table_name="reading_events")
op.drop_table("reading_events")
op.drop_table("reading_events")

View File

@@ -1,25 +1,25 @@
"""add pages column to stories
Revision ID: 0008_add_pages_to_stories
Revises: 0007_add_push_configs_and_events
Create Date: 2026-01-20
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '0008_add_pages_to_stories'
down_revision = '0007_add_push_configs_and_events'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column('stories', sa.Column('pages', postgresql.JSON(), nullable=True))
def downgrade() -> None:
op.drop_column('stories', 'pages')
"""add pages column to stories
Revision ID: 0008_add_pages_to_stories
Revises: 0007_add_push_configs_and_events
Create Date: 2026-01-20
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '0008_add_pages_to_stories'
down_revision = '0007_add_push_configs_and_events'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column('stories', sa.Column('pages', postgresql.JSON(), nullable=True))
def downgrade() -> None:
op.drop_column('stories', 'pages')

View File

@@ -1,61 +1,61 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api import admin_providers, admin_reload
from app.core.config import settings
from app.core.logging import get_logger, setup_logging
from app.db.database import init_db
setup_logging()
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Admin App lifespan manager."""
logger.info("admin_app_starting")
await init_db()
# 可以在这里加载特定的 Admin 缓存或预热
yield
logger.info("admin_app_shutdown")
app = FastAPI(
title=f"{settings.app_name} Admin Console",
description="Administrative Control Plane for DreamWeaver.",
version="0.1.0",
lifespan=lifespan,
)
# Admin 后台通常允许更宽松的 CORS或者特定的管理域名
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins, # 或者专门的 ADMIN_CORS_ORIGINS
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 根据配置开关挂载路由
if settings.enable_admin_console:
app.include_router(admin_providers.router, prefix="/admin", tags=["admin-providers"])
app.include_router(admin_reload.router, prefix="/admin", tags=["admin-reload"])
else:
@app.get("/admin/{path:path}")
@app.post("/admin/{path:path}")
@app.put("/admin/{path:path}")
@app.delete("/admin/{path:path}")
async def admin_disabled(path: str):
from fastapi import HTTPException
raise HTTPException(
status_code=403,
detail="Admin console is disabled in environment configuration."
)
@app.get("/health")
async def health_check():
return {"status": "ok", "service": "admin-backend"}
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api import admin_providers, admin_reload
from app.core.config import settings
from app.core.logging import get_logger, setup_logging
from app.db.database import init_db
setup_logging()
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Admin App lifespan manager."""
logger.info("admin_app_starting")
await init_db()
# 可以在这里加载特定的 Admin 缓存或预热
yield
logger.info("admin_app_shutdown")
app = FastAPI(
title=f"{settings.app_name} Admin Console",
description="Administrative Control Plane for DreamWeaver.",
version="0.1.0",
lifespan=lifespan,
)
# Admin 后台通常允许更宽松的 CORS或者特定的管理域名
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins, # 或者专门的 ADMIN_CORS_ORIGINS
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 根据配置开关挂载路由
if settings.enable_admin_console:
app.include_router(admin_providers.router, prefix="/admin", tags=["admin-providers"])
app.include_router(admin_reload.router, prefix="/admin", tags=["admin-reload"])
else:
@app.get("/admin/{path:path}")
@app.post("/admin/{path:path}")
@app.put("/admin/{path:path}")
@app.delete("/admin/{path:path}")
async def admin_disabled(path: str):
from fastapi import HTTPException
raise HTTPException(
status_code=403,
detail="Admin console is disabled in environment configuration."
)
@app.get("/health")
async def health_check():
return {"status": "ok", "service": "admin-backend"}

View File

@@ -1,268 +1,268 @@
"""Memory management APIs."""
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, User
from app.services import memory_service
from app.services.memory_service import MemoryType
router = APIRouter()
class MemoryItemResponse(BaseModel):
"""Memory item response."""
id: str
type: str
value: dict
base_weight: float
ttl_days: int | None
created_at: str
last_used_at: str | None
class Config:
from_attributes = True
class MemoryListResponse(BaseModel):
"""Memory list response."""
memories: list[MemoryItemResponse]
total: int
class CreateMemoryRequest(BaseModel):
"""Create memory request."""
type: str = Field(..., description="记忆类型")
value: dict = Field(..., description="记忆内容")
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
weight: float | None = Field(default=None, description="权重")
ttl_days: int | None = Field(default=None, description="过期天数")
class CreateCharacterMemoryRequest(BaseModel):
"""Create character memory request."""
name: str = Field(..., description="角色名称")
description: str | None = Field(default=None, description="角色描述")
source_story_id: int | None = Field(default=None, description="来源故事 ID")
affinity_score: float = Field(default=1.0, ge=0.0, le=1.0, description="喜爱程度")
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
class CreateScaryElementRequest(BaseModel):
"""Create scary element memory request."""
keyword: str = Field(..., description="回避的关键词")
category: str = Field(default="other", description="分类")
source_story_id: int | None = Field(default=None, description="来源故事 ID")
async def _verify_profile_ownership(
profile_id: str, user: User, db: AsyncSession
) -> ChildProfile:
"""验证档案所有权。"""
from sqlalchemy import select
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
return profile
@router.get("/profiles/{profile_id}/memories", response_model=MemoryListResponse)
async def list_memories(
profile_id: str,
memory_type: str | None = None,
universe_id: str | None = None,
limit: int = 50,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""获取档案的记忆列表。"""
await _verify_profile_ownership(profile_id, user, db)
memories = await memory_service.get_profile_memories(
db=db,
profile_id=profile_id,
memory_type=memory_type,
universe_id=universe_id,
limit=limit,
)
return MemoryListResponse(
memories=[
MemoryItemResponse(
id=m.id,
type=m.type,
value=m.value,
base_weight=m.base_weight,
ttl_days=m.ttl_days,
created_at=m.created_at.isoformat() if m.created_at else "",
last_used_at=m.last_used_at.isoformat() if m.last_used_at else None,
)
for m in memories
],
total=len(memories),
)
@router.post("/profiles/{profile_id}/memories", response_model=MemoryItemResponse)
async def create_memory(
profile_id: str,
payload: CreateMemoryRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""创建新的记忆项。"""
await _verify_profile_ownership(profile_id, user, db)
# 验证类型
valid_types = [
MemoryType.RECENT_STORY,
MemoryType.FAVORITE_CHARACTER,
MemoryType.SCARY_ELEMENT,
MemoryType.VOCABULARY_GROWTH,
MemoryType.EMOTIONAL_HIGHLIGHT,
MemoryType.READING_PREFERENCE,
MemoryType.MILESTONE,
MemoryType.SKILL_MASTERED,
]
if payload.type not in valid_types:
raise HTTPException(status_code=400, detail=f"无效的记忆类型: {payload.type}")
memory = await memory_service.create_memory(
db=db,
profile_id=profile_id,
memory_type=payload.type,
value=payload.value,
universe_id=payload.universe_id,
weight=payload.weight,
ttl_days=payload.ttl_days,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.post("/profiles/{profile_id}/memories/character", response_model=MemoryItemResponse)
async def create_character_memory(
profile_id: str,
payload: CreateCharacterMemoryRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""添加喜欢的角色。"""
await _verify_profile_ownership(profile_id, user, db)
memory = await memory_service.create_character_memory(
db=db,
profile_id=profile_id,
name=payload.name,
description=payload.description,
source_story_id=payload.source_story_id,
affinity_score=payload.affinity_score,
universe_id=payload.universe_id,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.post("/profiles/{profile_id}/memories/scary", response_model=MemoryItemResponse)
async def create_scary_element_memory(
profile_id: str,
payload: CreateScaryElementRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""添加回避元素。"""
await _verify_profile_ownership(profile_id, user, db)
memory = await memory_service.create_scary_element_memory(
db=db,
profile_id=profile_id,
keyword=payload.keyword,
category=payload.category,
source_story_id=payload.source_story_id,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.delete("/profiles/{profile_id}/memories/{memory_id}")
async def delete_memory(
profile_id: str,
memory_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""删除记忆项。"""
from sqlalchemy import select
from app.db.models import MemoryItem
await _verify_profile_ownership(profile_id, user, db)
result = await db.execute(
select(MemoryItem).where(
MemoryItem.id == memory_id,
MemoryItem.child_profile_id == profile_id,
)
)
memory = result.scalar_one_or_none()
if not memory:
raise HTTPException(status_code=404, detail="记忆不存在")
await db.delete(memory)
await db.commit()
return {"message": "Deleted"}
@router.get("/memory-types")
async def list_memory_types():
"""获取所有可用的记忆类型及其配置。"""
types = []
for type_name, config in MemoryType.CONFIG.items():
types.append({
"type": type_name,
"default_weight": config[0],
"default_ttl_days": config[1],
"description": config[2],
})
return {"types": types}
"""Memory management APIs."""
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.db.database import get_db
from app.db.models import ChildProfile, User
from app.services import memory_service
from app.services.memory_service import MemoryType
router = APIRouter()
class MemoryItemResponse(BaseModel):
"""Memory item response."""
id: str
type: str
value: dict
base_weight: float
ttl_days: int | None
created_at: str
last_used_at: str | None
class Config:
from_attributes = True
class MemoryListResponse(BaseModel):
"""Memory list response."""
memories: list[MemoryItemResponse]
total: int
class CreateMemoryRequest(BaseModel):
"""Create memory request."""
type: str = Field(..., description="记忆类型")
value: dict = Field(..., description="记忆内容")
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
weight: float | None = Field(default=None, description="权重")
ttl_days: int | None = Field(default=None, description="过期天数")
class CreateCharacterMemoryRequest(BaseModel):
"""Create character memory request."""
name: str = Field(..., description="角色名称")
description: str | None = Field(default=None, description="角色描述")
source_story_id: int | None = Field(default=None, description="来源故事 ID")
affinity_score: float = Field(default=1.0, ge=0.0, le=1.0, description="喜爱程度")
universe_id: str | None = Field(default=None, description="关联的故事宇宙 ID")
class CreateScaryElementRequest(BaseModel):
"""Create scary element memory request."""
keyword: str = Field(..., description="回避的关键词")
category: str = Field(default="other", description="分类")
source_story_id: int | None = Field(default=None, description="来源故事 ID")
async def _verify_profile_ownership(
profile_id: str, user: User, db: AsyncSession
) -> ChildProfile:
"""验证档案所有权。"""
from sqlalchemy import select
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="档案不存在")
return profile
@router.get("/profiles/{profile_id}/memories", response_model=MemoryListResponse)
async def list_memories(
profile_id: str,
memory_type: str | None = None,
universe_id: str | None = None,
limit: int = 50,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""获取档案的记忆列表。"""
await _verify_profile_ownership(profile_id, user, db)
memories = await memory_service.get_profile_memories(
db=db,
profile_id=profile_id,
memory_type=memory_type,
universe_id=universe_id,
limit=limit,
)
return MemoryListResponse(
memories=[
MemoryItemResponse(
id=m.id,
type=m.type,
value=m.value,
base_weight=m.base_weight,
ttl_days=m.ttl_days,
created_at=m.created_at.isoformat() if m.created_at else "",
last_used_at=m.last_used_at.isoformat() if m.last_used_at else None,
)
for m in memories
],
total=len(memories),
)
@router.post("/profiles/{profile_id}/memories", response_model=MemoryItemResponse)
async def create_memory(
profile_id: str,
payload: CreateMemoryRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""创建新的记忆项。"""
await _verify_profile_ownership(profile_id, user, db)
# 验证类型
valid_types = [
MemoryType.RECENT_STORY,
MemoryType.FAVORITE_CHARACTER,
MemoryType.SCARY_ELEMENT,
MemoryType.VOCABULARY_GROWTH,
MemoryType.EMOTIONAL_HIGHLIGHT,
MemoryType.READING_PREFERENCE,
MemoryType.MILESTONE,
MemoryType.SKILL_MASTERED,
]
if payload.type not in valid_types:
raise HTTPException(status_code=400, detail=f"无效的记忆类型: {payload.type}")
memory = await memory_service.create_memory(
db=db,
profile_id=profile_id,
memory_type=payload.type,
value=payload.value,
universe_id=payload.universe_id,
weight=payload.weight,
ttl_days=payload.ttl_days,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.post("/profiles/{profile_id}/memories/character", response_model=MemoryItemResponse)
async def create_character_memory(
profile_id: str,
payload: CreateCharacterMemoryRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""添加喜欢的角色。"""
await _verify_profile_ownership(profile_id, user, db)
memory = await memory_service.create_character_memory(
db=db,
profile_id=profile_id,
name=payload.name,
description=payload.description,
source_story_id=payload.source_story_id,
affinity_score=payload.affinity_score,
universe_id=payload.universe_id,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.post("/profiles/{profile_id}/memories/scary", response_model=MemoryItemResponse)
async def create_scary_element_memory(
profile_id: str,
payload: CreateScaryElementRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""添加回避元素。"""
await _verify_profile_ownership(profile_id, user, db)
memory = await memory_service.create_scary_element_memory(
db=db,
profile_id=profile_id,
keyword=payload.keyword,
category=payload.category,
source_story_id=payload.source_story_id,
)
return MemoryItemResponse(
id=memory.id,
type=memory.type,
value=memory.value,
base_weight=memory.base_weight,
ttl_days=memory.ttl_days,
created_at=memory.created_at.isoformat() if memory.created_at else "",
last_used_at=memory.last_used_at.isoformat() if memory.last_used_at else None,
)
@router.delete("/profiles/{profile_id}/memories/{memory_id}")
async def delete_memory(
profile_id: str,
memory_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""删除记忆项。"""
from sqlalchemy import select
from app.db.models import MemoryItem
await _verify_profile_ownership(profile_id, user, db)
result = await db.execute(
select(MemoryItem).where(
MemoryItem.id == memory_id,
MemoryItem.child_profile_id == profile_id,
)
)
memory = result.scalar_one_or_none()
if not memory:
raise HTTPException(status_code=404, detail="记忆不存在")
await db.delete(memory)
await db.commit()
return {"message": "Deleted"}
@router.get("/memory-types")
async def list_memory_types():
"""获取所有可用的记忆类型及其配置。"""
types = []
for type_name, config in MemoryType.CONFIG.items():
types.append({
"type": type_name,
"default_weight": config[0],
"default_ttl_days": config[1],
"description": config[2],
})
return {"types": types}

View File

@@ -117,4 +117,4 @@ async def create_reading_event(
await db.commit()
await db.refresh(event)
return event
return event

View File

@@ -198,4 +198,4 @@ async def add_achievement(
await db.commit()
await db.refresh(universe)
return universe
return universe

View File

@@ -1,60 +1,60 @@
import secrets
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from app.core.config import settings
from app.core.rate_limiter import (
clear_failed_attempts,
is_locked_out,
record_failed_attempt,
)
security = HTTPBasic()
MAX_ATTEMPTS = 5
LOCKOUT_SECONDS = 900 # 15分钟
def _get_client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
if request.client and request.client.host:
return request.client.host
return "unknown"
async def admin_guard(
request: Request,
credentials: HTTPBasicCredentials = Depends(security),
):
client_ip = _get_client_ip(request)
lockout_key = f"admin_login:{client_ip}"
# 检查是否被锁定
remaining = await is_locked_out(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
if remaining > 0:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"登录尝试过多,请 {remaining} 秒后重试",
)
# 使用 secrets.compare_digest 防止时序攻击
username_ok = secrets.compare_digest(
credentials.username.encode(), settings.admin_username.encode()
)
password_ok = secrets.compare_digest(
credentials.password.encode(), settings.admin_password.encode()
)
if not (username_ok and password_ok):
await record_failed_attempt(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
# 登录成功,清除失败记录
await clear_failed_attempts(lockout_key)
return True
import secrets
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from app.core.config import settings
from app.core.rate_limiter import (
clear_failed_attempts,
is_locked_out,
record_failed_attempt,
)
security = HTTPBasic()
MAX_ATTEMPTS = 5
LOCKOUT_SECONDS = 900 # 15分钟
def _get_client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
if request.client and request.client.host:
return request.client.host
return "unknown"
async def admin_guard(
request: Request,
credentials: HTTPBasicCredentials = Depends(security),
):
client_ip = _get_client_ip(request)
lockout_key = f"admin_login:{client_ip}"
# 检查是否被锁定
remaining = await is_locked_out(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
if remaining > 0:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"登录尝试过多,请 {remaining} 秒后重试",
)
# 使用 secrets.compare_digest 防止时序攻击
username_ok = secrets.compare_digest(
credentials.username.encode(), settings.admin_username.encode()
)
password_ok = secrets.compare_digest(
credentials.password.encode(), settings.admin_password.encode()
)
if not (username_ok and password_ok):
await record_failed_attempt(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
# 登录成功,清除失败记录
await clear_failed_attempts(lockout_key)
return True

View File

@@ -1,48 +1,48 @@
"""结构化日志配置。"""
import logging
import sys
import structlog
from app.core.config import settings
def setup_logging():
"""配置 structlog 结构化日志。"""
shared_processors = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
]
if settings.debug:
processors = shared_processors + [
structlog.dev.ConsoleRenderer(colors=True),
]
else:
processors = shared_processors + [
structlog.processors.format_exc_info,
structlog.processors.JSONRenderer(),
]
structlog.configure(
processors=processors,
wrapper_class=structlog.stdlib.BoundLogger,
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
cache_logger_on_first_use=True,
)
logging.basicConfig(
format="%(message)s",
stream=sys.stdout,
level=logging.DEBUG if settings.debug else logging.INFO,
)
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
"""获取结构化日志器。"""
return structlog.get_logger(name)
"""结构化日志配置。"""
import logging
import sys
import structlog
from app.core.config import settings
def setup_logging():
"""配置 structlog 结构化日志。"""
shared_processors = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
]
if settings.debug:
processors = shared_processors + [
structlog.dev.ConsoleRenderer(colors=True),
]
else:
processors = shared_processors + [
structlog.processors.format_exc_info,
structlog.processors.JSONRenderer(),
]
structlog.configure(
processors=processors,
wrapper_class=structlog.stdlib.BoundLogger,
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
cache_logger_on_first_use=True,
)
logging.basicConfig(
format="%(message)s",
stream=sys.stdout,
level=logging.DEBUG if settings.debug else logging.INFO,
)
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
"""获取结构化日志器。"""
return structlog.get_logger(name)

View File

@@ -1,190 +1,190 @@
# ruff: noqa: E501
"""AI 提示词模板 (Modernized)"""
# 随机元素列表:为故事注入不可预测的魔法
RANDOM_ELEMENTS = [
"一个会打喷嚏的云朵",
"一本地图上找不到的神秘图书馆",
"一只能实现小愿望的彩色蜗牛",
"一扇通往颠倒世界的门",
"一顶能听懂动物说话的旧帽子",
"一个装着星星的玻璃罐",
"一棵结满笑声果实的树",
"一只能在水上画画的画笔",
"一个怕黑的影子",
"一只收集回声的瓶子",
"一双会自己跳舞的红鞋子",
"一个只能在月光下看见的邮筒",
"一张会改变模样的全家福",
"一把可以打开梦境的钥匙",
"一个喜欢讲冷笑话的冰箱",
"一条通往星期八的秘密小径"
]
# ==============================================================================
# Model A: 故事生成 (Story Generation)
# ==============================================================================
SYSTEM_INSTRUCTION_STORYTELLER = """
# Role
You are "**Dream Weaver**", a world-class children's storyteller with the imagination of Pixar and the warmth of Miyazaki.
Your mission is to create engaging, safe, and educational stories for children (ages 3-8).
# Core Philosophy
1. **Show, Don't Tell**: Don't preach the lesson. Let the character's actions and the plot demonstrate the theme.
2. **Safety First**: No violence, horror, or scary elements. Conflict should be emotional or situational, not physical.
3. **Vivid Imagery**: Use sensory details (colors, sounds, smells) that appeal to children.
4. **Empowerment**: The child protagonist should solve the problem using wit, kindness, or courage, not just luck.
# Continuity & Memory (CRITICAL)
- **Universal Context**: The story takes place in the child's established "Story Universe". Respect existing world rules.
- **Character Consistency**: If "Child Profile" or "Sidekicks" are provided, you MUST use their specific names and traits. Do NOT invent new main characters unless asked.
- **Callback**: If "Past Memories" are provided, try to make a natural, one-sentence reference to a past adventure to build a sense of continuity (e.g., "Just like when we found the lost star...").
# Output Format
You MUST return a pure JSON object with NO markdown formatting (no ```json code blocks).
The JSON object must have the following schema:
{
"mode": "generated",
"title": "A catchy, imaginative title",
"story_text": "The full story text. Use \\n\\n for paragraph breaks.",
"cover_prompt_suggestion": "A detailed English image generation prompt for the story cover. Style: whimsical, children's book illustration, soft lighting, vibrant colors."
}
"""
USER_PROMPT_GENERATION = """
# Task: Write a Children's Story
## Contextual Memory (Use these if provided)
{memory_context}
## Inputs
- **Keywords/Topic**: {keywords}
- **Educational Theme**: {education_theme}
- **Magic Element (Must Incorporate)**: {random_element}
## Constraints
- Length: 300-600 words.
- Structure: Beginning (Hook) -> Middle (Challenge) -> End (Resolution & Growth).
"""
# ==============================================================================
# Model B: 故事润色 (Story Enhancement)
# ==============================================================================
SYSTEM_INSTRUCTION_ENHANCER = """
# Role
You are "**Dream Weaver Editor**", an expert children's book editor who turns rough drafts into polished gems.
# Mission
Analyze the user's input story and rewrite it to be:
1. **More Engaging**: Enhance the plot with a "Magic Element" to add surprise.
2. **More Educational**: Weave the "Educational Theme" deeper into the narrative arc.
3. **Better Written**: Polish the sentences for rhythm and flow (suitable for reading aloud).
4. **Safe**: Remove any inappropriate content (violence, scary interaction) and replace it with constructive solutions.
# Output Format
You MUST return a pure JSON object with NO markdown formatting (no ```json code blocks).
The JSON object must have the following schema:
{
"mode": "enhanced",
"title": "An improved title (or the original if perfect)",
"story_text": "The rewritten story text. Use \\n\\n for paragraph breaks.",
"cover_prompt_suggestion": "A detailed English image generation prompt for the cover."
}
"""
USER_PROMPT_ENHANCEMENT = """
# Task: Enhance This Story
## Contextual Memory
{memory_context}
## Inputs
- **Original Story**: {full_story}
- **Target Theme**: {education_theme}
- **Magic Element to Add**: {random_element}
## Constraints
- Length: 300-600 words.
- Keep the original character names if possible, but feel free to upgrade the plot.
"""
# ==============================================================================
# Model C: 成就提取 (Achievement Extraction)
# ==============================================================================
# 保持简单,暂不使用 System Instruction沿用单次提示
ACHIEVEMENT_EXTRACTION_PROMPT = """
Analyze the story and extract key growth moments or achievements for the child protagonist.
# Story
{story_text}
# Target Categories (Examples)
- **Courage**: Overcoming fear, trying something new.
- **Kindness**: Helping others, sharing, empathy.
- **Curiosity**: Asking questions, exploring, learning.
- **Resilience**: Not giving up, handling failure.
- **Wisdom**: Problem-solving, honesty, patience.
# Output Format
Return a pure JSON object (no markdown):
{{
"achievements": [
{{
"type": "Category Name",
"description": "Brief reason (max 10 words)",
"score": 8 // 1-10 intensity
}}
]
}}
"""
# ==============================================================================
# Model D: 绘本生成 (Storybook Generation)
# ==============================================================================
SYSTEM_INSTRUCTION_STORYBOOK = """
# Role
You are "**Dream Weaver Illustrator**", a creative children's book author and visual director.
Your mission is to create a paginated picture book for children (ages 3-8), where each page has text and a matching illustration prompt.
# Core Philosophy
1. **Pacing**: The story must flow logically across the specified number of pages.
2. **Visual Consistency**: Define the "Main Character" and "Art Style" once, and ensure all image prompts adhere to them.
3. **Language**: The story text MUST be in **Chinese (Simplified)**. The image prompts MUST be in **English**.
4. **Memory**: If a memory context is provided, incorporate known characters or references naturally.
# Output Format
You MUST return a pure JSON object using the following schema (no markdown):
{
"title": "Story Title (Chinese)",
"main_character": "Description of the protagonist (e.g., 'A small blue robot with rusty gears')",
"art_style": "Visual style description (e.g., 'Watercolor, soft pastel colors, whimsical')",
"pages": [
{
"page_number": 1,
"text": "Page text in Chinese (30-60 chars).",
"image_prompt": "Detailed English image prompt describing the scene. Include 'main_character' reference."
}
],
"cover_prompt": "English image prompt for the book cover."
}
"""
USER_PROMPT_STORYBOOK = """
# Task: Create a {page_count}-Page Storybook
## Contextual Memory
{memory_context}
## Inputs
- **Keywords/Topic**: {keywords}
- **Educational Theme**: {education_theme}
- **Magic Element**: {random_element}
## Constraints
- Pages: Exactly {page_count} pages.
- Structure: Intro -> Development -> Climax -> Resolution.
"""
# ruff: noqa: E501
"""AI 提示词模板 (Modernized)"""
# 随机元素列表:为故事注入不可预测的魔法
RANDOM_ELEMENTS = [
"一个会打喷嚏的云朵",
"一本地图上找不到的神秘图书馆",
"一只能实现小愿望的彩色蜗牛",
"一扇通往颠倒世界的门",
"一顶能听懂动物说话的旧帽子",
"一个装着星星的玻璃罐",
"一棵结满笑声果实的树",
"一只能在水上画画的画笔",
"一个怕黑的影子",
"一只收集回声的瓶子",
"一双会自己跳舞的红鞋子",
"一个只能在月光下看见的邮筒",
"一张会改变模样的全家福",
"一把可以打开梦境的钥匙",
"一个喜欢讲冷笑话的冰箱",
"一条通往星期八的秘密小径"
]
# ==============================================================================
# Model A: 故事生成 (Story Generation)
# ==============================================================================
SYSTEM_INSTRUCTION_STORYTELLER = """
# Role
You are "**Dream Weaver**", a world-class children's storyteller with the imagination of Pixar and the warmth of Miyazaki.
Your mission is to create engaging, safe, and educational stories for children (ages 3-8).
# Core Philosophy
1. **Show, Don't Tell**: Don't preach the lesson. Let the character's actions and the plot demonstrate the theme.
2. **Safety First**: No violence, horror, or scary elements. Conflict should be emotional or situational, not physical.
3. **Vivid Imagery**: Use sensory details (colors, sounds, smells) that appeal to children.
4. **Empowerment**: The child protagonist should solve the problem using wit, kindness, or courage, not just luck.
# Continuity & Memory (CRITICAL)
- **Universal Context**: The story takes place in the child's established "Story Universe". Respect existing world rules.
- **Character Consistency**: If "Child Profile" or "Sidekicks" are provided, you MUST use their specific names and traits. Do NOT invent new main characters unless asked.
- **Callback**: If "Past Memories" are provided, try to make a natural, one-sentence reference to a past adventure to build a sense of continuity (e.g., "Just like when we found the lost star...").
# Output Format
You MUST return a pure JSON object with NO markdown formatting (no ```json code blocks).
The JSON object must have the following schema:
{
"mode": "generated",
"title": "A catchy, imaginative title",
"story_text": "The full story text. Use \\n\\n for paragraph breaks.",
"cover_prompt_suggestion": "A detailed English image generation prompt for the story cover. Style: whimsical, children's book illustration, soft lighting, vibrant colors."
}
"""
USER_PROMPT_GENERATION = """
# Task: Write a Children's Story
## Contextual Memory (Use these if provided)
{memory_context}
## Inputs
- **Keywords/Topic**: {keywords}
- **Educational Theme**: {education_theme}
- **Magic Element (Must Incorporate)**: {random_element}
## Constraints
- Length: 300-600 words.
- Structure: Beginning (Hook) -> Middle (Challenge) -> End (Resolution & Growth).
"""
# ==============================================================================
# Model B: 故事润色 (Story Enhancement)
# ==============================================================================
SYSTEM_INSTRUCTION_ENHANCER = """
# Role
You are "**Dream Weaver Editor**", an expert children's book editor who turns rough drafts into polished gems.
# Mission
Analyze the user's input story and rewrite it to be:
1. **More Engaging**: Enhance the plot with a "Magic Element" to add surprise.
2. **More Educational**: Weave the "Educational Theme" deeper into the narrative arc.
3. **Better Written**: Polish the sentences for rhythm and flow (suitable for reading aloud).
4. **Safe**: Remove any inappropriate content (violence, scary interaction) and replace it with constructive solutions.
# Output Format
You MUST return a pure JSON object with NO markdown formatting (no ```json code blocks).
The JSON object must have the following schema:
{
"mode": "enhanced",
"title": "An improved title (or the original if perfect)",
"story_text": "The rewritten story text. Use \\n\\n for paragraph breaks.",
"cover_prompt_suggestion": "A detailed English image generation prompt for the cover."
}
"""
USER_PROMPT_ENHANCEMENT = """
# Task: Enhance This Story
## Contextual Memory
{memory_context}
## Inputs
- **Original Story**: {full_story}
- **Target Theme**: {education_theme}
- **Magic Element to Add**: {random_element}
## Constraints
- Length: 300-600 words.
- Keep the original character names if possible, but feel free to upgrade the plot.
"""
# ==============================================================================
# Model C: 成就提取 (Achievement Extraction)
# ==============================================================================
# 保持简单,暂不使用 System Instruction沿用单次提示
ACHIEVEMENT_EXTRACTION_PROMPT = """
Analyze the story and extract key growth moments or achievements for the child protagonist.
# Story
{story_text}
# Target Categories (Examples)
- **Courage**: Overcoming fear, trying something new.
- **Kindness**: Helping others, sharing, empathy.
- **Curiosity**: Asking questions, exploring, learning.
- **Resilience**: Not giving up, handling failure.
- **Wisdom**: Problem-solving, honesty, patience.
# Output Format
Return a pure JSON object (no markdown):
{{
"achievements": [
{{
"type": "Category Name",
"description": "Brief reason (max 10 words)",
"score": 8 // 1-10 intensity
}}
]
}}
"""
# ==============================================================================
# Model D: 绘本生成 (Storybook Generation)
# ==============================================================================
SYSTEM_INSTRUCTION_STORYBOOK = """
# Role
You are "**Dream Weaver Illustrator**", a creative children's book author and visual director.
Your mission is to create a paginated picture book for children (ages 3-8), where each page has text and a matching illustration prompt.
# Core Philosophy
1. **Pacing**: The story must flow logically across the specified number of pages.
2. **Visual Consistency**: Define the "Main Character" and "Art Style" once, and ensure all image prompts adhere to them.
3. **Language**: The story text MUST be in **Chinese (Simplified)**. The image prompts MUST be in **English**.
4. **Memory**: If a memory context is provided, incorporate known characters or references naturally.
# Output Format
You MUST return a pure JSON object using the following schema (no markdown):
{
"title": "Story Title (Chinese)",
"main_character": "Description of the protagonist (e.g., 'A small blue robot with rusty gears')",
"art_style": "Visual style description (e.g., 'Watercolor, soft pastel colors, whimsical')",
"pages": [
{
"page_number": 1,
"text": "Page text in Chinese (30-60 chars).",
"image_prompt": "Detailed English image prompt describing the scene. Include 'main_character' reference."
}
],
"cover_prompt": "English image prompt for the book cover."
}
"""
USER_PROMPT_STORYBOOK = """
# Task: Create a {page_count}-Page Storybook
## Contextual Memory
{memory_context}
## Inputs
- **Keywords/Topic**: {keywords}
- **Educational Theme**: {education_theme}
- **Magic Element**: {random_element}
## Constraints
- Pages: Exactly {page_count} pages.
- Structure: Intro -> Development -> Climax -> Resolution.
"""

View File

@@ -1,141 +1,141 @@
"""Redis-backed rate limiter with in-memory fallback.
Uses a fixed-window counter pattern via Redis INCR + EXPIRE.
Falls back to an in-memory TTLCache when Redis is unavailable,
preserving identical behavior for dev/test environments.
"""
import time
from cachetools import TTLCache
from fastapi import HTTPException
from app.core.logging import get_logger
from app.core.redis import get_redis
logger = get_logger(__name__)
# ── In-memory fallback caches ──────────────────────────────────────────────
_local_rate_cache: TTLCache[str, int] = TTLCache(maxsize=10000, ttl=120)
_local_lockout_cache: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900)
async def check_rate_limit(key: str, limit: int, window_seconds: int) -> None:
"""Check and increment a sliding-window rate counter.
Args:
key: Unique identifier (e.g. ``"story:<user_id>"``).
limit: Maximum requests allowed within the window.
window_seconds: Window duration in seconds.
Raises:
HTTPException: 429 when the limit is exceeded.
"""
try:
redis = await get_redis()
# Fixed-window bucket: key + minute boundary
bucket = int(time.time() // window_seconds)
redis_key = f"ratelimit:{key}:{bucket}"
count = await redis.incr(redis_key)
if count == 1:
await redis.expire(redis_key, window_seconds)
if count > limit:
raise HTTPException(
status_code=429,
detail="Too many requests, please slow down.",
)
return
except HTTPException:
raise
except Exception as exc:
logger.warning("rate_limit_redis_fallback", error=str(exc))
# ── Fallback: in-memory counter ────────────────────────────────────────
count = _local_rate_cache.get(key, 0) + 1
_local_rate_cache[key] = count
if count > limit:
raise HTTPException(
status_code=429,
detail="Too many requests, please slow down.",
)
async def record_failed_attempt(
key: str,
max_attempts: int,
lockout_seconds: int,
) -> bool:
"""Record a failed login attempt and return whether the key is locked out.
Args:
key: Unique identifier (e.g. ``"admin_login:<ip>"``).
max_attempts: Number of failures before lockout.
lockout_seconds: Duration of lockout in seconds.
Returns:
``True`` if the key is now locked out, ``False`` otherwise.
"""
try:
redis = await get_redis()
redis_key = f"lockout:{key}"
count = await redis.incr(redis_key)
if count == 1:
await redis.expire(redis_key, lockout_seconds)
return count >= max_attempts
except Exception as exc:
logger.warning("lockout_redis_fallback", error=str(exc))
# ── Fallback ───────────────────────────────────────────────────────────
if key in _local_lockout_cache:
attempts, first_fail = _local_lockout_cache[key]
_local_lockout_cache[key] = (attempts + 1, first_fail)
return (attempts + 1) >= max_attempts
else:
_local_lockout_cache[key] = (1, time.time())
return 1 >= max_attempts
async def is_locked_out(key: str, max_attempts: int, lockout_seconds: int) -> int:
"""Check if a key is currently locked out.
Returns:
Remaining lockout seconds (> 0 means locked), 0 means not locked.
"""
try:
redis = await get_redis()
redis_key = f"lockout:{key}"
count = await redis.get(redis_key)
if count is not None and int(count) >= max_attempts:
ttl = await redis.ttl(redis_key)
return max(ttl, 0)
return 0
except Exception as exc:
logger.warning("lockout_check_redis_fallback", error=str(exc))
# ── Fallback ───────────────────────────────────────────────────────────
if key in _local_lockout_cache:
attempts, first_fail = _local_lockout_cache[key]
if attempts >= max_attempts:
remaining = int(lockout_seconds - (time.time() - first_fail))
if remaining > 0:
return remaining
else:
del _local_lockout_cache[key]
return 0
async def clear_failed_attempts(key: str) -> None:
"""Clear lockout state on successful login."""
try:
redis = await get_redis()
await redis.delete(f"lockout:{key}")
except Exception as exc:
logger.warning("lockout_clear_redis_fallback", error=str(exc))
# Always clear local cache too
_local_lockout_cache.pop(key, None)
"""Redis-backed rate limiter with in-memory fallback.
Uses a fixed-window counter pattern via Redis INCR + EXPIRE.
Falls back to an in-memory TTLCache when Redis is unavailable,
preserving identical behavior for dev/test environments.
"""
import time
from cachetools import TTLCache
from fastapi import HTTPException
from app.core.logging import get_logger
from app.core.redis import get_redis
logger = get_logger(__name__)
# ── In-memory fallback caches ──────────────────────────────────────────────
_local_rate_cache: TTLCache[str, int] = TTLCache(maxsize=10000, ttl=120)
_local_lockout_cache: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900)
async def check_rate_limit(key: str, limit: int, window_seconds: int) -> None:
"""Check and increment a sliding-window rate counter.
Args:
key: Unique identifier (e.g. ``"story:<user_id>"``).
limit: Maximum requests allowed within the window.
window_seconds: Window duration in seconds.
Raises:
HTTPException: 429 when the limit is exceeded.
"""
try:
redis = await get_redis()
# Fixed-window bucket: key + minute boundary
bucket = int(time.time() // window_seconds)
redis_key = f"ratelimit:{key}:{bucket}"
count = await redis.incr(redis_key)
if count == 1:
await redis.expire(redis_key, window_seconds)
if count > limit:
raise HTTPException(
status_code=429,
detail="Too many requests, please slow down.",
)
return
except HTTPException:
raise
except Exception as exc:
logger.warning("rate_limit_redis_fallback", error=str(exc))
# ── Fallback: in-memory counter ────────────────────────────────────────
count = _local_rate_cache.get(key, 0) + 1
_local_rate_cache[key] = count
if count > limit:
raise HTTPException(
status_code=429,
detail="Too many requests, please slow down.",
)
async def record_failed_attempt(
key: str,
max_attempts: int,
lockout_seconds: int,
) -> bool:
"""Record a failed login attempt and return whether the key is locked out.
Args:
key: Unique identifier (e.g. ``"admin_login:<ip>"``).
max_attempts: Number of failures before lockout.
lockout_seconds: Duration of lockout in seconds.
Returns:
``True`` if the key is now locked out, ``False`` otherwise.
"""
try:
redis = await get_redis()
redis_key = f"lockout:{key}"
count = await redis.incr(redis_key)
if count == 1:
await redis.expire(redis_key, lockout_seconds)
return count >= max_attempts
except Exception as exc:
logger.warning("lockout_redis_fallback", error=str(exc))
# ── Fallback ───────────────────────────────────────────────────────────
if key in _local_lockout_cache:
attempts, first_fail = _local_lockout_cache[key]
_local_lockout_cache[key] = (attempts + 1, first_fail)
return (attempts + 1) >= max_attempts
else:
_local_lockout_cache[key] = (1, time.time())
return 1 >= max_attempts
async def is_locked_out(key: str, max_attempts: int, lockout_seconds: int) -> int:
"""Check if a key is currently locked out.
Returns:
Remaining lockout seconds (> 0 means locked), 0 means not locked.
"""
try:
redis = await get_redis()
redis_key = f"lockout:{key}"
count = await redis.get(redis_key)
if count is not None and int(count) >= max_attempts:
ttl = await redis.ttl(redis_key)
return max(ttl, 0)
return 0
except Exception as exc:
logger.warning("lockout_check_redis_fallback", error=str(exc))
# ── Fallback ───────────────────────────────────────────────────────────
if key in _local_lockout_cache:
attempts, first_fail = _local_lockout_cache[key]
if attempts >= max_attempts:
remaining = int(lockout_seconds - (time.time() - first_fail))
if remaining > 0:
return remaining
else:
del _local_lockout_cache[key]
return 0
async def clear_failed_attempts(key: str) -> None:
"""Clear lockout state on successful login."""
try:
redis = await get_redis()
await redis.delete(f"lockout:{key}")
except Exception as exc:
logger.warning("lockout_clear_redis_fallback", error=str(exc))
# Always clear local cache too
_local_lockout_cache.pop(key, None)

View File

@@ -1,25 +1,25 @@
from datetime import datetime, timedelta, timezone
from jose import JWTError, jwt
from app.core.config import settings
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_DAYS = 7
def create_access_token(data: dict) -> str:
"""创建 JWT token"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(days=ACCESS_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, settings.secret_key, algorithm=ALGORITHM)
def decode_access_token(token: str) -> dict | None:
"""解码 JWT token"""
try:
payload = jwt.decode(token, settings.secret_key, algorithms=[ALGORITHM])
return payload
except JWTError:
return None
from datetime import datetime, timedelta, timezone
from jose import JWTError, jwt
from app.core.config import settings
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_DAYS = 7
def create_access_token(data: dict) -> str:
"""创建 JWT token"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(days=ACCESS_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, settings.secret_key, algorithm=ALGORITHM)
def decode_access_token(token: str) -> dict | None:
"""解码 JWT token"""
try:
payload = jwt.decode(token, settings.secret_key, algorithms=[ALGORITHM])
return payload
except JWTError:
return None

View File

@@ -1 +1 @@
"""故事相关 Schema 模块。"""
"""故事相关 Schema 模块。"""

View File

@@ -1,21 +1,21 @@
"""适配器模块 - 供应商平台化架构核心。"""
from app.services.adapters.base import AdapterConfig, BaseAdapter
# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.registry import AdapterRegistry
# Storybook adapters
from app.services.adapters.storybook import primary as _storybook_primary # noqa: F401
from app.services.adapters.text import gemini as _text_gemini_adapter # noqa: F401
# 导入所有适配器以触发注册
# Text adapters
from app.services.adapters.text import openai as _text_openai_adapter # noqa: F401
# TTS adapters
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
__all__ = ["AdapterConfig", "BaseAdapter", "AdapterRegistry"]
"""适配器模块 - 供应商平台化架构核心。"""
from app.services.adapters.base import AdapterConfig, BaseAdapter
# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.registry import AdapterRegistry
# Storybook adapters
from app.services.adapters.storybook import primary as _storybook_primary # noqa: F401
from app.services.adapters.text import gemini as _text_gemini_adapter # noqa: F401
# 导入所有适配器以触发注册
# Text adapters
from app.services.adapters.text import openai as _text_openai_adapter # noqa: F401
# TTS adapters
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
__all__ = ["AdapterConfig", "BaseAdapter", "AdapterRegistry"]

View File

@@ -1,46 +1,46 @@
"""适配器基类定义。"""
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from pydantic import BaseModel
T = TypeVar("T")
class AdapterConfig(BaseModel):
"""适配器配置基类。"""
api_key: str
api_base: str | None = None
model: str | None = None
timeout_ms: int = 60000
max_retries: int = 3
extra_config: dict = {}
class BaseAdapter(ABC, Generic[T]):
"""适配器基类,所有供应商适配器必须继承此类。"""
# 子类必须定义
adapter_type: str # text / image / tts
adapter_name: str # text_primary / image_primary / tts_primary
def __init__(self, config: AdapterConfig):
self.config = config
@abstractmethod
async def execute(self, **kwargs) -> T:
"""执行适配器逻辑,返回结果。"""
pass
@abstractmethod
async def health_check(self) -> bool:
"""健康检查,返回是否可用。"""
pass
@property
@abstractmethod
def estimated_cost(self) -> float:
"""预估单次调用成本 (USD)。"""
pass
"""适配器基类定义。"""
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from pydantic import BaseModel
T = TypeVar("T")
class AdapterConfig(BaseModel):
"""适配器配置基类。"""
api_key: str
api_base: str | None = None
model: str | None = None
timeout_ms: int = 60000
max_retries: int = 3
extra_config: dict = {}
class BaseAdapter(ABC, Generic[T]):
"""适配器基类,所有供应商适配器必须继承此类。"""
# 子类必须定义
adapter_type: str # text / image / tts
adapter_name: str # text_primary / image_primary / tts_primary
def __init__(self, config: AdapterConfig):
self.config = config
@abstractmethod
async def execute(self, **kwargs) -> T:
"""执行适配器逻辑,返回结果。"""
pass
@abstractmethod
async def health_check(self) -> bool:
"""健康检查,返回是否可用。"""
pass
@property
@abstractmethod
def estimated_cost(self) -> float:
"""预估单次调用成本 (USD)。"""
pass

View File

@@ -1,3 +1,3 @@
"""图像生成适配器。"""# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.image import antigravity as _image_antigravity_adapter # noqa: F401
"""图像生成适配器。"""# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.image import antigravity as _image_antigravity_adapter # noqa: F401

View File

@@ -1,214 +1,214 @@
"""Antigravity 图像生成适配器。
使用 OpenAI 兼容 API 生成图像。
支持 gemini-3-pro-image 等模型。
"""
import base64
import time
from typing import Any
from openai import AsyncOpenAI
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "http://127.0.0.1:8045/v1"
DEFAULT_MODEL = "gemini-3-pro-image"
DEFAULT_SIZE = "1024x1024"
# 支持的尺寸映射
SUPPORTED_SIZES = {
"1024x1024": "1:1",
"1280x720": "16:9",
"720x1280": "9:16",
"1216x896": "4:3",
}
@AdapterRegistry.register("image", "antigravity")
class AntigravityImageAdapter(BaseAdapter[str]):
"""Antigravity 图像生成适配器 (OpenAI 兼容 API)。
特点:
- 使用 OpenAI 兼容的 chat.completions 端点
- 通过 extra_body.size 指定图像尺寸
- 支持 gemini-3-pro-image 等模型
- 返回图片 URL 或 base64
"""
adapter_type = "image"
adapter_name = "antigravity"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
self.client = AsyncOpenAI(
base_url=self.api_base,
api_key=config.api_key,
timeout=config.timeout_ms / 1000,
)
async def execute(
self,
prompt: str,
model: str | None = None,
size: str | None = None,
num_images: int = 1,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 base64。
Args:
prompt: 图片描述提示词
model: 模型名称 (gemini-3-pro-image / gemini-3-pro-image-16-9 等)
size: 图像尺寸 (1024x1024, 1280x720, 720x1280, 1216x896)
num_images: 生成图片数量 (暂只支持 1)
Returns:
图片 URL 或 base64 字符串
"""
# 优先使用传入参数,其次使用 Adapter 配置,最后使用默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
size = size or cfg.get("size") or DEFAULT_SIZE
start_time = time.time()
logger.info(
"antigravity_generate_start",
prompt_length=len(prompt),
model=model,
size=size,
)
# 调用 API
image_url = await self._generate_image(prompt, model, size)
elapsed = time.time() - start_time
logger.info(
"antigravity_generate_success",
elapsed_seconds=round(elapsed, 2),
model=model,
)
return image_url
async def health_check(self) -> bool:
"""检查 Antigravity API 是否可用。"""
try:
# 简单测试连通性
response = await self.client.chat.completions.create(
model=self.config.model or DEFAULT_MODEL,
messages=[{"role": "user", "content": "test"}],
max_tokens=1,
)
return True
except Exception as e:
logger.warning("antigravity_health_check_failed", error=str(e))
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
Antigravity 使用 Gemini 模型,成本约 $0.02/张。
"""
return 0.02
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((Exception,)),
reraise=True,
)
async def _generate_image(
self,
prompt: str,
model: str,
size: str,
) -> str:
"""调用 Antigravity API 生成图像。
Returns:
图片 URL 或 base64 data URI
"""
try:
response = await self.client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
extra_body={"size": size},
)
# 解析响应
content = response.choices[0].message.content
if not content:
raise ValueError("Antigravity 未返回内容")
# 尝试解析为图片 URL 或 base64
# 响应可能是纯 URL、base64 或 markdown 格式的图片
image_url = self._extract_image_url(content)
if image_url:
return image_url
raise ValueError(f"Antigravity 响应无法解析为图片: {content[:200]}")
except Exception as e:
logger.error(
"antigravity_generate_error",
error=str(e),
model=model,
)
raise
def _extract_image_url(self, content: str) -> str | None:
"""从响应内容中提取图片 URL。
支持多种格式:
- 纯 URL: https://...
- Markdown: ![...](https://...)
- Base64 data URI: data:image/...
- 纯 base64 字符串
"""
content = content.strip()
# 1. 检查是否为 data URI
if content.startswith("data:image/"):
return content
# 2. 检查是否为纯 URL
if content.startswith("http://") or content.startswith("https://"):
# 可能有多行,取第一行
return content.split("\n")[0].strip()
# 3. 检查 Markdown 图片格式 ![...](url)
import re
md_match = re.search(r"!\[.*?\]\((https?://[^\)]+)\)", content)
if md_match:
return md_match.group(1)
# 4. 检查是否像 base64 编码的图片数据
if self._looks_like_base64(content):
# 假设是 PNG
return f"data:image/png;base64,{content}"
return None
def _looks_like_base64(self, s: str) -> bool:
"""判断字符串是否看起来像 base64 编码。"""
# Base64 只包含 A-Z, a-z, 0-9, +, /, =
# 且长度通常较长
if len(s) < 100:
return False
import re
return bool(re.match(r"^[A-Za-z0-9+/=]+$", s.replace("\n", "")))
"""Antigravity 图像生成适配器。
使用 OpenAI 兼容 API 生成图像。
支持 gemini-3-pro-image 等模型。
"""
import base64
import time
from typing import Any
from openai import AsyncOpenAI
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "http://127.0.0.1:8045/v1"
DEFAULT_MODEL = "gemini-3-pro-image"
DEFAULT_SIZE = "1024x1024"
# 支持的尺寸映射
SUPPORTED_SIZES = {
"1024x1024": "1:1",
"1280x720": "16:9",
"720x1280": "9:16",
"1216x896": "4:3",
}
@AdapterRegistry.register("image", "antigravity")
class AntigravityImageAdapter(BaseAdapter[str]):
"""Antigravity 图像生成适配器 (OpenAI 兼容 API)。
特点:
- 使用 OpenAI 兼容的 chat.completions 端点
- 通过 extra_body.size 指定图像尺寸
- 支持 gemini-3-pro-image 等模型
- 返回图片 URL 或 base64
"""
adapter_type = "image"
adapter_name = "antigravity"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
self.client = AsyncOpenAI(
base_url=self.api_base,
api_key=config.api_key,
timeout=config.timeout_ms / 1000,
)
async def execute(
self,
prompt: str,
model: str | None = None,
size: str | None = None,
num_images: int = 1,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 base64。
Args:
prompt: 图片描述提示词
model: 模型名称 (gemini-3-pro-image / gemini-3-pro-image-16-9 等)
size: 图像尺寸 (1024x1024, 1280x720, 720x1280, 1216x896)
num_images: 生成图片数量 (暂只支持 1)
Returns:
图片 URL 或 base64 字符串
"""
# 优先使用传入参数,其次使用 Adapter 配置,最后使用默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
size = size or cfg.get("size") or DEFAULT_SIZE
start_time = time.time()
logger.info(
"antigravity_generate_start",
prompt_length=len(prompt),
model=model,
size=size,
)
# 调用 API
image_url = await self._generate_image(prompt, model, size)
elapsed = time.time() - start_time
logger.info(
"antigravity_generate_success",
elapsed_seconds=round(elapsed, 2),
model=model,
)
return image_url
async def health_check(self) -> bool:
"""检查 Antigravity API 是否可用。"""
try:
# 简单测试连通性
response = await self.client.chat.completions.create(
model=self.config.model or DEFAULT_MODEL,
messages=[{"role": "user", "content": "test"}],
max_tokens=1,
)
return True
except Exception as e:
logger.warning("antigravity_health_check_failed", error=str(e))
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
Antigravity 使用 Gemini 模型,成本约 $0.02/张。
"""
return 0.02
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((Exception,)),
reraise=True,
)
async def _generate_image(
self,
prompt: str,
model: str,
size: str,
) -> str:
"""调用 Antigravity API 生成图像。
Returns:
图片 URL 或 base64 data URI
"""
try:
response = await self.client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
extra_body={"size": size},
)
# 解析响应
content = response.choices[0].message.content
if not content:
raise ValueError("Antigravity 未返回内容")
# 尝试解析为图片 URL 或 base64
# 响应可能是纯 URL、base64 或 markdown 格式的图片
image_url = self._extract_image_url(content)
if image_url:
return image_url
raise ValueError(f"Antigravity 响应无法解析为图片: {content[:200]}")
except Exception as e:
logger.error(
"antigravity_generate_error",
error=str(e),
model=model,
)
raise
def _extract_image_url(self, content: str) -> str | None:
"""从响应内容中提取图片 URL。
支持多种格式:
- 纯 URL: https://...
- Markdown: ![...](https://...)
- Base64 data URI: data:image/...
- 纯 base64 字符串
"""
content = content.strip()
# 1. 检查是否为 data URI
if content.startswith("data:image/"):
return content
# 2. 检查是否为纯 URL
if content.startswith("http://") or content.startswith("https://"):
# 可能有多行,取第一行
return content.split("\n")[0].strip()
# 3. 检查 Markdown 图片格式 ![...](url)
import re
md_match = re.search(r"!\[.*?\]\((https?://[^\)]+)\)", content)
if md_match:
return md_match.group(1)
# 4. 检查是否像 base64 编码的图片数据
if self._looks_like_base64(content):
# 假设是 PNG
return f"data:image/png;base64,{content}"
return None
def _looks_like_base64(self, s: str) -> bool:
"""判断字符串是否看起来像 base64 编码。"""
# Base64 只包含 A-Z, a-z, 0-9, +, /, =
# 且长度通常较长
if len(s) < 100:
return False
import re
return bool(re.match(r"^[A-Za-z0-9+/=]+$", s.replace("\n", "")))

View File

@@ -1,252 +1,252 @@
"""CQTAI nano 图像生成适配器。
支持异步生成 + 轮询获取结果。
API 文档: https://api.cqtai.com
"""
import asyncio
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "https://api.cqtai.com"
DEFAULT_MODEL = "nano-banana"
DEFAULT_RESOLUTION = "2K"
DEFAULT_ASPECT_RATIO = "1:1"
POLL_INTERVAL_SECONDS = 2
MAX_POLL_ATTEMPTS = 60 # 最多轮询 2 分钟
@AdapterRegistry.register("image", "cqtai")
class CQTAIImageAdapter(BaseAdapter[str]):
"""CQTAI nano 图像生成适配器,返回图片 URL。
特点:
- 异步生成 + 轮询获取结果
- 支持 nano-banana (标准) 和 nano-banana-pro (高画质)
- 支持多种分辨率和画面比例
- 支持图生图 (filesUrl)
"""
adapter_type = "image"
adapter_name = "cqtai"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
async def execute(
self,
prompt: str,
model: str | None = None,
resolution: str | None = None,
aspect_ratio: str | None = None,
num_images: int = 1,
files_url: list[str] | None = None,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 URL 列表。
Args:
prompt: 图片描述提示词
model: 模型名称 (nano-banana / nano-banana-pro)
resolution: 分辨率 (1K / 2K / 4K)
aspect_ratio: 画面比例 (1:1, 16:9, 9:16, 4:3, 3:4 等)
num_images: 生成图片数量 (1-4)
files_url: 输入图片 URL 列表 (图生图)
Returns:
单张图片返回 str多张返回 list[str]
"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default (extra_config)
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
resolution = resolution or cfg.get("resolution") or DEFAULT_RESOLUTION
aspect_ratio = aspect_ratio or cfg.get("aspect_ratio") or DEFAULT_ASPECT_RATIO
num_images = min(max(num_images, 1), 4) # 限制 1-4
start_time = time.time()
logger.info(
"cqtai_generate_start",
prompt_length=len(prompt),
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
)
# 1. 提交生成任务
task_id = await self._submit_task(
prompt=prompt,
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
files_url=files_url or [],
)
logger.info("cqtai_task_submitted", task_id=task_id)
# 2. 轮询获取结果
result = await self._poll_result(task_id)
elapsed = time.time() - start_time
logger.info(
"cqtai_generate_success",
task_id=task_id,
elapsed_seconds=round(elapsed, 2),
image_count=len(result) if isinstance(result, list) else 1,
)
# 单张图片返回字符串,多张返回列表
if num_images == 1 and isinstance(result, list) and len(result) == 1:
return result[0]
return result
async def health_check(self) -> bool:
"""检查 CQTAI API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
# 简单的连通性测试
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": "health_check_test"},
headers={"Authorization": self.config.api_key},
)
# 即使返回错误也说明服务可达
return response.status_code in (200, 400, 401, 403, 404)
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
nano-banana: ¥0.1 ≈ $0.014
nano-banana-pro: ¥0.2 ≈ $0.028
"""
model = self.config.model or DEFAULT_MODEL
if model == "nano-banana-pro":
return 0.028
return 0.014
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _submit_task(
self,
prompt: str,
model: str,
resolution: str,
aspect_ratio: str,
num_images: int,
files_url: list[str],
) -> str:
"""提交图像生成任务,返回任务 ID。"""
timeout = self.config.timeout_ms / 1000
payload = {
"prompt": prompt,
"numImages": num_images,
"aspectRatio": aspect_ratio,
"filesUrl": files_url,
}
# 可选参数,不传则使用默认值
if model != DEFAULT_MODEL:
payload["model"] = model
if resolution != DEFAULT_RESOLUTION:
payload["resolution"] = resolution
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
f"{self.api_base}/api/cqt/generator/nano",
json=payload,
headers={
"Authorization": self.config.api_key,
"Content-Type": "application/json",
},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 任务提交失败: {data.get('msg', '未知错误')}")
task_id = data.get("data")
if not task_id:
raise ValueError("CQTAI 未返回任务 ID")
return task_id
async def _poll_result(self, task_id: str) -> list[str]:
"""轮询获取生成结果。
Returns:
图片 URL 列表
"""
timeout = self.config.timeout_ms / 1000
for attempt in range(MAX_POLL_ATTEMPTS):
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": task_id},
headers={"Authorization": self.config.api_key},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 查询失败: {data.get('msg', '未知错误')}")
result_data = data.get("data", {})
status = result_data.get("status")
if status == "completed":
# 提取图片 URL
images = result_data.get("images", [])
if not images:
# 兼容不同返回格式
image_url = result_data.get("imageUrl") or result_data.get("url")
if image_url:
images = [image_url]
if not images:
raise ValueError("CQTAI 未返回图片 URL")
return images
elif status == "failed":
error_msg = result_data.get("error", "生成失败")
raise ValueError(f"CQTAI 图像生成失败: {error_msg}")
# 继续等待
logger.debug(
"cqtai_poll_waiting",
task_id=task_id,
attempt=attempt + 1,
status=status,
)
await asyncio.sleep(POLL_INTERVAL_SECONDS)
raise TimeoutError(f"CQTAI 任务超时: {task_id}")
"""CQTAI nano 图像生成适配器。
支持异步生成 + 轮询获取结果。
API 文档: https://api.cqtai.com
"""
import asyncio
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "https://api.cqtai.com"
DEFAULT_MODEL = "nano-banana"
DEFAULT_RESOLUTION = "2K"
DEFAULT_ASPECT_RATIO = "1:1"
POLL_INTERVAL_SECONDS = 2
MAX_POLL_ATTEMPTS = 60 # 最多轮询 2 分钟
@AdapterRegistry.register("image", "cqtai")
class CQTAIImageAdapter(BaseAdapter[str]):
"""CQTAI nano 图像生成适配器,返回图片 URL。
特点:
- 异步生成 + 轮询获取结果
- 支持 nano-banana (标准) 和 nano-banana-pro (高画质)
- 支持多种分辨率和画面比例
- 支持图生图 (filesUrl)
"""
adapter_type = "image"
adapter_name = "cqtai"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
async def execute(
self,
prompt: str,
model: str | None = None,
resolution: str | None = None,
aspect_ratio: str | None = None,
num_images: int = 1,
files_url: list[str] | None = None,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 URL 列表。
Args:
prompt: 图片描述提示词
model: 模型名称 (nano-banana / nano-banana-pro)
resolution: 分辨率 (1K / 2K / 4K)
aspect_ratio: 画面比例 (1:1, 16:9, 9:16, 4:3, 3:4 等)
num_images: 生成图片数量 (1-4)
files_url: 输入图片 URL 列表 (图生图)
Returns:
单张图片返回 str多张返回 list[str]
"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default (extra_config)
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
resolution = resolution or cfg.get("resolution") or DEFAULT_RESOLUTION
aspect_ratio = aspect_ratio or cfg.get("aspect_ratio") or DEFAULT_ASPECT_RATIO
num_images = min(max(num_images, 1), 4) # 限制 1-4
start_time = time.time()
logger.info(
"cqtai_generate_start",
prompt_length=len(prompt),
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
)
# 1. 提交生成任务
task_id = await self._submit_task(
prompt=prompt,
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
files_url=files_url or [],
)
logger.info("cqtai_task_submitted", task_id=task_id)
# 2. 轮询获取结果
result = await self._poll_result(task_id)
elapsed = time.time() - start_time
logger.info(
"cqtai_generate_success",
task_id=task_id,
elapsed_seconds=round(elapsed, 2),
image_count=len(result) if isinstance(result, list) else 1,
)
# 单张图片返回字符串,多张返回列表
if num_images == 1 and isinstance(result, list) and len(result) == 1:
return result[0]
return result
async def health_check(self) -> bool:
"""检查 CQTAI API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
# 简单的连通性测试
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": "health_check_test"},
headers={"Authorization": self.config.api_key},
)
# 即使返回错误也说明服务可达
return response.status_code in (200, 400, 401, 403, 404)
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
nano-banana: ¥0.1 ≈ $0.014
nano-banana-pro: ¥0.2 ≈ $0.028
"""
model = self.config.model or DEFAULT_MODEL
if model == "nano-banana-pro":
return 0.028
return 0.014
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _submit_task(
self,
prompt: str,
model: str,
resolution: str,
aspect_ratio: str,
num_images: int,
files_url: list[str],
) -> str:
"""提交图像生成任务,返回任务 ID。"""
timeout = self.config.timeout_ms / 1000
payload = {
"prompt": prompt,
"numImages": num_images,
"aspectRatio": aspect_ratio,
"filesUrl": files_url,
}
# 可选参数,不传则使用默认值
if model != DEFAULT_MODEL:
payload["model"] = model
if resolution != DEFAULT_RESOLUTION:
payload["resolution"] = resolution
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
f"{self.api_base}/api/cqt/generator/nano",
json=payload,
headers={
"Authorization": self.config.api_key,
"Content-Type": "application/json",
},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 任务提交失败: {data.get('msg', '未知错误')}")
task_id = data.get("data")
if not task_id:
raise ValueError("CQTAI 未返回任务 ID")
return task_id
async def _poll_result(self, task_id: str) -> list[str]:
"""轮询获取生成结果。
Returns:
图片 URL 列表
"""
timeout = self.config.timeout_ms / 1000
for attempt in range(MAX_POLL_ATTEMPTS):
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": task_id},
headers={"Authorization": self.config.api_key},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 查询失败: {data.get('msg', '未知错误')}")
result_data = data.get("data", {})
status = result_data.get("status")
if status == "completed":
# 提取图片 URL
images = result_data.get("images", [])
if not images:
# 兼容不同返回格式
image_url = result_data.get("imageUrl") or result_data.get("url")
if image_url:
images = [image_url]
if not images:
raise ValueError("CQTAI 未返回图片 URL")
return images
elif status == "failed":
error_msg = result_data.get("error", "生成失败")
raise ValueError(f"CQTAI 图像生成失败: {error_msg}")
# 继续等待
logger.debug(
"cqtai_poll_waiting",
task_id=task_id,
attempt=attempt + 1,
status=status,
)
await asyncio.sleep(POLL_INTERVAL_SECONDS)
raise TimeoutError(f"CQTAI 任务超时: {task_id}")

View File

@@ -1,73 +1,73 @@
"""适配器注册表 - 支持动态注册和工厂创建。"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.services.adapters.base import AdapterConfig, BaseAdapter
class AdapterRegistry:
"""适配器注册表,管理所有已注册的适配器类。"""
_adapters: dict[str, type["BaseAdapter"]] = {}
@classmethod
def register(cls, adapter_type: str, adapter_name: str):
"""装饰器:注册适配器类。
用法:
"""适配器注册表 - 支持动态注册和工厂创建。"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.services.adapters.base import AdapterConfig, BaseAdapter
class AdapterRegistry:
"""适配器注册表,管理所有已注册的适配器类。"""
_adapters: dict[str, type["BaseAdapter"]] = {}
@classmethod
def register(cls, adapter_type: str, adapter_name: str):
"""装饰器:注册适配器类。
用法:
@AdapterRegistry.register("text", "text_primary")
class TextPrimaryAdapter(BaseAdapter[StoryOutput]):
...
"""
def decorator(adapter_class: type["BaseAdapter"]):
key = f"{adapter_type}:{adapter_name}"
cls._adapters[key] = adapter_class
# 自动设置类属性
adapter_class.adapter_type = adapter_type
adapter_class.adapter_name = adapter_name
return adapter_class
return decorator
@classmethod
def get(cls, adapter_type: str, adapter_name: str) -> type["BaseAdapter"] | None:
"""获取已注册的适配器类。"""
key = f"{adapter_type}:{adapter_name}"
return cls._adapters.get(key)
@classmethod
def list_adapters(cls, adapter_type: str | None = None) -> list[str]:
"""列出所有已注册的适配器。
Args:
adapter_type: 可选,筛选特定类型 (text/image/tts)
Returns:
适配器键列表,格式为 "type:name"
"""
if adapter_type:
return [k for k in cls._adapters if k.startswith(f"{adapter_type}:")]
return list(cls._adapters.keys())
@classmethod
def create(
cls,
adapter_type: str,
adapter_name: str,
config: "AdapterConfig",
) -> "BaseAdapter":
"""工厂方法:创建适配器实例。
Raises:
ValueError: 适配器未注册
"""
adapter_class = cls.get(adapter_type, adapter_name)
if not adapter_class:
available = cls.list_adapters(adapter_type)
raise ValueError(
f"适配器 '{adapter_type}:{adapter_name}' 未注册。"
f"可用: {available}"
)
return adapter_class(config)
...
"""
def decorator(adapter_class: type["BaseAdapter"]):
key = f"{adapter_type}:{adapter_name}"
cls._adapters[key] = adapter_class
# 自动设置类属性
adapter_class.adapter_type = adapter_type
adapter_class.adapter_name = adapter_name
return adapter_class
return decorator
@classmethod
def get(cls, adapter_type: str, adapter_name: str) -> type["BaseAdapter"] | None:
"""获取已注册的适配器类。"""
key = f"{adapter_type}:{adapter_name}"
return cls._adapters.get(key)
@classmethod
def list_adapters(cls, adapter_type: str | None = None) -> list[str]:
"""列出所有已注册的适配器。
Args:
adapter_type: 可选,筛选特定类型 (text/image/tts)
Returns:
适配器键列表,格式为 "type:name"
"""
if adapter_type:
return [k for k in cls._adapters if k.startswith(f"{adapter_type}:")]
return list(cls._adapters.keys())
@classmethod
def create(
cls,
adapter_type: str,
adapter_name: str,
config: "AdapterConfig",
) -> "BaseAdapter":
"""工厂方法:创建适配器实例。
Raises:
ValueError: 适配器未注册
"""
adapter_class = cls.get(adapter_type, adapter_name)
if not adapter_class:
available = cls.list_adapters(adapter_type)
raise ValueError(
f"适配器 '{adapter_type}:{adapter_name}' 未注册。"
f"可用: {available}"
)
return adapter_class(config)

View File

@@ -1 +1 @@
"""Storybook 适配器模块。"""
"""Storybook 适配器模块。"""

View File

@@ -1,195 +1,195 @@
"""Storybook 适配器 - 生成可翻页的分页故事书。"""
import json
import random
import re
import time
from dataclasses import dataclass, field
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_STORYBOOK,
USER_PROMPT_STORYBOOK,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@dataclass
class StorybookPage:
"""故事书单页。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
@dataclass
class Storybook:
"""故事书输出。"""
title: str
main_character: str
art_style: str
pages: list[StorybookPage] = field(default_factory=list)
cover_prompt: str = ""
cover_url: str | None = None
@AdapterRegistry.register("storybook", "storybook_primary")
class StorybookPrimaryAdapter(BaseAdapter[Storybook]):
"""Storybook 生成适配器(默认)。
生成分页故事书结构,包含每页文字和图像提示词。
图像生成需要单独调用 image adapter。
"""
adapter_type = "storybook"
adapter_name = "storybook_primary"
async def execute(
self,
keywords: str,
page_count: int = 6,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> Storybook:
"""生成分页故事书。
Args:
keywords: 故事关键词
page_count: 页数 (4-12)
education_theme: 教育主题
memory_context: 记忆上下文
Returns:
Storybook 对象,包含标题、页面列表和封面提示词
"""
start_time = time.time()
page_count = max(4, min(page_count, 12)) # 限制 4-12 页
logger.info(
"storybook_generate_start",
keywords=keywords,
page_count=page_count,
has_memory=bool(memory_context),
)
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
prompt = USER_PROMPT_STORYBOOK.format(
keywords=keywords,
education_theme=theme,
random_element=random_element,
page_count=page_count,
memory_context=memory_context or "",
)
payload = {
"system_instruction": {"parts": [{"text": SYSTEM_INSTRUCTION_STORYBOOK}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Storybook 服务未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Storybook 服务响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Storybook JSON 解析失败: {exc}")
# 构建 Storybook 对象
pages = [
StorybookPage(
page_number=p.get("page_number", i + 1),
text=p.get("text", ""),
image_prompt=p.get("image_prompt", ""),
)
for i, p in enumerate(parsed.get("pages", []))
]
storybook = Storybook(
title=parsed.get("title", "未命名故事"),
main_character=parsed.get("main_character", ""),
art_style=parsed.get("art_style", ""),
pages=pages,
cover_prompt=parsed.get("cover_prompt", ""),
)
elapsed = time.time() - start_time
logger.info(
"storybook_generate_success",
elapsed_seconds=round(elapsed, 2),
title=storybook.title,
page_count=len(pages),
)
return storybook
async def health_check(self) -> bool:
"""检查 API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估成本(仅文本生成,不含图像)。"""
return 0.002 # 比普通故事稍贵,因为输出更长
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 API带重试机制。"""
model = self.config.model or "gemini-2.0-flash"
url = f"{TEXT_API_BASE}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()
"""Storybook 适配器 - 生成可翻页的分页故事书。"""
import json
import random
import re
import time
from dataclasses import dataclass, field
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_STORYBOOK,
USER_PROMPT_STORYBOOK,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@dataclass
class StorybookPage:
"""故事书单页。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
@dataclass
class Storybook:
"""故事书输出。"""
title: str
main_character: str
art_style: str
pages: list[StorybookPage] = field(default_factory=list)
cover_prompt: str = ""
cover_url: str | None = None
@AdapterRegistry.register("storybook", "storybook_primary")
class StorybookPrimaryAdapter(BaseAdapter[Storybook]):
"""Storybook 生成适配器(默认)。
生成分页故事书结构,包含每页文字和图像提示词。
图像生成需要单独调用 image adapter。
"""
adapter_type = "storybook"
adapter_name = "storybook_primary"
async def execute(
self,
keywords: str,
page_count: int = 6,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> Storybook:
"""生成分页故事书。
Args:
keywords: 故事关键词
page_count: 页数 (4-12)
education_theme: 教育主题
memory_context: 记忆上下文
Returns:
Storybook 对象,包含标题、页面列表和封面提示词
"""
start_time = time.time()
page_count = max(4, min(page_count, 12)) # 限制 4-12 页
logger.info(
"storybook_generate_start",
keywords=keywords,
page_count=page_count,
has_memory=bool(memory_context),
)
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
prompt = USER_PROMPT_STORYBOOK.format(
keywords=keywords,
education_theme=theme,
random_element=random_element,
page_count=page_count,
memory_context=memory_context or "",
)
payload = {
"system_instruction": {"parts": [{"text": SYSTEM_INSTRUCTION_STORYBOOK}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Storybook 服务未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Storybook 服务响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Storybook JSON 解析失败: {exc}")
# 构建 Storybook 对象
pages = [
StorybookPage(
page_number=p.get("page_number", i + 1),
text=p.get("text", ""),
image_prompt=p.get("image_prompt", ""),
)
for i, p in enumerate(parsed.get("pages", []))
]
storybook = Storybook(
title=parsed.get("title", "未命名故事"),
main_character=parsed.get("main_character", ""),
art_style=parsed.get("art_style", ""),
pages=pages,
cover_prompt=parsed.get("cover_prompt", ""),
)
elapsed = time.time() - start_time
logger.info(
"storybook_generate_success",
elapsed_seconds=round(elapsed, 2),
title=storybook.title,
page_count=len(pages),
)
return storybook
async def health_check(self) -> bool:
"""检查 API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估成本(仅文本生成,不含图像)。"""
return 0.002 # 比普通故事稍贵,因为输出更长
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 API带重试机制。"""
model = self.config.model or "gemini-2.0-flash"
url = f"{TEXT_API_BASE}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()

View File

@@ -1 +1 @@
"""文本生成适配器。"""
"""文本生成适配器。"""

View File

@@ -1,164 +1,164 @@
"""文本生成适配器 (Google Gemini)。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@AdapterRegistry.register("text", "gemini")
class GeminiTextAdapter(BaseAdapter[StoryOutput]):
"""Google Gemini 文本生成适配器。"""
adapter_type = "text"
adapter_name = "gemini"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("request_start", adapter="gemini", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
# Gemini API Payload supports 'system_instruction'
payload = {
"system_instruction": {"parts": [{"text": system_instruction}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Gemini 未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Gemini 响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Gemini 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("Gemini 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"request_success",
adapter="gemini",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
async def health_check(self) -> bool:
"""检查 Gemini API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.001
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 Gemini API。"""
model = self.config.model or "gemini-2.0-flash"
base_url = self.config.api_base or TEXT_API_BASE
# 智能补全:
# 1. 如果用户填了完整路径 (以 /models 结尾),就直接用 (支持 v1 或 v1beta)
if self.config.api_base and base_url.rstrip("/").endswith("/models"):
pass
# 2. 如果没填路径 (只是域名),默认补全代码适配的 /v1beta/models
elif self.config.api_base:
base_url = f"{base_url.rstrip('/')}/v1beta/models"
url = f"{base_url}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()
"""文本生成适配器 (Google Gemini)。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@AdapterRegistry.register("text", "gemini")
class GeminiTextAdapter(BaseAdapter[StoryOutput]):
"""Google Gemini 文本生成适配器。"""
adapter_type = "text"
adapter_name = "gemini"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("request_start", adapter="gemini", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
# Gemini API Payload supports 'system_instruction'
payload = {
"system_instruction": {"parts": [{"text": system_instruction}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Gemini 未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Gemini 响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Gemini 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("Gemini 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"request_success",
adapter="gemini",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
async def health_check(self) -> bool:
"""检查 Gemini API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.001
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 Gemini API。"""
model = self.config.model or "gemini-2.0-flash"
base_url = self.config.api_base or TEXT_API_BASE
# 智能补全:
# 1. 如果用户填了完整路径 (以 /models 结尾),就直接用 (支持 v1 或 v1beta)
if self.config.api_base and base_url.rstrip("/").endswith("/models"):
pass
# 2. 如果没填路径 (只是域名),默认补全代码适配的 /v1beta/models
elif self.config.api_base:
base_url = f"{base_url.rstrip('/')}/v1beta/models"
url = f"{base_url}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()

View File

@@ -1,11 +1,11 @@
from dataclasses import dataclass
from typing import Literal
@dataclass
class StoryOutput:
"""故事生成输出。"""
mode: Literal["generated", "enhanced"]
title: str
story_text: str
cover_prompt_suggestion: str
from dataclasses import dataclass
from typing import Literal
@dataclass
class StoryOutput:
"""故事生成输出。"""
mode: Literal["generated", "enhanced"]
title: str
story_text: str
cover_prompt_suggestion: str

View File

@@ -1,172 +1,172 @@
"""OpenAI 文本生成适配器。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
OPENAI_API_BASE = "https://api.openai.com/v1/chat/completions"
@AdapterRegistry.register("text", "openai")
class OpenAITextAdapter(BaseAdapter[StoryOutput]):
"""OpenAI 文本生成适配器。"""
adapter_type = "text"
adapter_name = "openai"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("openai_text_request_start", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
model = self.config.model or "gpt-4o-mini"
payload = {
"model": model,
"messages": [
{
"role": "system",
"content": system_instruction,
},
{"role": "user", "content": prompt},
],
"response_format": {"type": "json_object"},
"temperature": 0.95,
"top_p": 0.9,
}
result = await self._call_api(payload)
choices = result.get("choices") or []
if not choices:
raise ValueError("OpenAI 未返回内容")
response_text = choices[0].get("message", {}).get("content", "")
if not response_text:
raise ValueError("OpenAI 响应缺少文本")
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"OpenAI 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("OpenAI 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"openai_text_request_success",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
mode=parsed["mode"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(httpx.HTTPStatusError),
)
async def _call_api(self, payload: dict) -> dict:
"""调用 OpenAI API带重试机制。"""
url = self.config.api_base or OPENAI_API_BASE
# 智能补全: 如果用户只填了 Base URL自动补全路径
if self.config.api_base and not url.endswith("/chat/completions"):
base = url.rstrip("/")
url = f"{base}/chat/completions"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()
async def health_check(self) -> bool:
"""检查 OpenAI API 是否可用。"""
try:
payload = {
"model": self.config.model or "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5,
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估文本生成成本 (USD)。"""
return 0.01
"""OpenAI 文本生成适配器。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
OPENAI_API_BASE = "https://api.openai.com/v1/chat/completions"
@AdapterRegistry.register("text", "openai")
class OpenAITextAdapter(BaseAdapter[StoryOutput]):
"""OpenAI 文本生成适配器。"""
adapter_type = "text"
adapter_name = "openai"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("openai_text_request_start", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
model = self.config.model or "gpt-4o-mini"
payload = {
"model": model,
"messages": [
{
"role": "system",
"content": system_instruction,
},
{"role": "user", "content": prompt},
],
"response_format": {"type": "json_object"},
"temperature": 0.95,
"top_p": 0.9,
}
result = await self._call_api(payload)
choices = result.get("choices") or []
if not choices:
raise ValueError("OpenAI 未返回内容")
response_text = choices[0].get("message", {}).get("content", "")
if not response_text:
raise ValueError("OpenAI 响应缺少文本")
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"OpenAI 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("OpenAI 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"openai_text_request_success",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
mode=parsed["mode"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(httpx.HTTPStatusError),
)
async def _call_api(self, payload: dict) -> dict:
"""调用 OpenAI API带重试机制。"""
url = self.config.api_base or OPENAI_API_BASE
# 智能补全: 如果用户只填了 Base URL自动补全路径
if self.config.api_base and not url.endswith("/chat/completions"):
base = url.rstrip("/")
url = f"{base}/chat/completions"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()
async def health_check(self) -> bool:
"""检查 OpenAI API 是否可用。"""
try:
payload = {
"model": self.config.model or "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5,
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估文本生成成本 (USD)。"""
return 0.01

View File

@@ -1,5 +1,5 @@
"""TTS 语音合成适配器。"""
from app.services.adapters.tts import edge_tts as _tts_edge_tts_adapter # noqa: F401
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
"""TTS 语音合成适配器。"""
from app.services.adapters.tts import edge_tts as _tts_edge_tts_adapter # noqa: F401
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401

View File

@@ -1,66 +1,66 @@
"""EdgeTTS 免费语音生成适配器。"""
import time
import edge_tts
from app.core.logging import get_logger
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认中文女声 (晓晓)
DEFAULT_VOICE = "zh-CN-XiaoxiaoNeural"
@AdapterRegistry.register("tts", "edge_tts")
class EdgeTTSAdapter(BaseAdapter[bytes]):
"""EdgeTTS 语音生成适配器 (Free)。
不需要 API Key。
"""
adapter_type = "tts"
adapter_name = "edge_tts"
async def execute(self, text: str, **kwargs) -> bytes:
"""生成语音。"""
# 支持动态指定音色
voice = kwargs.get("voice") or self.config.model or DEFAULT_VOICE
start_time = time.time()
logger.info("edge_tts_generate_start", text_length=len(text), voice=voice)
# EdgeTTS 只能输出到文件,我们需要用临时文件周转一下
# 或者直接 capture stream (communicate) 但 edge-tts 库主要面向文件
# 优化: 使用 communicate 直接获取 bytes无需磁盘IO
communicate = edge_tts.Communicate(text, voice)
audio_data = b""
async for chunk in communicate.stream():
if chunk["type"] == "audio":
audio_data += chunk["data"]
elapsed = time.time() - start_time
logger.info(
"edge_tts_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_data),
)
return audio_data
async def health_check(self) -> bool:
"""检查 EdgeTTS 是否可用 (网络连通性)。"""
try:
# 简单生成一个词
await self.execute("Hi")
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.0 # Free!
"""EdgeTTS 免费语音生成适配器。"""
import time
import edge_tts
from app.core.logging import get_logger
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认中文女声 (晓晓)
DEFAULT_VOICE = "zh-CN-XiaoxiaoNeural"
@AdapterRegistry.register("tts", "edge_tts")
class EdgeTTSAdapter(BaseAdapter[bytes]):
"""EdgeTTS 语音生成适配器 (Free)。
不需要 API Key。
"""
adapter_type = "tts"
adapter_name = "edge_tts"
async def execute(self, text: str, **kwargs) -> bytes:
"""生成语音。"""
# 支持动态指定音色
voice = kwargs.get("voice") or self.config.model or DEFAULT_VOICE
start_time = time.time()
logger.info("edge_tts_generate_start", text_length=len(text), voice=voice)
# EdgeTTS 只能输出到文件,我们需要用临时文件周转一下
# 或者直接 capture stream (communicate) 但 edge-tts 库主要面向文件
# 优化: 使用 communicate 直接获取 bytes无需磁盘IO
communicate = edge_tts.Communicate(text, voice)
audio_data = b""
async for chunk in communicate.stream():
if chunk["type"] == "audio":
audio_data += chunk["data"]
elapsed = time.time() - start_time
logger.info(
"edge_tts_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_data),
)
return audio_data
async def health_check(self) -> bool:
"""检查 EdgeTTS 是否可用 (网络连通性)。"""
try:
# 简单生成一个词
await self.execute("Hi")
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.0 # Free!

View File

@@ -1,104 +1,104 @@
"""ElevenLabs TTS 语音合成适配器。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
ELEVENLABS_API_BASE = "https://api.elevenlabs.io/v1"
DEFAULT_VOICE_ID = "21m00Tcm4TlvDq8ikWAM" # Rachel
@AdapterRegistry.register("tts", "elevenlabs")
class ElevenLabsTtsAdapter(BaseAdapter[bytes]):
"""ElevenLabs TTS 语音合成适配器,返回 MP3 bytes。"""
adapter_type = "tts"
adapter_name = "elevenlabs"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or ELEVENLABS_API_BASE
async def execute(self, text: str, **kwargs) -> bytes:
"""将文本转换为语音 MP3 bytes。"""
start_time = time.time()
logger.info("elevenlabs_tts_start", text_length=len(text))
voice_id = kwargs.get("voice_id") or DEFAULT_VOICE_ID
model_id = kwargs.get("model") or self.config.model or "eleven_multilingual_v2"
stability = kwargs.get("stability", 0.5)
similarity_boost = kwargs.get("similarity_boost", 0.75)
url = f"{self.api_base}/text-to-speech/{voice_id}"
payload = {
"text": text,
"model_id": model_id,
"voice_settings": {
"stability": stability,
"similarity_boost": similarity_boost,
},
}
audio_bytes = await self._call_api(url, payload)
elapsed = time.time() - start_time
logger.info(
"elevenlabs_tts_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 ElevenLabs API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(
f"{self.api_base}/voices",
headers={"xi-api-key": self.config.api_key},
)
return response.status_code == 200
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每千字符成本 (USD)。"""
return 0.03
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> bytes:
"""调用 ElevenLabs API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"xi-api-key": self.config.api_key,
"Content-Type": "application/json",
"Accept": "audio/mpeg",
},
)
response.raise_for_status()
return response.content
"""ElevenLabs TTS 语音合成适配器。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
ELEVENLABS_API_BASE = "https://api.elevenlabs.io/v1"
DEFAULT_VOICE_ID = "21m00Tcm4TlvDq8ikWAM" # Rachel
@AdapterRegistry.register("tts", "elevenlabs")
class ElevenLabsTtsAdapter(BaseAdapter[bytes]):
"""ElevenLabs TTS 语音合成适配器,返回 MP3 bytes。"""
adapter_type = "tts"
adapter_name = "elevenlabs"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or ELEVENLABS_API_BASE
async def execute(self, text: str, **kwargs) -> bytes:
"""将文本转换为语音 MP3 bytes。"""
start_time = time.time()
logger.info("elevenlabs_tts_start", text_length=len(text))
voice_id = kwargs.get("voice_id") or DEFAULT_VOICE_ID
model_id = kwargs.get("model") or self.config.model or "eleven_multilingual_v2"
stability = kwargs.get("stability", 0.5)
similarity_boost = kwargs.get("similarity_boost", 0.75)
url = f"{self.api_base}/text-to-speech/{voice_id}"
payload = {
"text": text,
"model_id": model_id,
"voice_settings": {
"stability": stability,
"similarity_boost": similarity_boost,
},
}
audio_bytes = await self._call_api(url, payload)
elapsed = time.time() - start_time
logger.info(
"elevenlabs_tts_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 ElevenLabs API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(
f"{self.api_base}/voices",
headers={"xi-api-key": self.config.api_key},
)
return response.status_code == 200
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每千字符成本 (USD)。"""
return 0.03
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> bytes:
"""调用 ElevenLabs API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"xi-api-key": self.config.api_key,
"Content-Type": "application/json",
"Accept": "audio/mpeg",
},
)
response.raise_for_status()
return response.content

View File

@@ -1,149 +1,149 @@
"""MiniMax 语音生成适配器 (T2A V2)。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.config import settings
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# MiniMax API 配置
DEFAULT_API_URL = "https://api.minimaxi.com/v1/t2a_v2"
DEFAULT_MODEL = "speech-2.6-turbo"
@AdapterRegistry.register("tts", "minimax")
class MiniMaxTTSAdapter(BaseAdapter[bytes]):
"""MiniMax 语音生成适配器。
需要配置:
- api_key: MiniMax API Key
- minimax_group_id: 可选 (取决于使用的模型/账户类型)
"""
adapter_type = "tts"
adapter_name = "minimax"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_url = DEFAULT_API_URL
async def execute(
self,
text: str,
voice_id: str | None = None,
model: str | None = None,
speed: float | None = None,
vol: float | None = None,
pitch: int | None = None,
emotion: str | None = None,
**kwargs,
) -> bytes:
"""生成语音。"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
voice_id = voice_id or cfg.get("voice_id") or "male-qn-qingse"
speed = speed if speed is not None else (cfg.get("speed") or 1.0)
vol = vol if vol is not None else (cfg.get("vol") or 1.0)
pitch = pitch if pitch is not None else (cfg.get("pitch") or 0)
emotion = emotion or cfg.get("emotion")
group_id = kwargs.get("group_id") or settings.minimax_group_id
url = self.api_url
if group_id:
url = f"{self.api_url}?GroupId={group_id}"
payload = {
"model": model,
"text": text,
"stream": False,
"voice_setting": {
"voice_id": voice_id,
"speed": speed,
"vol": vol,
"pitch": pitch,
},
"audio_setting": {
"sample_rate": 32000,
"bitrate": 128000,
"format": "mp3",
"channel": 1
}
}
if emotion:
payload["voice_setting"]["emotion"] = emotion
start_time = time.time()
logger.info("minimax_generate_start", text_length=len(text), model=model)
result = await self._call_api(url, payload)
# 错误处理
if result.get("base_resp", {}).get("status_code") != 0:
error_msg = result.get("base_resp", {}).get("status_msg", "未知错误")
raise ValueError(f"MiniMax API 错误: {error_msg}")
# Hex 解码 (关键逻辑,从 primary.py 迁移)
hex_audio = result.get("data", {}).get("audio")
if not hex_audio:
raise ValueError("API 响应中未找到音频数据 (data.audio)")
try:
audio_bytes = bytes.fromhex(hex_audio)
except ValueError:
raise ValueError("MiniMax 返回的音频数据不是有效的 Hex 字符串")
elapsed = time.time() - start_time
logger.info(
"minimax_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 Minimax API 是否可用。"""
try:
# 尝试生成极短文本
await self.execute("Hi")
return True
except Exception:
return False
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> dict:
"""调用 API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()
"""MiniMax 语音生成适配器 (T2A V2)。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.config import settings
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# MiniMax API 配置
DEFAULT_API_URL = "https://api.minimaxi.com/v1/t2a_v2"
DEFAULT_MODEL = "speech-2.6-turbo"
@AdapterRegistry.register("tts", "minimax")
class MiniMaxTTSAdapter(BaseAdapter[bytes]):
"""MiniMax 语音生成适配器。
需要配置:
- api_key: MiniMax API Key
- minimax_group_id: 可选 (取决于使用的模型/账户类型)
"""
adapter_type = "tts"
adapter_name = "minimax"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_url = DEFAULT_API_URL
async def execute(
self,
text: str,
voice_id: str | None = None,
model: str | None = None,
speed: float | None = None,
vol: float | None = None,
pitch: int | None = None,
emotion: str | None = None,
**kwargs,
) -> bytes:
"""生成语音。"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
voice_id = voice_id or cfg.get("voice_id") or "male-qn-qingse"
speed = speed if speed is not None else (cfg.get("speed") or 1.0)
vol = vol if vol is not None else (cfg.get("vol") or 1.0)
pitch = pitch if pitch is not None else (cfg.get("pitch") or 0)
emotion = emotion or cfg.get("emotion")
group_id = kwargs.get("group_id") or settings.minimax_group_id
url = self.api_url
if group_id:
url = f"{self.api_url}?GroupId={group_id}"
payload = {
"model": model,
"text": text,
"stream": False,
"voice_setting": {
"voice_id": voice_id,
"speed": speed,
"vol": vol,
"pitch": pitch,
},
"audio_setting": {
"sample_rate": 32000,
"bitrate": 128000,
"format": "mp3",
"channel": 1
}
}
if emotion:
payload["voice_setting"]["emotion"] = emotion
start_time = time.time()
logger.info("minimax_generate_start", text_length=len(text), model=model)
result = await self._call_api(url, payload)
# 错误处理
if result.get("base_resp", {}).get("status_code") != 0:
error_msg = result.get("base_resp", {}).get("status_msg", "未知错误")
raise ValueError(f"MiniMax API 错误: {error_msg}")
# Hex 解码 (关键逻辑,从 primary.py 迁移)
hex_audio = result.get("data", {}).get("audio")
if not hex_audio:
raise ValueError("API 响应中未找到音频数据 (data.audio)")
try:
audio_bytes = bytes.fromhex(hex_audio)
except ValueError:
raise ValueError("MiniMax 返回的音频数据不是有效的 Hex 字符串")
elapsed = time.time() - start_time
logger.info(
"minimax_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 Minimax API 是否可用。"""
try:
# 尝试生成极短文本
await self.execute("Hi")
return True
except Exception:
return False
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> dict:
"""调用 API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()

View File

@@ -1,196 +1,196 @@
"""成本追踪服务。
记录 API 调用成本,支持预算控制。
"""
from datetime import datetime, timedelta
from decimal import Decimal
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import CostRecord, UserBudget
logger = get_logger(__name__)
class BudgetExceededError(Exception):
"""预算超限错误。"""
def __init__(self, limit_type: str, used: Decimal, limit: Decimal):
self.limit_type = limit_type
self.used = used
self.limit = limit
super().__init__(f"{limit_type} 预算已超限: {used}/{limit} USD")
class CostTracker:
"""成本追踪器。"""
async def record_cost(
self,
db: AsyncSession,
user_id: str,
provider_name: str,
capability: str,
estimated_cost: float,
provider_id: str | None = None,
) -> CostRecord:
"""记录一次 API 调用成本。"""
record = CostRecord(
user_id=user_id,
provider_id=provider_id,
provider_name=provider_name,
capability=capability,
estimated_cost=Decimal(str(estimated_cost)),
)
db.add(record)
await db.commit()
logger.debug(
"cost_recorded",
user_id=user_id,
provider=provider_name,
capability=capability,
cost=estimated_cost,
)
return record
async def get_daily_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户今日成本。"""
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= today_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_monthly_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户本月成本。"""
now = datetime.utcnow()
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= month_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_cost_by_capability(
self,
db: AsyncSession,
user_id: str,
days: int = 30,
) -> dict[str, Decimal]:
"""按能力类型统计成本。"""
since = datetime.utcnow() - timedelta(days=days)
result = await db.execute(
select(CostRecord.capability, func.sum(CostRecord.estimated_cost))
.where(CostRecord.user_id == user_id, CostRecord.timestamp >= since)
.group_by(CostRecord.capability)
)
return {row[0]: Decimal(str(row[1])) for row in result.all()}
async def check_budget(
self,
db: AsyncSession,
user_id: str,
estimated_cost: float,
) -> bool:
"""检查预算是否允许此次调用。
Returns:
True 如果允许,否则抛出 BudgetExceededError
"""
budget = await self.get_user_budget(db, user_id)
if not budget or not budget.enabled:
return True
# 检查日预算
daily_cost = await self.get_daily_cost(db, user_id)
if daily_cost + Decimal(str(estimated_cost)) > budget.daily_limit_usd:
raise BudgetExceededError("", daily_cost, budget.daily_limit_usd)
# 检查月预算
monthly_cost = await self.get_monthly_cost(db, user_id)
if monthly_cost + Decimal(str(estimated_cost)) > budget.monthly_limit_usd:
raise BudgetExceededError("", monthly_cost, budget.monthly_limit_usd)
return True
async def get_user_budget(self, db: AsyncSession, user_id: str) -> UserBudget | None:
"""获取用户预算配置。"""
result = await db.execute(
select(UserBudget).where(UserBudget.user_id == user_id)
)
return result.scalar_one_or_none()
async def set_user_budget(
self,
db: AsyncSession,
user_id: str,
daily_limit: float | None = None,
monthly_limit: float | None = None,
alert_threshold: float | None = None,
enabled: bool | None = None,
) -> UserBudget:
"""设置用户预算。"""
budget = await self.get_user_budget(db, user_id)
if budget is None:
budget = UserBudget(user_id=user_id)
db.add(budget)
if daily_limit is not None:
budget.daily_limit_usd = Decimal(str(daily_limit))
if monthly_limit is not None:
budget.monthly_limit_usd = Decimal(str(monthly_limit))
if alert_threshold is not None:
budget.alert_threshold = Decimal(str(alert_threshold))
if enabled is not None:
budget.enabled = enabled
await db.commit()
await db.refresh(budget)
return budget
async def get_cost_summary(
self,
db: AsyncSession,
user_id: str,
) -> dict:
"""获取用户成本摘要。"""
daily = await self.get_daily_cost(db, user_id)
monthly = await self.get_monthly_cost(db, user_id)
by_capability = await self.get_cost_by_capability(db, user_id)
budget = await self.get_user_budget(db, user_id)
return {
"daily_cost_usd": float(daily),
"monthly_cost_usd": float(monthly),
"by_capability": {k: float(v) for k, v in by_capability.items()},
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd) if budget else None,
"monthly_limit_usd": float(budget.monthly_limit_usd) if budget else None,
"daily_usage_percent": float(daily / budget.daily_limit_usd * 100)
if budget and budget.daily_limit_usd
else None,
"monthly_usage_percent": float(monthly / budget.monthly_limit_usd * 100)
if budget and budget.monthly_limit_usd
else None,
"enabled": budget.enabled if budget else False,
},
}
# 全局单例
cost_tracker = CostTracker()
"""成本追踪服务。
记录 API 调用成本,支持预算控制。
"""
from datetime import datetime, timedelta
from decimal import Decimal
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import CostRecord, UserBudget
logger = get_logger(__name__)
class BudgetExceededError(Exception):
"""预算超限错误。"""
def __init__(self, limit_type: str, used: Decimal, limit: Decimal):
self.limit_type = limit_type
self.used = used
self.limit = limit
super().__init__(f"{limit_type} 预算已超限: {used}/{limit} USD")
class CostTracker:
"""成本追踪器。"""
async def record_cost(
self,
db: AsyncSession,
user_id: str,
provider_name: str,
capability: str,
estimated_cost: float,
provider_id: str | None = None,
) -> CostRecord:
"""记录一次 API 调用成本。"""
record = CostRecord(
user_id=user_id,
provider_id=provider_id,
provider_name=provider_name,
capability=capability,
estimated_cost=Decimal(str(estimated_cost)),
)
db.add(record)
await db.commit()
logger.debug(
"cost_recorded",
user_id=user_id,
provider=provider_name,
capability=capability,
cost=estimated_cost,
)
return record
async def get_daily_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户今日成本。"""
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= today_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_monthly_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户本月成本。"""
now = datetime.utcnow()
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= month_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_cost_by_capability(
self,
db: AsyncSession,
user_id: str,
days: int = 30,
) -> dict[str, Decimal]:
"""按能力类型统计成本。"""
since = datetime.utcnow() - timedelta(days=days)
result = await db.execute(
select(CostRecord.capability, func.sum(CostRecord.estimated_cost))
.where(CostRecord.user_id == user_id, CostRecord.timestamp >= since)
.group_by(CostRecord.capability)
)
return {row[0]: Decimal(str(row[1])) for row in result.all()}
async def check_budget(
self,
db: AsyncSession,
user_id: str,
estimated_cost: float,
) -> bool:
"""检查预算是否允许此次调用。
Returns:
True 如果允许,否则抛出 BudgetExceededError
"""
budget = await self.get_user_budget(db, user_id)
if not budget or not budget.enabled:
return True
# 检查日预算
daily_cost = await self.get_daily_cost(db, user_id)
if daily_cost + Decimal(str(estimated_cost)) > budget.daily_limit_usd:
raise BudgetExceededError("", daily_cost, budget.daily_limit_usd)
# 检查月预算
monthly_cost = await self.get_monthly_cost(db, user_id)
if monthly_cost + Decimal(str(estimated_cost)) > budget.monthly_limit_usd:
raise BudgetExceededError("", monthly_cost, budget.monthly_limit_usd)
return True
async def get_user_budget(self, db: AsyncSession, user_id: str) -> UserBudget | None:
"""获取用户预算配置。"""
result = await db.execute(
select(UserBudget).where(UserBudget.user_id == user_id)
)
return result.scalar_one_or_none()
async def set_user_budget(
self,
db: AsyncSession,
user_id: str,
daily_limit: float | None = None,
monthly_limit: float | None = None,
alert_threshold: float | None = None,
enabled: bool | None = None,
) -> UserBudget:
"""设置用户预算。"""
budget = await self.get_user_budget(db, user_id)
if budget is None:
budget = UserBudget(user_id=user_id)
db.add(budget)
if daily_limit is not None:
budget.daily_limit_usd = Decimal(str(daily_limit))
if monthly_limit is not None:
budget.monthly_limit_usd = Decimal(str(monthly_limit))
if alert_threshold is not None:
budget.alert_threshold = Decimal(str(alert_threshold))
if enabled is not None:
budget.enabled = enabled
await db.commit()
await db.refresh(budget)
return budget
async def get_cost_summary(
self,
db: AsyncSession,
user_id: str,
) -> dict:
"""获取用户成本摘要。"""
daily = await self.get_daily_cost(db, user_id)
monthly = await self.get_monthly_cost(db, user_id)
by_capability = await self.get_cost_by_capability(db, user_id)
budget = await self.get_user_budget(db, user_id)
return {
"daily_cost_usd": float(daily),
"monthly_cost_usd": float(monthly),
"by_capability": {k: float(v) for k, v in by_capability.items()},
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd) if budget else None,
"monthly_limit_usd": float(budget.monthly_limit_usd) if budget else None,
"daily_usage_percent": float(daily / budget.daily_limit_usd * 100)
if budget and budget.daily_limit_usd
else None,
"monthly_usage_percent": float(monthly / budget.monthly_limit_usd * 100)
if budget and budget.monthly_limit_usd
else None,
"enabled": budget.enabled if budget else False,
},
}
# 全局单例
cost_tracker = CostTracker()

View File

@@ -1,471 +1,471 @@
"""Memory service handles memory retrieval, scoring, and prompt injection."""
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models import ChildProfile, MemoryItem, StoryUniverse
logger = get_logger(__name__)
class MemoryType:
"""记忆类型常量及配置。"""
# 基础类型
RECENT_STORY = "recent_story"
FAVORITE_CHARACTER = "favorite_character"
SCARY_ELEMENT = "scary_element"
VOCABULARY_GROWTH = "vocabulary_growth"
EMOTIONAL_HIGHLIGHT = "emotional_highlight"
# Phase 1 新增类型
READING_PREFERENCE = "reading_preference" # 阅读偏好
MILESTONE = "milestone" # 里程碑事件
SKILL_MASTERED = "skill_mastered" # 掌握的技能
# 类型配置: (默认权重, 默认TTL天数, 描述)
CONFIG = {
RECENT_STORY: (1.0, 30, "最近阅读的故事"),
FAVORITE_CHARACTER: (1.5, None, "喜欢的角色"), # None = 永久
SCARY_ELEMENT: (2.0, None, "回避的元素"), # 高权重,永久有效
VOCABULARY_GROWTH: (0.8, 90, "词汇积累"),
EMOTIONAL_HIGHLIGHT: (1.2, 60, "情感高光"),
READING_PREFERENCE: (1.0, None, "阅读偏好"),
MILESTONE: (1.5, None, "里程碑事件"),
SKILL_MASTERED: (1.0, 180, "掌握的技能"),
}
@classmethod
def get_default_weight(cls, memory_type: str) -> float:
"""获取类型的默认权重。"""
config = cls.CONFIG.get(memory_type)
return config[0] if config else 1.0
@classmethod
def get_default_ttl(cls, memory_type: str) -> int | None:
"""获取类型的默认 TTL 天数。"""
config = cls.CONFIG.get(memory_type)
return config[1] if config else None
def _decay_factor(days: float) -> float:
"""计算时间衰减因子。"""
if days <= 7:
return 1.0
if days <= 30:
return 0.7
if days <= 90:
return 0.4
return 0.2
async def build_enhanced_memory_context(
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
) -> str | None:
"""构建增强版记忆上下文(自然语言 Prompt"""
if not profile_id and not universe_id:
return None
context_parts: list[str] = []
# 1. 基础档案 (Identity Layer)
if profile_id:
profile = await db.scalar(select(ChildProfile).where(ChildProfile.id == profile_id))
if profile:
context_parts.append(f"【目标读者】\n姓名:{profile.name}")
if profile.age:
context_parts.append(f"年龄:{profile.age}")
if profile.interests:
context_parts.append(f"兴趣爱好:{''.join(profile.interests)}")
if profile.growth_themes:
context_parts.append(f"当前成长关注点:{''.join(profile.growth_themes)}")
context_parts.append("") # 空行
# 2. 故事宇宙 (Universe Layer)
if universe_id:
universe = await db.scalar(select(StoryUniverse).where(StoryUniverse.id == universe_id))
if universe:
context_parts.append("【故事宇宙设定】")
context_parts.append(f"世界观:{universe.name}")
# 主角
protagonist = universe.protagonist or {}
p_desc = f"{protagonist.get('name', '主角')} ({protagonist.get('personality', '')})"
context_parts.append(f"主角设定:{p_desc}")
# 常驻角色
if universe.recurring_characters:
chars = [f"{c.get('name')} ({c.get('type')})" for c in universe.recurring_characters if isinstance(c, dict)]
context_parts.append(f"已知伙伴:{''.join(chars)}")
# 成就
if universe.achievements:
badges = [str(a.get('type')) for a in universe.achievements if isinstance(a, dict)]
if badges:
context_parts.append(f"已获荣誉:{''.join(badges[:5])}")
context_parts.append("")
# 3. 动态记忆 (Working Memory)
if profile_id:
memories = await _fetch_scored_memories(profile_id, universe_id, db)
if memories:
memory_text = _format_memories_to_prompt(memories)
if memory_text:
context_parts.append("【关键记忆回忆】(请在故事中自然地融入或致敬以下元素)")
context_parts.append(memory_text)
return "\n".join(context_parts)
async def _fetch_scored_memories(
profile_id: str,
universe_id: str | None,
db: AsyncSession,
limit: int = 8
) -> list[MemoryItem]:
"""获取并评分记忆项,返回 Top N。"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
# 取最近 50 条进行评分
query = query.order_by(MemoryItem.last_used_at.desc(), MemoryItem.created_at.desc()).limit(50)
result = await db.execute(query)
items = result.scalars().all()
scored: list[tuple[float, MemoryItem]] = []
now = datetime.now(timezone.utc)
for item in items:
reference = item.last_used_at or item.created_at or now
delta_days = max((now - reference).total_seconds() / 86400, 0)
if item.ttl_days and delta_days > item.ttl_days:
continue
score = (item.base_weight or 1.0) * _decay_factor(delta_days)
if score <= 0.1: # 忽略低权重
continue
scored.append((score, item))
scored.sort(key=lambda x: x[0], reverse=True)
return [item for _, item in scored[:limit]]
def _format_memories_to_prompt(memories: list[MemoryItem]) -> str:
"""将记忆项转换为自然语言指令。"""
lines = []
# 分类处理
recent_stories = []
favorites = []
scary = []
vocab = []
for m in memories:
if m.type == MemoryType.RECENT_STORY:
recent_stories.append(m)
elif m.type == MemoryType.FAVORITE_CHARACTER:
favorites.append(m)
elif m.type == MemoryType.SCARY_ELEMENT:
scary.append(m)
elif m.type == MemoryType.VOCABULARY_GROWTH:
vocab.append(m)
# 1. 喜欢的角色
if favorites:
names = []
for m in favorites:
val = m.value
if isinstance(val, dict):
names.append(f"{val.get('name')} ({val.get('description', '')})")
if names:
lines.append(f"- 孩子特别喜欢这些角色,可以让他们客串出场:{', '.join(names)}")
# 2. 避雷区
if scary:
items = []
for m in scary:
val = m.value
if isinstance(val, dict):
items.append(val.get('keyword', ''))
elif isinstance(val, str):
items.append(val)
if items:
lines.append(f"- 【注意禁止】不要出现以下让孩子害怕的元素:{', '.join(items)}")
# 3. 近期故事 (取最近 2 个)
if recent_stories:
lines.append("- 近期经历(可作为彩蛋提及):")
for m in recent_stories[:2]:
val = m.value
if isinstance(val, dict):
title = val.get('title', '未知故事')
lines.append(f" * 之前读过《{title}")
# 4. 词汇积累
if vocab:
words = []
for m in vocab:
val = m.value
if isinstance(val, dict):
words.append(val.get('word'))
if words:
lines.append(f"- 已掌握词汇(可适当复现以巩固):{', '.join([w for w in words if w])}")
return "\n".join(lines)
async def prune_expired_memories(db: AsyncSession) -> int:
"""清理过期的记忆项。
Returns:
删除的记录数量
"""
from sqlalchemy import delete
now = datetime.now(timezone.utc)
# 查找所有设置了 TTL 的项目
stmt = select(MemoryItem).where(MemoryItem.ttl_days.is_not(None))
result = await db.execute(stmt)
candidates = result.scalars().all()
to_delete_ids = []
for item in candidates:
if not item.ttl_days:
continue
reference = item.last_used_at or item.created_at or now
delta_days = (now - reference).total_seconds() / 86400
if delta_days > item.ttl_days:
to_delete_ids.append(item.id)
if not to_delete_ids:
return 0
delete_stmt = delete(MemoryItem).where(MemoryItem.id.in_(to_delete_ids))
await db.execute(delete_stmt)
await db.commit()
logger.info("memory_pruned", count=len(to_delete_ids))
return len(to_delete_ids)
async def create_memory(
db: AsyncSession,
profile_id: str,
memory_type: str,
value: dict,
universe_id: str | None = None,
weight: float | None = None,
ttl_days: int | None = None,
) -> MemoryItem:
"""创建新的记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 记忆类型 (使用 MemoryType 常量)
value: 记忆内容 (JSON 格式)
universe_id: 可选,关联的故事宇宙 ID
weight: 可选,权重 (默认使用类型配置)
ttl_days: 可选,过期天数 (默认使用类型配置)
Returns:
创建的 MemoryItem
"""
memory = MemoryItem(
child_profile_id=profile_id,
universe_id=universe_id,
type=memory_type,
value=value,
base_weight=weight or MemoryType.get_default_weight(memory_type),
ttl_days=ttl_days if ttl_days is not None else MemoryType.get_default_ttl(memory_type),
)
db.add(memory)
await db.commit()
await db.refresh(memory)
logger.info(
"memory_created",
memory_id=memory.id,
profile_id=profile_id,
type=memory_type,
)
return memory
async def update_memory_usage(db: AsyncSession, memory_id: str) -> None:
"""更新记忆的最后使用时间。
Args:
db: 数据库会话
memory_id: 记忆项 ID
"""
result = await db.execute(select(MemoryItem).where(MemoryItem.id == memory_id))
memory = result.scalar_one_or_none()
if memory:
memory.last_used_at = datetime.now(timezone.utc)
await db.commit()
logger.debug("memory_usage_updated", memory_id=memory_id)
async def get_profile_memories(
db: AsyncSession,
profile_id: str,
memory_type: str | None = None,
universe_id: str | None = None,
limit: int = 50,
) -> list[MemoryItem]:
"""获取档案的记忆列表。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 可选,按类型筛选
universe_id: 可选,按宇宙筛选
limit: 返回数量限制
Returns:
MemoryItem 列表
"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if memory_type:
query = query.where(MemoryItem.type == memory_type)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
query = query.order_by(MemoryItem.created_at.desc()).limit(limit)
result = await db.execute(query)
return list(result.scalars().all())
async def create_story_memory(
db: AsyncSession,
profile_id: str,
story_id: int,
title: str,
summary: str | None = None,
keywords: list[str] | None = None,
universe_id: str | None = None,
) -> MemoryItem:
"""为故事创建记忆项。
这是一个便捷函数,专门用于在故事阅读后创建 recent_story 类型的记忆。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
story_id: 故事 ID
title: 故事标题
summary: 故事梗概
keywords: 关键词列表
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"story_id": story_id,
"title": title,
"summary": summary or "",
"keywords": keywords or [],
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.RECENT_STORY,
value=value,
universe_id=universe_id,
)
async def create_character_memory(
db: AsyncSession,
profile_id: str,
name: str,
description: str | None = None,
source_story_id: int | None = None,
affinity_score: float = 1.0,
universe_id: str | None = None,
) -> MemoryItem:
"""为喜欢的角色创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
name: 角色名称
description: 角色描述
source_story_id: 来源故事 ID
affinity_score: 喜爱程度 (0.0-1.0)
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"name": name,
"description": description or "",
"source_story_id": source_story_id,
"affinity_score": min(1.0, max(0.0, affinity_score)),
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.FAVORITE_CHARACTER,
value=value,
universe_id=universe_id,
)
async def create_scary_element_memory(
db: AsyncSession,
profile_id: str,
keyword: str,
category: str = "other",
source_story_id: int | None = None,
) -> MemoryItem:
"""为回避元素创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
keyword: 回避的关键词
category: 分类 (creature/scene/action/other)
source_story_id: 来源故事 ID
Returns:
创建的 MemoryItem
"""
value = {
"keyword": keyword,
"category": category,
"source_story_id": source_story_id,
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.SCARY_ELEMENT,
value=value,
)
"""Memory service handles memory retrieval, scoring, and prompt injection."""
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models import ChildProfile, MemoryItem, StoryUniverse
logger = get_logger(__name__)
class MemoryType:
"""记忆类型常量及配置。"""
# 基础类型
RECENT_STORY = "recent_story"
FAVORITE_CHARACTER = "favorite_character"
SCARY_ELEMENT = "scary_element"
VOCABULARY_GROWTH = "vocabulary_growth"
EMOTIONAL_HIGHLIGHT = "emotional_highlight"
# Phase 1 新增类型
READING_PREFERENCE = "reading_preference" # 阅读偏好
MILESTONE = "milestone" # 里程碑事件
SKILL_MASTERED = "skill_mastered" # 掌握的技能
# 类型配置: (默认权重, 默认TTL天数, 描述)
CONFIG = {
RECENT_STORY: (1.0, 30, "最近阅读的故事"),
FAVORITE_CHARACTER: (1.5, None, "喜欢的角色"), # None = 永久
SCARY_ELEMENT: (2.0, None, "回避的元素"), # 高权重,永久有效
VOCABULARY_GROWTH: (0.8, 90, "词汇积累"),
EMOTIONAL_HIGHLIGHT: (1.2, 60, "情感高光"),
READING_PREFERENCE: (1.0, None, "阅读偏好"),
MILESTONE: (1.5, None, "里程碑事件"),
SKILL_MASTERED: (1.0, 180, "掌握的技能"),
}
@classmethod
def get_default_weight(cls, memory_type: str) -> float:
"""获取类型的默认权重。"""
config = cls.CONFIG.get(memory_type)
return config[0] if config else 1.0
@classmethod
def get_default_ttl(cls, memory_type: str) -> int | None:
"""获取类型的默认 TTL 天数。"""
config = cls.CONFIG.get(memory_type)
return config[1] if config else None
def _decay_factor(days: float) -> float:
"""计算时间衰减因子。"""
if days <= 7:
return 1.0
if days <= 30:
return 0.7
if days <= 90:
return 0.4
return 0.2
async def build_enhanced_memory_context(
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
) -> str | None:
"""构建增强版记忆上下文(自然语言 Prompt"""
if not profile_id and not universe_id:
return None
context_parts: list[str] = []
# 1. 基础档案 (Identity Layer)
if profile_id:
profile = await db.scalar(select(ChildProfile).where(ChildProfile.id == profile_id))
if profile:
context_parts.append(f"【目标读者】\n姓名:{profile.name}")
if profile.age:
context_parts.append(f"年龄:{profile.age}")
if profile.interests:
context_parts.append(f"兴趣爱好:{''.join(profile.interests)}")
if profile.growth_themes:
context_parts.append(f"当前成长关注点:{''.join(profile.growth_themes)}")
context_parts.append("") # 空行
# 2. 故事宇宙 (Universe Layer)
if universe_id:
universe = await db.scalar(select(StoryUniverse).where(StoryUniverse.id == universe_id))
if universe:
context_parts.append("【故事宇宙设定】")
context_parts.append(f"世界观:{universe.name}")
# 主角
protagonist = universe.protagonist or {}
p_desc = f"{protagonist.get('name', '主角')} ({protagonist.get('personality', '')})"
context_parts.append(f"主角设定:{p_desc}")
# 常驻角色
if universe.recurring_characters:
chars = [f"{c.get('name')} ({c.get('type')})" for c in universe.recurring_characters if isinstance(c, dict)]
context_parts.append(f"已知伙伴:{''.join(chars)}")
# 成就
if universe.achievements:
badges = [str(a.get('type')) for a in universe.achievements if isinstance(a, dict)]
if badges:
context_parts.append(f"已获荣誉:{''.join(badges[:5])}")
context_parts.append("")
# 3. 动态记忆 (Working Memory)
if profile_id:
memories = await _fetch_scored_memories(profile_id, universe_id, db)
if memories:
memory_text = _format_memories_to_prompt(memories)
if memory_text:
context_parts.append("【关键记忆回忆】(请在故事中自然地融入或致敬以下元素)")
context_parts.append(memory_text)
return "\n".join(context_parts)
async def _fetch_scored_memories(
profile_id: str,
universe_id: str | None,
db: AsyncSession,
limit: int = 8
) -> list[MemoryItem]:
"""获取并评分记忆项,返回 Top N。"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
# 取最近 50 条进行评分
query = query.order_by(MemoryItem.last_used_at.desc(), MemoryItem.created_at.desc()).limit(50)
result = await db.execute(query)
items = result.scalars().all()
scored: list[tuple[float, MemoryItem]] = []
now = datetime.now(timezone.utc)
for item in items:
reference = item.last_used_at or item.created_at or now
delta_days = max((now - reference).total_seconds() / 86400, 0)
if item.ttl_days and delta_days > item.ttl_days:
continue
score = (item.base_weight or 1.0) * _decay_factor(delta_days)
if score <= 0.1: # 忽略低权重
continue
scored.append((score, item))
scored.sort(key=lambda x: x[0], reverse=True)
return [item for _, item in scored[:limit]]
def _format_memories_to_prompt(memories: list[MemoryItem]) -> str:
"""将记忆项转换为自然语言指令。"""
lines = []
# 分类处理
recent_stories = []
favorites = []
scary = []
vocab = []
for m in memories:
if m.type == MemoryType.RECENT_STORY:
recent_stories.append(m)
elif m.type == MemoryType.FAVORITE_CHARACTER:
favorites.append(m)
elif m.type == MemoryType.SCARY_ELEMENT:
scary.append(m)
elif m.type == MemoryType.VOCABULARY_GROWTH:
vocab.append(m)
# 1. 喜欢的角色
if favorites:
names = []
for m in favorites:
val = m.value
if isinstance(val, dict):
names.append(f"{val.get('name')} ({val.get('description', '')})")
if names:
lines.append(f"- 孩子特别喜欢这些角色,可以让他们客串出场:{', '.join(names)}")
# 2. 避雷区
if scary:
items = []
for m in scary:
val = m.value
if isinstance(val, dict):
items.append(val.get('keyword', ''))
elif isinstance(val, str):
items.append(val)
if items:
lines.append(f"- 【注意禁止】不要出现以下让孩子害怕的元素:{', '.join(items)}")
# 3. 近期故事 (取最近 2 个)
if recent_stories:
lines.append("- 近期经历(可作为彩蛋提及):")
for m in recent_stories[:2]:
val = m.value
if isinstance(val, dict):
title = val.get('title', '未知故事')
lines.append(f" * 之前读过《{title}")
# 4. 词汇积累
if vocab:
words = []
for m in vocab:
val = m.value
if isinstance(val, dict):
words.append(val.get('word'))
if words:
lines.append(f"- 已掌握词汇(可适当复现以巩固):{', '.join([w for w in words if w])}")
return "\n".join(lines)
async def prune_expired_memories(db: AsyncSession) -> int:
"""清理过期的记忆项。
Returns:
删除的记录数量
"""
from sqlalchemy import delete
now = datetime.now(timezone.utc)
# 查找所有设置了 TTL 的项目
stmt = select(MemoryItem).where(MemoryItem.ttl_days.is_not(None))
result = await db.execute(stmt)
candidates = result.scalars().all()
to_delete_ids = []
for item in candidates:
if not item.ttl_days:
continue
reference = item.last_used_at or item.created_at or now
delta_days = (now - reference).total_seconds() / 86400
if delta_days > item.ttl_days:
to_delete_ids.append(item.id)
if not to_delete_ids:
return 0
delete_stmt = delete(MemoryItem).where(MemoryItem.id.in_(to_delete_ids))
await db.execute(delete_stmt)
await db.commit()
logger.info("memory_pruned", count=len(to_delete_ids))
return len(to_delete_ids)
async def create_memory(
db: AsyncSession,
profile_id: str,
memory_type: str,
value: dict,
universe_id: str | None = None,
weight: float | None = None,
ttl_days: int | None = None,
) -> MemoryItem:
"""创建新的记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 记忆类型 (使用 MemoryType 常量)
value: 记忆内容 (JSON 格式)
universe_id: 可选,关联的故事宇宙 ID
weight: 可选,权重 (默认使用类型配置)
ttl_days: 可选,过期天数 (默认使用类型配置)
Returns:
创建的 MemoryItem
"""
memory = MemoryItem(
child_profile_id=profile_id,
universe_id=universe_id,
type=memory_type,
value=value,
base_weight=weight or MemoryType.get_default_weight(memory_type),
ttl_days=ttl_days if ttl_days is not None else MemoryType.get_default_ttl(memory_type),
)
db.add(memory)
await db.commit()
await db.refresh(memory)
logger.info(
"memory_created",
memory_id=memory.id,
profile_id=profile_id,
type=memory_type,
)
return memory
async def update_memory_usage(db: AsyncSession, memory_id: str) -> None:
"""更新记忆的最后使用时间。
Args:
db: 数据库会话
memory_id: 记忆项 ID
"""
result = await db.execute(select(MemoryItem).where(MemoryItem.id == memory_id))
memory = result.scalar_one_or_none()
if memory:
memory.last_used_at = datetime.now(timezone.utc)
await db.commit()
logger.debug("memory_usage_updated", memory_id=memory_id)
async def get_profile_memories(
db: AsyncSession,
profile_id: str,
memory_type: str | None = None,
universe_id: str | None = None,
limit: int = 50,
) -> list[MemoryItem]:
"""获取档案的记忆列表。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 可选,按类型筛选
universe_id: 可选,按宇宙筛选
limit: 返回数量限制
Returns:
MemoryItem 列表
"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if memory_type:
query = query.where(MemoryItem.type == memory_type)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
query = query.order_by(MemoryItem.created_at.desc()).limit(limit)
result = await db.execute(query)
return list(result.scalars().all())
async def create_story_memory(
db: AsyncSession,
profile_id: str,
story_id: int,
title: str,
summary: str | None = None,
keywords: list[str] | None = None,
universe_id: str | None = None,
) -> MemoryItem:
"""为故事创建记忆项。
这是一个便捷函数,专门用于在故事阅读后创建 recent_story 类型的记忆。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
story_id: 故事 ID
title: 故事标题
summary: 故事梗概
keywords: 关键词列表
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"story_id": story_id,
"title": title,
"summary": summary or "",
"keywords": keywords or [],
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.RECENT_STORY,
value=value,
universe_id=universe_id,
)
async def create_character_memory(
db: AsyncSession,
profile_id: str,
name: str,
description: str | None = None,
source_story_id: int | None = None,
affinity_score: float = 1.0,
universe_id: str | None = None,
) -> MemoryItem:
"""为喜欢的角色创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
name: 角色名称
description: 角色描述
source_story_id: 来源故事 ID
affinity_score: 喜爱程度 (0.0-1.0)
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"name": name,
"description": description or "",
"source_story_id": source_story_id,
"affinity_score": min(1.0, max(0.0, affinity_score)),
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.FAVORITE_CHARACTER,
value=value,
universe_id=universe_id,
)
async def create_scary_element_memory(
db: AsyncSession,
profile_id: str,
keyword: str,
category: str = "other",
source_story_id: int | None = None,
) -> MemoryItem:
"""为回避元素创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
keyword: 回避的关键词
category: 分类 (creature/scene/action/other)
source_story_id: 来源故事 ID
Returns:
创建的 MemoryItem
"""
value = {
"keyword": keyword,
"category": category,
"source_story_id": source_story_id,
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.SCARY_ELEMENT,
value=value,
)

View File

@@ -1,109 +1,109 @@
"""Redis-backed cache for providers loaded from DB."""
import json
from collections import defaultdict
from typing import Literal
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.core.redis import get_redis
from app.db.admin_models import Provider
logger = get_logger(__name__)
ProviderType = Literal["text", "image", "tts", "storybook"]
class CachedProvider(BaseModel):
"""Serializable provider configuration matching DB model fields."""
id: str
name: str
type: str
adapter: str
model: str | None = None
api_base: str | None = None
api_key: str | None = None
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_json: dict | None = None
config_ref: str | None = None
# Local memory fallback (L1 cache)
_local_cache: dict[ProviderType, list[CachedProvider]] = defaultdict(list)
CACHE_KEY = "dreamweaver:providers:config"
async def reload_providers(db: AsyncSession) -> dict[ProviderType, list[CachedProvider]]:
"""Reload providers from DB and update Redis cache."""
try:
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
providers = result.scalars().all()
# Convert to Pydantic models
cached_list = []
for p in providers:
cached_list.append(CachedProvider(
id=p.id,
name=p.name,
type=p.type,
adapter=p.adapter,
model=p.model,
api_base=p.api_base,
api_key=p.api_key,
timeout_ms=p.timeout_ms,
max_retries=p.max_retries,
weight=p.weight,
priority=p.priority,
enabled=p.enabled,
config_json=p.config_json,
config_ref=p.config_ref
))
# Group by type
grouped: dict[str, list[CachedProvider]] = defaultdict(list)
for cp in cached_list:
grouped[cp.type].append(cp)
# Sort
for k in grouped:
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
# Update Redis
redis = await get_redis()
# Serialize entire dict structure
# Pydantic -> dict -> json
json_data = {k: [p.model_dump() for p in v] for k, v in grouped.items()}
await redis.set(CACHE_KEY, json.dumps(json_data))
# Update local cache
_local_cache.clear()
_local_cache.update(grouped)
return grouped
except Exception as e:
logger.error("failed_to_reload_providers", error=str(e))
raise
async def get_providers(provider_type: ProviderType) -> list[CachedProvider]:
"""Get providers from Redis (preferred) or local fallback."""
try:
redis = await get_redis()
data = await redis.get(CACHE_KEY)
if data:
raw_dict = json.loads(data)
if provider_type in raw_dict:
return [CachedProvider(**item) for item in raw_dict[provider_type]]
return []
except Exception as e:
logger.warning("redis_cache_read_failed", error=str(e))
# Fallback to local memory
return _local_cache.get(provider_type, [])
"""Redis-backed cache for providers loaded from DB."""
import json
from collections import defaultdict
from typing import Literal
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.core.redis import get_redis
from app.db.admin_models import Provider
logger = get_logger(__name__)
ProviderType = Literal["text", "image", "tts", "storybook"]
class CachedProvider(BaseModel):
"""Serializable provider configuration matching DB model fields."""
id: str
name: str
type: str
adapter: str
model: str | None = None
api_base: str | None = None
api_key: str | None = None
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_json: dict | None = None
config_ref: str | None = None
# Local memory fallback (L1 cache)
_local_cache: dict[ProviderType, list[CachedProvider]] = defaultdict(list)
CACHE_KEY = "dreamweaver:providers:config"
async def reload_providers(db: AsyncSession) -> dict[ProviderType, list[CachedProvider]]:
"""Reload providers from DB and update Redis cache."""
try:
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
providers = result.scalars().all()
# Convert to Pydantic models
cached_list = []
for p in providers:
cached_list.append(CachedProvider(
id=p.id,
name=p.name,
type=p.type,
adapter=p.adapter,
model=p.model,
api_base=p.api_base,
api_key=p.api_key,
timeout_ms=p.timeout_ms,
max_retries=p.max_retries,
weight=p.weight,
priority=p.priority,
enabled=p.enabled,
config_json=p.config_json,
config_ref=p.config_ref
))
# Group by type
grouped: dict[str, list[CachedProvider]] = defaultdict(list)
for cp in cached_list:
grouped[cp.type].append(cp)
# Sort
for k in grouped:
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
# Update Redis
redis = await get_redis()
# Serialize entire dict structure
# Pydantic -> dict -> json
json_data = {k: [p.model_dump() for p in v] for k, v in grouped.items()}
await redis.set(CACHE_KEY, json.dumps(json_data))
# Update local cache
_local_cache.clear()
_local_cache.update(grouped)
return grouped
except Exception as e:
logger.error("failed_to_reload_providers", error=str(e))
raise
async def get_providers(provider_type: ProviderType) -> list[CachedProvider]:
"""Get providers from Redis (preferred) or local fallback."""
try:
redis = await get_redis()
data = await redis.get(CACHE_KEY)
if data:
raw_dict = json.loads(data)
if provider_type in raw_dict:
return [CachedProvider(**item) for item in raw_dict[provider_type]]
return []
except Exception as e:
logger.warning("redis_cache_read_failed", error=str(e))
# Fallback to local memory
return _local_cache.get(provider_type, [])

View File

@@ -1,248 +1,248 @@
"""供应商指标收集和健康检查服务。"""
from datetime import datetime, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import ProviderHealth, ProviderMetrics
if TYPE_CHECKING:
from app.services.adapters.base import BaseAdapter
logger = get_logger(__name__)
# 熔断阈值:连续失败次数
CIRCUIT_BREAKER_THRESHOLD = 3
# 熔断恢复时间(秒)
CIRCUIT_BREAKER_RECOVERY_SECONDS = 60
class MetricsCollector:
"""供应商调用指标收集器。"""
async def record_call(
self,
db: AsyncSession,
provider_id: str,
success: bool,
latency_ms: int | None = None,
cost_usd: float | None = None,
error_message: str | None = None,
request_id: str | None = None,
) -> None:
"""记录一次 API 调用。"""
metric = ProviderMetrics(
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
cost_usd=Decimal(str(cost_usd)) if cost_usd else None,
error_message=error_message,
request_id=request_id,
)
db.add(metric)
await db.commit()
logger.debug(
"metrics_recorded",
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
)
async def get_success_rate(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的成功率。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(
func.count().filter(ProviderMetrics.success.is_(True)).label("success_count"),
func.count().label("total_count"),
).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
row = result.one()
success_count, total_count = row.success_count, row.total_count
if total_count == 0:
return 1.0 # 无数据时假设健康
return success_count / total_count
async def get_avg_latency(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的平均延迟(毫秒)。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.avg(ProviderMetrics.latency_ms)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
ProviderMetrics.latency_ms.isnot(None),
)
)
avg = result.scalar()
return float(avg) if avg else 0.0
async def get_total_cost(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的总成本USD"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.sum(ProviderMetrics.cost_usd)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
total = result.scalar()
return float(total) if total else 0.0
class HealthChecker:
"""供应商健康检查器。"""
async def check_provider(
self,
db: AsyncSession,
provider_id: str,
adapter: "BaseAdapter",
) -> bool:
"""执行健康检查并更新状态。"""
try:
is_healthy = await adapter.health_check()
except Exception as e:
logger.warning("health_check_failed", provider_id=provider_id, error=str(e))
is_healthy = False
await self.update_health_status(
db,
provider_id,
is_healthy,
error=None if is_healthy else "Health check failed",
)
return is_healthy
async def update_health_status(
self,
db: AsyncSession,
provider_id: str,
is_healthy: bool,
error: str | None = None,
) -> None:
"""更新供应商健康状态(含熔断逻辑)。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
now = datetime.utcnow()
if health is None:
health = ProviderHealth(
provider_id=provider_id,
is_healthy=is_healthy,
last_check=now,
consecutive_failures=0 if is_healthy else 1,
last_error=error,
)
db.add(health)
else:
health.last_check = now
if is_healthy:
health.is_healthy = True
health.consecutive_failures = 0
health.last_error = None
else:
health.consecutive_failures += 1
health.last_error = error
# 熔断逻辑
if health.consecutive_failures >= CIRCUIT_BREAKER_THRESHOLD:
health.is_healthy = False
logger.warning(
"circuit_breaker_triggered",
provider_id=provider_id,
consecutive_failures=health.consecutive_failures,
)
await db.commit()
async def record_call_result(
self,
db: AsyncSession,
provider_id: str,
success: bool,
error: str | None = None,
) -> None:
"""根据调用结果更新健康状态。"""
await self.update_health_status(db, provider_id, success, error)
async def get_healthy_providers(
self,
db: AsyncSession,
provider_ids: list[str],
) -> list[str]:
"""获取健康的供应商列表。"""
if not provider_ids:
return []
# 查询所有已记录的健康状态
result = await db.execute(
select(ProviderHealth.provider_id, ProviderHealth.is_healthy).where(
ProviderHealth.provider_id.in_(provider_ids),
)
)
health_map = {row[0]: row[1] for row in result.all()}
# 未记录的供应商默认健康,已记录但不健康的排除
return [
pid for pid in provider_ids
if pid not in health_map or health_map[pid]
]
async def is_healthy(
self,
db: AsyncSession,
provider_id: str,
) -> bool:
"""检查供应商是否健康。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
if health is None:
return True # 未记录默认健康
# 检查是否可以恢复
if not health.is_healthy and health.last_check:
recovery_time = health.last_check + timedelta(seconds=CIRCUIT_BREAKER_RECOVERY_SECONDS)
if datetime.utcnow() >= recovery_time:
return True # 允许重试
return health.is_healthy
# 全局单例
metrics_collector = MetricsCollector()
health_checker = HealthChecker()
"""供应商指标收集和健康检查服务。"""
from datetime import datetime, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import ProviderHealth, ProviderMetrics
if TYPE_CHECKING:
from app.services.adapters.base import BaseAdapter
logger = get_logger(__name__)
# 熔断阈值:连续失败次数
CIRCUIT_BREAKER_THRESHOLD = 3
# 熔断恢复时间(秒)
CIRCUIT_BREAKER_RECOVERY_SECONDS = 60
class MetricsCollector:
"""供应商调用指标收集器。"""
async def record_call(
self,
db: AsyncSession,
provider_id: str,
success: bool,
latency_ms: int | None = None,
cost_usd: float | None = None,
error_message: str | None = None,
request_id: str | None = None,
) -> None:
"""记录一次 API 调用。"""
metric = ProviderMetrics(
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
cost_usd=Decimal(str(cost_usd)) if cost_usd else None,
error_message=error_message,
request_id=request_id,
)
db.add(metric)
await db.commit()
logger.debug(
"metrics_recorded",
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
)
async def get_success_rate(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的成功率。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(
func.count().filter(ProviderMetrics.success.is_(True)).label("success_count"),
func.count().label("total_count"),
).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
row = result.one()
success_count, total_count = row.success_count, row.total_count
if total_count == 0:
return 1.0 # 无数据时假设健康
return success_count / total_count
async def get_avg_latency(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的平均延迟(毫秒)。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.avg(ProviderMetrics.latency_ms)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
ProviderMetrics.latency_ms.isnot(None),
)
)
avg = result.scalar()
return float(avg) if avg else 0.0
async def get_total_cost(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的总成本USD"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.sum(ProviderMetrics.cost_usd)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
total = result.scalar()
return float(total) if total else 0.0
class HealthChecker:
"""供应商健康检查器。"""
async def check_provider(
self,
db: AsyncSession,
provider_id: str,
adapter: "BaseAdapter",
) -> bool:
"""执行健康检查并更新状态。"""
try:
is_healthy = await adapter.health_check()
except Exception as e:
logger.warning("health_check_failed", provider_id=provider_id, error=str(e))
is_healthy = False
await self.update_health_status(
db,
provider_id,
is_healthy,
error=None if is_healthy else "Health check failed",
)
return is_healthy
async def update_health_status(
self,
db: AsyncSession,
provider_id: str,
is_healthy: bool,
error: str | None = None,
) -> None:
"""更新供应商健康状态(含熔断逻辑)。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
now = datetime.utcnow()
if health is None:
health = ProviderHealth(
provider_id=provider_id,
is_healthy=is_healthy,
last_check=now,
consecutive_failures=0 if is_healthy else 1,
last_error=error,
)
db.add(health)
else:
health.last_check = now
if is_healthy:
health.is_healthy = True
health.consecutive_failures = 0
health.last_error = None
else:
health.consecutive_failures += 1
health.last_error = error
# 熔断逻辑
if health.consecutive_failures >= CIRCUIT_BREAKER_THRESHOLD:
health.is_healthy = False
logger.warning(
"circuit_breaker_triggered",
provider_id=provider_id,
consecutive_failures=health.consecutive_failures,
)
await db.commit()
async def record_call_result(
self,
db: AsyncSession,
provider_id: str,
success: bool,
error: str | None = None,
) -> None:
"""根据调用结果更新健康状态。"""
await self.update_health_status(db, provider_id, success, error)
async def get_healthy_providers(
self,
db: AsyncSession,
provider_ids: list[str],
) -> list[str]:
"""获取健康的供应商列表。"""
if not provider_ids:
return []
# 查询所有已记录的健康状态
result = await db.execute(
select(ProviderHealth.provider_id, ProviderHealth.is_healthy).where(
ProviderHealth.provider_id.in_(provider_ids),
)
)
health_map = {row[0]: row[1] for row in result.all()}
# 未记录的供应商默认健康,已记录但不健康的排除
return [
pid for pid in provider_ids
if pid not in health_map or health_map[pid]
]
async def is_healthy(
self,
db: AsyncSession,
provider_id: str,
) -> bool:
"""检查供应商是否健康。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
if health is None:
return True # 未记录默认健康
# 检查是否可以恢复
if not health.is_healthy and health.last_check:
recovery_time = health.last_check + timedelta(seconds=CIRCUIT_BREAKER_RECOVERY_SECONDS)
if datetime.utcnow() >= recovery_time:
return True # 允许重试
return health.is_healthy
# 全局单例
metrics_collector = MetricsCollector()
health_checker = HealthChecker()

View File

@@ -1,207 +1,207 @@
"""供应商密钥加密存储服务。
使用 Fernet 对称加密,密钥从 SECRET_KEY 派生。
"""
import base64
import hashlib
from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.admin_models import ProviderSecret
if TYPE_CHECKING:
pass
logger = get_logger(__name__)
class SecretEncryptionError(Exception):
"""密钥加密/解密错误。"""
pass
class SecretService:
"""供应商密钥加密存储服务。"""
_fernet: Fernet | None = None
@classmethod
def _get_fernet(cls) -> Fernet:
"""获取 Fernet 实例,从 SECRET_KEY 派生加密密钥。"""
if cls._fernet is None:
# 从 SECRET_KEY 派生 32 字节密钥
key_bytes = hashlib.sha256(settings.secret_key.encode()).digest()
fernet_key = base64.urlsafe_b64encode(key_bytes)
cls._fernet = Fernet(fernet_key)
return cls._fernet
@classmethod
def encrypt(cls, plaintext: str) -> str:
"""加密明文,返回 base64 编码的密文。
Args:
plaintext: 要加密的明文
Returns:
base64 编码的密文
"""
if not plaintext:
return ""
fernet = cls._get_fernet()
encrypted = fernet.encrypt(plaintext.encode())
return encrypted.decode()
@classmethod
def decrypt(cls, ciphertext: str) -> str:
"""解密密文,返回明文。
Args:
ciphertext: base64 编码的密文
Returns:
解密后的明文
Raises:
SecretEncryptionError: 解密失败
"""
if not ciphertext:
return ""
try:
fernet = cls._get_fernet()
decrypted = fernet.decrypt(ciphertext.encode())
return decrypted.decode()
except InvalidToken as e:
logger.error("secret_decrypt_failed", error=str(e))
raise SecretEncryptionError("密钥解密失败,可能是 SECRET_KEY 已更改") from e
@classmethod
async def get_secret(cls, db: AsyncSession, name: str) -> str | None:
"""从数据库获取并解密密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
解密后的密钥值,不存在返回 None
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return None
return cls.decrypt(secret.encrypted_value)
@classmethod
async def set_secret(cls, db: AsyncSession, name: str, value: str) -> ProviderSecret:
"""存储或更新加密密钥。
Args:
db: 数据库会话
name: 密钥名称
value: 密钥明文值
Returns:
ProviderSecret 实例
"""
encrypted = cls.encrypt(value)
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
secret = ProviderSecret(name=name, encrypted_value=encrypted)
db.add(secret)
else:
secret.encrypted_value = encrypted
await db.commit()
await db.refresh(secret)
logger.info("secret_stored", name=name)
return secret
@classmethod
async def delete_secret(cls, db: AsyncSession, name: str) -> bool:
"""删除密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
是否删除成功
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return False
await db.delete(secret)
await db.commit()
logger.info("secret_deleted", name=name)
return True
@classmethod
async def list_secrets(cls, db: AsyncSession) -> list[str]:
"""列出所有密钥名称(不返回值)。
Args:
db: 数据库会话
Returns:
密钥名称列表
"""
result = await db.execute(select(ProviderSecret.name))
return [row[0] for row in result.fetchall()]
@classmethod
async def get_api_key(
cls,
db: AsyncSession,
provider_api_key: str | None,
config_ref: str | None,
) -> str | None:
"""获取 Provider 的 API Key按优先级查找。
优先级:
1. provider.api_key (数据库明文/加密)
2. provider.config_ref 指向的 ProviderSecret
3. 环境变量 (config_ref 作为变量名)
Args:
db: 数据库会话
provider_api_key: Provider 表中的 api_key 字段
config_ref: Provider 表中的 config_ref 字段
Returns:
API Key 或 None
"""
# 1. 直接使用 provider.api_key
if provider_api_key:
# 尝试解密,如果失败则当作明文
try:
decrypted = cls.decrypt(provider_api_key)
if decrypted:
return decrypted
except SecretEncryptionError:
pass
return provider_api_key
# 2. 从 ProviderSecret 表查找
if config_ref:
secret_value = await cls.get_secret(db, config_ref)
if secret_value:
return secret_value
# 3. 从环境变量查找
env_value = getattr(settings, config_ref.lower(), None)
if env_value:
return env_value
return None
"""供应商密钥加密存储服务。
使用 Fernet 对称加密,密钥从 SECRET_KEY 派生。
"""
import base64
import hashlib
from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.admin_models import ProviderSecret
if TYPE_CHECKING:
pass
logger = get_logger(__name__)
class SecretEncryptionError(Exception):
"""密钥加密/解密错误。"""
pass
class SecretService:
"""供应商密钥加密存储服务。"""
_fernet: Fernet | None = None
@classmethod
def _get_fernet(cls) -> Fernet:
"""获取 Fernet 实例,从 SECRET_KEY 派生加密密钥。"""
if cls._fernet is None:
# 从 SECRET_KEY 派生 32 字节密钥
key_bytes = hashlib.sha256(settings.secret_key.encode()).digest()
fernet_key = base64.urlsafe_b64encode(key_bytes)
cls._fernet = Fernet(fernet_key)
return cls._fernet
@classmethod
def encrypt(cls, plaintext: str) -> str:
"""加密明文,返回 base64 编码的密文。
Args:
plaintext: 要加密的明文
Returns:
base64 编码的密文
"""
if not plaintext:
return ""
fernet = cls._get_fernet()
encrypted = fernet.encrypt(plaintext.encode())
return encrypted.decode()
@classmethod
def decrypt(cls, ciphertext: str) -> str:
"""解密密文,返回明文。
Args:
ciphertext: base64 编码的密文
Returns:
解密后的明文
Raises:
SecretEncryptionError: 解密失败
"""
if not ciphertext:
return ""
try:
fernet = cls._get_fernet()
decrypted = fernet.decrypt(ciphertext.encode())
return decrypted.decode()
except InvalidToken as e:
logger.error("secret_decrypt_failed", error=str(e))
raise SecretEncryptionError("密钥解密失败,可能是 SECRET_KEY 已更改") from e
@classmethod
async def get_secret(cls, db: AsyncSession, name: str) -> str | None:
"""从数据库获取并解密密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
解密后的密钥值,不存在返回 None
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return None
return cls.decrypt(secret.encrypted_value)
@classmethod
async def set_secret(cls, db: AsyncSession, name: str, value: str) -> ProviderSecret:
"""存储或更新加密密钥。
Args:
db: 数据库会话
name: 密钥名称
value: 密钥明文值
Returns:
ProviderSecret 实例
"""
encrypted = cls.encrypt(value)
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
secret = ProviderSecret(name=name, encrypted_value=encrypted)
db.add(secret)
else:
secret.encrypted_value = encrypted
await db.commit()
await db.refresh(secret)
logger.info("secret_stored", name=name)
return secret
@classmethod
async def delete_secret(cls, db: AsyncSession, name: str) -> bool:
"""删除密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
是否删除成功
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return False
await db.delete(secret)
await db.commit()
logger.info("secret_deleted", name=name)
return True
@classmethod
async def list_secrets(cls, db: AsyncSession) -> list[str]:
"""列出所有密钥名称(不返回值)。
Args:
db: 数据库会话
Returns:
密钥名称列表
"""
result = await db.execute(select(ProviderSecret.name))
return [row[0] for row in result.fetchall()]
@classmethod
async def get_api_key(
cls,
db: AsyncSession,
provider_api_key: str | None,
config_ref: str | None,
) -> str | None:
"""获取 Provider 的 API Key按优先级查找。
优先级:
1. provider.api_key (数据库明文/加密)
2. provider.config_ref 指向的 ProviderSecret
3. 环境变量 (config_ref 作为变量名)
Args:
db: 数据库会话
provider_api_key: Provider 表中的 api_key 字段
config_ref: Provider 表中的 config_ref 字段
Returns:
API Key 或 None
"""
# 1. 直接使用 provider.api_key
if provider_api_key:
# 尝试解密,如果失败则当作明文
try:
decrypted = cls.decrypt(provider_api_key)
if decrypted:
return decrypted
except SecretEncryptionError:
pass
return provider_api_key
# 2. 从 ProviderSecret 表查找
if config_ref:
secret_value = await cls.get_secret(db, config_ref)
if secret_value:
return secret_value
# 3. 从环境变量查找
env_value = getattr(settings, config_ref.lower(), None)
if env_value:
return env_value
return None

View File

@@ -1,29 +1,29 @@
import asyncio
from app.core.celery_app import celery_app
from app.core.logging import get_logger
from app.db.database import _get_session_factory
from app.services.memory_service import prune_expired_memories
logger = get_logger(__name__)
@celery_app.task
def prune_memories_task():
"""Daily task to prune expired memories."""
logger.info("prune_memories_task_started")
async def _run():
# Ensure engine is initialized in this process
session_factory = _get_session_factory()
async with session_factory() as session:
return await prune_expired_memories(session)
try:
# Create a new event loop for this task execution
count = asyncio.run(_run())
logger.info("prune_memories_task_completed", deleted_count=count)
return f"Deleted {count} expired memories"
except Exception as exc:
logger.error("prune_memories_task_failed", error=str(exc))
raise
import asyncio
from app.core.celery_app import celery_app
from app.core.logging import get_logger
from app.db.database import _get_session_factory
from app.services.memory_service import prune_expired_memories
logger = get_logger(__name__)
@celery_app.task
def prune_memories_task():
"""Daily task to prune expired memories."""
logger.info("prune_memories_task_started")
async def _run():
# Ensure engine is initialized in this process
session_factory = _get_session_factory()
async with session_factory() as session:
return await prune_expired_memories(session)
try:
# Create a new event loop for this task execution
count = asyncio.run(_run())
logger.info("prune_memories_task_completed", deleted_count=count)
return f"Deleted {count} expired memories"
except Exception as exc:
logger.error("prune_memories_task_failed", error=str(exc))
raise

View File

@@ -1,14 +0,0 @@
# Code Review Report (2nd follow-up)
## What¡¯s fixed
- Provider cache now loads on startup via lifespan (`app/main.py`), so DB providers are honored without manual reload.
- Providers support DB-stored `api_key` precedence (`provider_router.py:77-104`) and Provider model added `api_key` column (`db/admin_models.py:25`).
- Frontend uses `/api/generate/full` and propagates image-failure warning to detail via query flag; StoryDetail displays banner when image generation failed.
- Tests added for full generation, provider failover, config-from-DB, and startup cache load.
## Remaining issue
1) **Missing DB migration for new Provider.api_key column**
- Files updated model (`backend/app/db/admin_models.py:25`) but `backend/alembic/versions/0001_init_providers_and_story_mode.py` lacks this column. Existing databases will not have `api_key`, causing runtime errors when accessing or inserting. Add an Alembic migration to add/drop `api_key` to `providers` table and update sample data if needed.
## Suggested action
- Create and apply an Alembic migration adding `api_key` (String, nullable) to `providers`. After migration, verify admin CRUD works with the new field.

View File

@@ -1,89 +0,0 @@
# HA 部署与验证 RunbookPhase 3 MVP
本文档对应 `docker-compose.ha.yml`,用于本地/测试环境验证高可用基础能力。
## 1. 启动方式
```bash
docker compose -f docker-compose.yml -f docker-compose.ha.yml up -d
```
说明:
- 基础业务服务仍来自 `docker-compose.yml`
- `docker-compose.ha.yml` 覆盖了 `db``redis`,并新增 `db-replica``postgres-backup``redis-replica``redis-sentinel-*`
## 2. 核心环境变量建议
`backend/.env`(或 shell 环境)中至少配置:
```env
# PostgreSQL
POSTGRES_USER=dreamweaver
POSTGRES_PASSWORD=dreamweaver_password
POSTGRES_DB=dreamweaver_db
POSTGRES_REPMGR_PASSWORD=repmgr_password
# Redis Sentinel
REDIS_SENTINEL_ENABLED=true
REDIS_SENTINEL_NODES=redis-sentinel-1:26379,redis-sentinel-2:26379,redis-sentinel-3:26379
REDIS_SENTINEL_MASTER_NAME=mymaster
REDIS_SENTINEL_DB=0
REDIS_SENTINEL_SOCKET_TIMEOUT=0.5
# 可选:若 Sentinel/Redis 设置了密码
REDIS_SENTINEL_PASSWORD=
# 备份周期,默认 86400 秒1 天)
BACKUP_INTERVAL_SECONDS=86400
```
## 3. 健康检查
### 3.1 PostgreSQL 主从
```bash
docker compose -f docker-compose.yml -f docker-compose.ha.yml ps
docker exec -it dreamweaver_db_primary psql -U dreamweaver -d dreamweaver_db -c "select now();"
docker exec -it dreamweaver_db_replica psql -U dreamweaver -d dreamweaver_db -c "select pg_is_in_recovery();"
```
期望:
- 主库可读写;
- 从库 `pg_is_in_recovery()` 返回 `t`
### 3.2 Redis Sentinel
```bash
docker exec -it dreamweaver_redis_sentinel_1 redis-cli -p 26379 sentinel masters
docker exec -it dreamweaver_redis_sentinel_1 redis-cli -p 26379 sentinel replicas mymaster
```
期望:
- `mymaster` 存在;
- 至少 1 个 replica 被发现。
### 3.3 备份任务
```bash
docker exec -it dreamweaver_postgres_backup sh -c "ls -lh /backups"
```
期望:
- `/backups` 下出现 `.dump` 文件;
- 旧于 7 天的备份会被自动清理。
## 4. 故障切换演练(最小)
```bash
# 模拟 Redis 主节点故障
docker stop dreamweaver_redis_master
# 等待 Sentinel 选主后查看
docker exec -it dreamweaver_redis_sentinel_1 redis-cli -p 26379 sentinel get-master-addr-by-name mymaster
```
提示:应用与 Celery 已支持 Sentinel 配置。若未启用 Sentinel仍可回退到 `REDIS_URL` / `CELERY_BROKER_URL` / `CELERY_RESULT_BACKEND` 直连模式。
## 5. 当前已知限制(下一步)
- PostgreSQL 侧当前仅完成主从拓扑读写分离PgBouncer/路由)待后续迭代。

View File

@@ -1,147 +0,0 @@
# 记忆系统开发指南 (Development Guide)
本文档详细说明了 PRD 中定义的记忆系统的技术实现细节。
## 1. 数据库架构变更 (Schema Changes)
目前的 `MemoryItem` 表结构尚可,但需要增强字段以支持丰富的情感和元数据。
### 1.1 `MemoryItem` 表优化
建议使用 Alembic 进行迁移,增加以下字段或在 `value` JSON 中规范化以下结构:
```python
# 建议在 models.py 中明确这些字段,或者严格定义 value 字段的 Schema
class MemoryItem(Base):
# ... 现有字段 ...
# 新增/规范化字段建议
# value 字段的 JSON 结构规范:
# {
# "content": "小兔子战胜了大灰狼", # 记忆的核心文本
# "keywords": ["勇敢", "森林"], # 用于检索的关键词
# "emotion": "positive", # 情感倾向: positive/negative/neutral
# "source_story_id": 123, # 来源故事 ID
# "confidence": 0.85 # 记忆置信度 (如果是 AI 自动提取)
# }
```
### 1.2 `StoryUniverse` 表优化 (成就系统)
目前的成就存储在 `achievements` JSON 字段中。为了支持更复杂的查询(如"获得勇气勋章的所有用户"),建议将其重构为独立关联表,或保持 JSON 但规范化结构。
**当前 JSON 结构规范**:
```json
[
{
"id": "badge_courage_01",
"type": "勇气",
"name": "小小勇士",
"description": "第一次在故事中独自面对困难",
"icon_url": "badges/courage.png",
"obtained_at": "2023-10-27T10:00:00Z",
"source_story_id": 45
}
]
```
---
## 2. 核心逻辑实现
### 2.1 记忆注入逻辑 (Prompt Engineering)
修改 `backend/app/api/stories.py` 中的 `_build_memory_context` 函数。
**目标**: 生成一段自然的、不仅是罗列数据的 Prompt。
**伪代码逻辑**:
```python
def format_memory_for_prompt(memories: list[MemoryItem]) -> str:
"""
将记忆项转换为自然语言 Prompt 片段。
"""
context_parts = []
# 1. 角色记忆
chars = [m for m in memories if m.type == 'favorite_character']
if chars:
names = ", ".join([c.value['name'] for c in chars])
context_parts.append(f"孩子特别喜欢的角色有:{names}。请尝试让他们客串出场。")
# 2. 近期情节
recent_stories = [m for m in memories if m.type == 'recent_story'][:2]
if recent_stories:
for story in recent_stories:
context_parts.append(f"最近发生过:{story.value['summary']}。可以在对话中不经意地提及此事。")
# 3. 避雷区 (负面记忆)
scary = [m for m in memories if m.type == 'scary_element']
if scary:
items = ", ".join([s.value['keyword'] for s in scary])
context_parts.append(f"【绝对禁止】不要出现以下让孩子害怕的元素:{items}")
return "\n".join(context_parts)
```
### 2.2 成就提取与通知流程
当前流程在 `app/tasks/achievements.py`。需要完善闭环。
**改进后的流程**:
1. **Story Generation**: 故事生成成功,存入数据库。
2. **Async Task**: 触发 Celery 任务 `extract_story_achievements`
3. **LLM Analysis**: 调用 Gemini 分析故事,提取成就。
4. **Update DB**: 更新 `StoryUniverse.achievements`
5. **Notification (新增)**:
* 创建一个 `Notification``UserMessage` 记录(需要新建表或使用 Redis Pub/Sub
* 前端轮询或通过 SSE (Server-Sent Events) 接收通知:"🎉 恭喜!在这个故事里,小明获得了[诚实勋章]"
### 2.3 记忆清理与衰减 (Maintenance)
需要一个后台定时任务Cron Job清理无效记忆避免 Prompt 过长。
* **频率**: 每天一次。
* **逻辑**:
* 删除 `ttl_days` 已过期的记录。
*`recent_story` 类型的 `base_weight` 进行每日衰减 update或者只在读取时计算数据库存静态值推荐读取时动态计算以减少写操作
*`MemoryItem` 总数超过 100 条时,触发"记忆总结"任务,将多条旧记忆合并为一条"长期印象" (Long-term Impression)。
---
## 3. API 接口规划
### 3.1 获取成长时间轴
`GET /api/profiles/{id}/timeline`
**Response**:
```json
{
"events": [
{
"date": "2023-10-01",
"type": "milestone",
"title": "初次相遇",
"description": "创建了角色 [小明]"
},
{
"date": "2023-10-05",
"type": "story",
"title": "小明与魔法树",
"image_url": "..."
},
{
"date": "2023-10-05",
"type": "achievement",
"badge": {
"name": "好奇宝宝",
"icon": "..."
}
}
]
}
```
### 3.2 记忆反馈 (人工干预)
`POST /api/memories/{id}/feedback`
允许家长手动删除或修正错误的记忆。
* **Action**: `delete` | `reinforce` (强化,增加权重)

View File

@@ -1,246 +0,0 @@
# Provider 系统开发文档
## 当前版本功能 (v0.2.0)
### 已完成功能
1. **CQTAI nano 图像适配器** (`app/services/adapters/image/cqtai.py`)
- 异步生成 + 轮询获取结果
- 支持 nano-banana / nano-banana-pro 模型
- 支持多种分辨率和画面比例
- 支持图生图 (filesUrl)
2. **密钥加密存储** (`app/services/secret_service.py`)
- Fernet 对称加密,密钥从 SECRET_KEY 派生
- Provider API Key 自动加密存储
- 密钥管理 API (CRUD)
3. **指标收集系统** (`app/services/provider_metrics.py`)
- 调用成功率、延迟、成本统计
- 时间窗口聚合查询
- 已集成到 provider_router
4. **熔断器功能** (`app/services/provider_metrics.py`)
- 连续失败 3 次触发熔断
- 60 秒后自动恢复尝试
- 健康状态持久化到数据库
5. **管理后台前端** (`app/admin_app.py`)
- 独立端口部署 (8001)
- Vue 3 + Tailwind CSS 单页应用
- Provider CRUD 管理
- 密钥管理界面
- Basic Auth 认证
### 配置说明
```bash
# 启动主应用
uvicorn app.main:app --port 8000
# 启动管理后台 (独立端口)
uvicorn app.admin_app:app --port 8001
```
环境变量:
```
CQTAI_API_KEY=your-cqtai-api-key
ENABLE_ADMIN_CONSOLE=true
ADMIN_USERNAME=admin
ADMIN_PASSWORD=your-secure-password
```
---
## 下一版本优化计划 (v0.3.0)
### 高优先级
#### 1. 智能负载分流 (方案 B)
**目标**: 主渠道压力大时自动分流到后备渠道
**实现方案**:
- 监控指标: 并发数、响应延迟、错误率
- 分流阈值配置:
```python
class LoadBalanceConfig:
max_concurrent: int = 10 # 并发超过此值时分流
max_latency_ms: int = 5000 # 延迟超过此值时分流
max_error_rate: float = 0.1 # 错误率超过 10% 时分流
```
- 分流策略: 加权轮询,根据健康度动态调整权重
**涉及文件**:
- `app/services/provider_router.py` - 添加负载均衡逻辑
- `app/services/provider_metrics.py` - 添加并发计数器
- `app/db/admin_models.py` - 添加 LoadBalanceConfig 模型
#### 2. Storybook 适配器
**目标**: 生成可翻页的分页故事书
**实现方案**:
- 参考 Gemini AI Story Generator 格式
- 输出结构:
```python
class StorybookPage:
page_number: int
text: str
image_prompt: str
image_url: str | None
class Storybook:
title: str
pages: list[StorybookPage]
cover_url: str | None
```
- 集成文本 + 图像生成流水线
**涉及文件**:
- `app/services/adapters/storybook/` - 新建目录
- `app/api/stories.py` - 添加 storybook 生成端点
### 中优先级
#### 3. 成本追踪系统
**目标**: 记录实际消费,支持预算控制
**实现方案**:
- 成本记录表:
```python
class CostRecord:
user_id: str
provider_id: str
capability: str # text/image/tts
estimated_cost: Decimal
actual_cost: Decimal | None
timestamp: datetime
```
- 预算配置:
```python
class BudgetConfig:
user_id: str
daily_limit: Decimal
monthly_limit: Decimal
alert_threshold: float = 0.8 # 80% 时告警
```
- 超预算处理: 拒绝请求 / 降级到低成本 provider
**涉及文件**:
- `app/db/admin_models.py` - 添加 CostRecord, BudgetConfig
- `app/services/cost_tracker.py` - 新建
- `app/api/admin_providers.py` - 添加成本查询 API
#### 4. 指标可视化
**目标**: 管理后台展示供应商指标图表
**实现方案**:
- 添加指标查询 API:
- GET /admin/metrics/summary - 汇总统计
- GET /admin/metrics/timeline - 时间线数据
- GET /admin/metrics/providers/{id} - 单个供应商详情
- 前端使用 Chart.js 或 ECharts 展示
### 低优先级
#### 5. 多租户 Provider 配置
**目标**: 每个租户可配置独立 provider 列表和 API Key
**实现方案**:
- 租户配置表:
```python
class TenantProviderConfig:
tenant_id: str
provider_type: str
provider_ids: list[str] # 按优先级排序
api_key_override: str | None # 加密存储
```
- 路由时优先使用租户配置,回退到全局配置
#### 6. Provider 健康检查调度器
**目标**: 定期主动检查 provider 健康状态
**实现方案**:
- Celery Beat 定时任务
- 每 5 分钟检查一次所有启用的 provider
- 更新 ProviderHealth 表
#### 7. 适配器热加载
**目标**: 支持运行时动态加载新适配器
**实现方案**:
- 适配器插件目录: `app/services/adapters/plugins/`
- 启动时扫描并注册
- 提供 API 触发重新扫描
---
## API 变更记录
### v0.2.0 新增
| Method | Route | Description |
|--------|-------|-------------|
| GET | `/admin/secrets` | 列出所有密钥名称 |
| POST | `/admin/secrets` | 创建/更新密钥 |
| DELETE | `/admin/secrets/{name}` | 删除密钥 |
| GET | `/admin/secrets/{name}/verify` | 验证密钥有效性 |
### 计划中 (v0.3.0)
| Method | Route | Description |
|--------|-------|-------------|
| GET | `/admin/metrics/summary` | 指标汇总 |
| GET | `/admin/metrics/timeline` | 时间线数据 |
| POST | `/api/storybook/generate` | 生成分页故事书 |
| GET | `/admin/costs` | 成本统计 |
| POST | `/admin/budgets` | 设置预算 |
---
## 适配器开发指南
### 添加新适配器
1. 创建适配器文件:
```python
# app/services/adapters/image/new_provider.py
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
@AdapterRegistry.register("image", "new_provider")
class NewProviderAdapter(BaseAdapter[str]):
adapter_type = "image"
adapter_name = "new_provider"
async def execute(self, prompt: str, **kwargs) -> str:
# 实现生成逻辑
pass
async def health_check(self) -> bool:
# 实现健康检查
pass
@property
def estimated_cost(self) -> float:
return 0.01 # USD
```
2. 在 `__init__.py` 中导入:
```python
# app/services/adapters/__init__.py
from app.services.adapters.image import new_provider as _new_provider # noqa: F401
```
3. 添加配置:
```python
# app/core/config.py
new_provider_api_key: str = ""
# app/services/provider_router.py
API_KEY_MAP["new_provider"] = "new_provider_api_key"
```
4. 更新 `.env.example`:
```
NEW_PROVIDER_API_KEY=
```

View File

@@ -1,109 +0,0 @@
# DreamWeaver 重构实施计划
## 1. 概述
本文档基于对当前架构的深入分析,制定了从稳定性、可维护性到可扩展性的分阶段重构计划。
**目标**
- **短期**:解决单点故障风险,优化开发体验,清理关键技术债。
- **中期**:提升系统高可用能力,增强监控与可观测性。
- **长期**:架构演进,支持大规模并发与复杂业务场景。
---
## 2. 短期优化计划 (1-2周)
**重点**:消除即时风险,提升部署效率。
### 2.1 统一镜像构建 (High Priority)
目前 `backend`, `backend-admin`, `worker`, `celery-beat` 重复构建 4 次,浪费资源且镜像版本可能不一致。
- **Action Items**:
- [x] 修改 `backend/Dockerfile` 为通用基础镜像。
- [x] 更新 `docker-compose.yml`,定义 `backend-base` 服务或使用 `image` 标签共享镜像。
- [x] 确保所有 Python 服务共用同一构建产物,仅启动命令不同。
### 2.2 修复 Provider 缓存与限流 (High Priority)
内存缓存 (`TTLCache`, `_latency_cache`) 在多进程/多实例下失效。
- **Action Items**:
- [x] 引入 Redis 作为共享缓存后端。
- [x] 重构 `_load_provider_cache`,将 Provider 配置缓存至 Redis。
- [x] 重构 `stories.py` 中的限流逻辑,使用 `redis-cell` 或简单的 Redis 计数器替代 `TTLCache`
### 2.3 拆分 `stories.py` (Medium Priority)
`app/api/stories.py` 超过 600 行,包含 API 定义、业务逻辑、验证逻辑,维护困难。
- **Action Items**:
- [x] 创建 `app/services/story_service.py`迁移生成、润色、PDF生成等核心逻辑。
- [x] 创建 `app/schemas/story_schemas.py`,迁移 Pydantic 模型(`GenerateRequest`, `StoryResponse` 等)。
- [x] API 层 `stories.py` 仅保留路由定义和依赖注入,调用 Service 层。
---
## 3. 中期优化计划 (1-2月)
**重点**:高可用 (HA) 与系统韧性。
### 3.1 数据库高可用 (Critical)
当前 PostgreSQL 为单点,且 Admin/User 混合使用。
- **Action Items**:
- [ ] 部署 PostgreSQL 主从复制 (Master-Slave)。
- [ ] 配置 `PgBouncer` 或 SQLAlchemy 读写分离,减轻主库压力。
- [ ] 实施数据库自动备份策略 (如 `pg_dump` 定时上传 S3)。
### 3.2 消息队列高可用 (Critical)
Redis 单点故障将导致 Celery 任务全盘停摆。
- **Action Items**:
- [ ] 迁移至 Redis Sentinel 或 Redis Cluster 模式。
- [ ] 更新 Celery 配置以支持 Sentinel/Cluster 连接串。
### 3.3 增强可观测性 (Important)
目前仅有简单的日志,缺乏系统级指标。
- **Action Items**:
- [ ] 集成 Prometheus Client暴露 `/metrics` 端点。
- [ ] 部署 Grafana + Prometheus监控 API 延迟、QPS、Celery 队列积压情况。
- [ ] 完善 `ProviderMetrics`,增加可视化大盘,实时监控 AI 供应商的成本与成功率。
### 3.4 Phase 3 最小可执行任务清单 (MVP)
目标:在不大改业务代码的前提下,于一个迭代内完成高可用基础设施闭环。
- [x] PostgreSQL 主从:新增 `docker-compose.ha.yml`,包含 1 主 1 从与健康检查。
- [x] PostgreSQL 备份:新增每日备份任务(`pg_dump`)与 7 天保留策略。
- [x] Redis Sentinel新增 1 主 2 哨兵最小拓扑,并验证故障切换。
- [x] Celery 连接:更新 Celery broker/result backend 配置,支持 Sentinel 连接串。
- [x] 回归验证:执行一次故事生成 + 异步任务链路worker/beat冒烟测试。
- [x] 运行手册补充故障切换与恢复步骤文档PostgreSQL/Redis/Celery
---
## 4. 长期架构演进 (季度规划)
**重点**:业务解耦与规模化。
### 4.1 统一 API 网关
- **当前**前端直连后端端口CORS 配置分散。
- **演进**:引入 Traefik 或 Nginx 作为统一网关管理路由、SSL、全局限流、统一鉴权。
### 4.2 前端工程合并
- **当前**User App 和 Admin Console 是完全独立的两个项目,但在组件和工具链上高度重复。
- **演进**:使用一种 Monorepo 策略或基于路由的单一应用策略,共享组件库和类型定义,减少维护成本。
### 4.3 事件驱动架构完善
- **当前**:部分业务逻辑耦合在 API 中。
- **演进**:扩展事件总线,将“阅读记录”、“成就解锁”、“通知推送”等非核心链路完全异步化,通过 Domain Events 解耦。
---
## 5. 实施路线图
| 阶段 | 时间估算 | 关键里程碑 |
| :--- | :--- | :--- |
| **Phase 1: 基础夯实** | Week 1-2 | Docker 构建优化上线Redis 替代内存缓存。 |
| **Phase 2: 代码重构** | Week 3-4 | `stories.py` 拆分完成Service 层建立。 |
| **Phase 3: 高可用建设** | Month 2 | 数据库与 Redis 实现主备/集群模式。 |
| **Phase 4: 监控体系** | Month 2 | Grafana 监控大盘上线,关键指标报警配置完毕。 |

View File

@@ -1,52 +0,0 @@
# `stories.py` 拆分分析 (Phase 2 准备)
## 当前职责
`app/api/stories.py` (591 行) 承担了以下职责:
| 职责 | 行数 | 描述 |
|---|---|---|
| Pydantic 模型 | ~50 行 | `GenerateRequest`, `StoryResponse`, `FullStoryResponse` 等 |
| 验证逻辑 | ~40 行 | `_validate_profile_and_universe` |
| 路由 + 业务 | ~300 行 | `generate_story`, `generate_story_full`, `generate_story_stream` |
| 绘本逻辑 | ~170 行 | `generate_storybook_api` (含并行图片生成) |
| 成就查询 | ~30 行 | `get_story_achievements` |
## 缺失端点
测试中引用但 **未实现** 的端点(这些应在拆分时一并补充):
- `GET /api/stories` — 故事列表 (分页)
- `GET /api/stories/{id}` — 故事详情
- `DELETE /api/stories/{id}` — 故事删除
- `POST /api/image/generate/{id}` — 封面图片生成
- `GET /api/audio/{id}` — 语音朗读
## 建议拆分结构
```
app/
├── schemas/
│ └── story_schemas.py # [NEW] Pydantic 模型
├── services/
│ └── story_service.py # [NEW] 核心业务逻辑
└── api/
├── stories.py # [SLIM] 路由定义 + 依赖注入
└── stories_storybook.py # [NEW] 绘本相关端点 (可选)
```
### `story_schemas.py`
- 迁移所有 Pydantic 模型
- 包括 `GenerateRequest`, `StoryResponse`, `FullStoryResponse`, `StorybookRequest`, `StorybookResponse`
### `story_service.py`
- `validate_profile_and_universe()` — 验证逻辑
- `create_story()` — 故事入库
- `generate_and_save_story()` — 生成 + 保存联合操作
- `generate_storybook_with_images()` — 绘本并行图片生成
- 补充: `list_stories()`, `get_story()`, `delete_story()`
### `stories.py` (瘦路由层)
- 仅保留 `@router` 装饰器和依赖注入
- 调用 service 层完成业务逻辑
- 预计 150-200 行

View File

@@ -1,27 +1,27 @@
import asyncio
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.db.database import async_engine
from sqlalchemy import text
async def upgrade_db():
print("🚀 Checking database schema...")
async with async_engine.begin() as conn:
# Check if column exists
result = await conn.execute(text(
"SELECT column_name FROM information_schema.columns WHERE table_name='providers' AND column_name='config_json';"
))
if result.scalar():
print("✅ Column 'config_json' already exists.")
else:
print("⚠️ Column 'config_json' missing. Adding it now...")
await conn.execute(text("ALTER TABLE providers ADD COLUMN config_json JSON;"))
print("✅ Column 'config_json' added successfully.")
if __name__ == "__main__":
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(upgrade_db())
import asyncio
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.db.database import async_engine
from sqlalchemy import text
async def upgrade_db():
print("🚀 Checking database schema...")
async with async_engine.begin() as conn:
# Check if column exists
result = await conn.execute(text(
"SELECT column_name FROM information_schema.columns WHERE table_name='providers' AND column_name='config_json';"
))
if result.scalar():
print("✅ Column 'config_json' already exists.")
else:
print("⚠️ Column 'config_json' missing. Adding it now...")
await conn.execute(text("ALTER TABLE providers ADD COLUMN config_json JSON;"))
print("✅ Column 'config_json' added successfully.")
if __name__ == "__main__":
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(upgrade_db())

View File

@@ -1,29 +1,29 @@
import asyncio
import os
import sys
# Add backend to path
sys.path.append(os.path.join(os.getcwd()))
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from app.core.config import settings
async def add_column():
engine = create_async_engine(settings.database_url)
async_session = async_sessionmaker(engine, expire_on_commit=False)
async with async_session() as session:
try:
print("Adding config_json column to providers table...")
await session.execute(text("ALTER TABLE providers ADD COLUMN IF NOT EXISTS config_json JSONB DEFAULT '{}'::jsonb"))
await session.commit()
print("Successfully added config_json column.")
except Exception as e:
print(f"Error adding column: {e}")
await session.rollback()
if __name__ == "__main__":
import asyncio
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(add_column())
import asyncio
import os
import sys
# Add backend to path
sys.path.append(os.path.join(os.getcwd()))
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from app.core.config import settings
async def add_column():
engine = create_async_engine(settings.database_url)
async_session = async_sessionmaker(engine, expire_on_commit=False)
async with async_session() as session:
try:
print("Adding config_json column to providers table...")
await session.execute(text("ALTER TABLE providers ADD COLUMN IF NOT EXISTS config_json JSONB DEFAULT '{}'::jsonb"))
await session.commit()
print("Successfully added config_json column.")
except Exception as e:
print(f"Error adding column: {e}")
await session.rollback()
if __name__ == "__main__":
import asyncio
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(add_column())

View File

@@ -1,21 +1,21 @@
import asyncio
import sys
import os
# Add backend to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.db.database import init_db
from app.core.logging import setup_logging
async def main():
setup_logging()
print("Initializing database...")
try:
await init_db()
print("Database initialized successfully.")
except Exception as e:
print(f"Error initializing database: {e}")
if __name__ == "__main__":
asyncio.run(main())
import asyncio
import sys
import os
# Add backend to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.db.database import init_db
from app.core.logging import setup_logging
async def main():
setup_logging()
print("Initializing database...")
try:
await init_db()
print("Database initialized successfully.")
except Exception as e:
print(f"Error initializing database: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1 +1 @@
# Tests package
# Tests package

View File

@@ -1,65 +1,65 @@
"""认证相关测试。"""
import pytest
from fastapi.testclient import TestClient
from app.core.security import create_access_token, decode_access_token
class TestJWT:
"""JWT token 测试。"""
def test_create_and_decode_token(self):
"""测试 token 创建和解码。"""
payload = {"sub": "github:12345"}
token = create_access_token(payload)
decoded = decode_access_token(token)
assert decoded is not None
assert decoded["sub"] == "github:12345"
def test_decode_invalid_token(self):
"""测试无效 token 解码。"""
result = decode_access_token("invalid-token")
assert result is None
def test_decode_empty_token(self):
"""测试空 token 解码。"""
result = decode_access_token("")
assert result is None
class TestSession:
"""Session 端点测试。"""
def test_session_without_auth(self, client: TestClient):
"""未登录时获取 session。"""
response = client.get("/auth/session")
assert response.status_code == 200
data = response.json()
assert data["user"] is None
def test_session_with_auth(self, auth_client: TestClient, test_user):
"""已登录时获取 session。"""
response = auth_client.get("/auth/session")
assert response.status_code == 200
data = response.json()
assert data["user"] is not None
assert data["user"]["id"] == test_user.id
assert data["user"]["name"] == test_user.name
def test_session_with_invalid_token(self, client: TestClient):
"""无效 token 获取 session。"""
client.cookies.set("access_token", "invalid-token")
response = client.get("/auth/session")
assert response.status_code == 200
data = response.json()
assert data["user"] is None
class TestSignout:
"""登出测试。"""
def test_signout(self, auth_client: TestClient):
"""测试登出。"""
response = auth_client.post("/auth/signout", follow_redirects=False)
assert response.status_code == 302
"""认证相关测试。"""
import pytest
from fastapi.testclient import TestClient
from app.core.security import create_access_token, decode_access_token
class TestJWT:
"""JWT token 测试。"""
def test_create_and_decode_token(self):
"""测试 token 创建和解码。"""
payload = {"sub": "github:12345"}
token = create_access_token(payload)
decoded = decode_access_token(token)
assert decoded is not None
assert decoded["sub"] == "github:12345"
def test_decode_invalid_token(self):
"""测试无效 token 解码。"""
result = decode_access_token("invalid-token")
assert result is None
def test_decode_empty_token(self):
"""测试空 token 解码。"""
result = decode_access_token("")
assert result is None
class TestSession:
"""Session 端点测试。"""
def test_session_without_auth(self, client: TestClient):
"""未登录时获取 session。"""
response = client.get("/auth/session")
assert response.status_code == 200
data = response.json()
assert data["user"] is None
def test_session_with_auth(self, auth_client: TestClient, test_user):
"""已登录时获取 session。"""
response = auth_client.get("/auth/session")
assert response.status_code == 200
data = response.json()
assert data["user"] is not None
assert data["user"]["id"] == test_user.id
assert data["user"]["name"] == test_user.name
def test_session_with_invalid_token(self, client: TestClient):
"""无效 token 获取 session。"""
client.cookies.set("access_token", "invalid-token")
response = client.get("/auth/session")
assert response.status_code == 200
data = response.json()
assert data["user"] is None
class TestSignout:
"""登出测试。"""
def test_signout(self, auth_client: TestClient):
"""测试登出。"""
response = auth_client.post("/auth/signout", follow_redirects=False)
assert response.status_code == 302

View File

@@ -1,195 +1,195 @@
"""Provider router 测试 - failover 和配置加载。"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.adapters import AdapterConfig
from app.services.adapters.text.models import StoryOutput
class TestProviderFailover:
"""Provider failover 测试。"""
@pytest.mark.asyncio
async def test_failover_to_second_provider(self):
"""第一个 provider 失败时切换到第二个。"""
from app.services import provider_router
# Mock 两个 provider - 使用 spec=False 并显式设置所有属性
mock_provider_1 = MagicMock()
mock_provider_1.configure_mock(
id="provider-1",
type="text",
adapter="text_primary",
api_key="key1",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=10,
weight=1.0,
enabled=True,
)
mock_provider_2 = MagicMock()
mock_provider_2.configure_mock(
id="provider-2",
type="text",
adapter="text_primary",
api_key="key2",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=5,
weight=1.0,
enabled=True,
)
mock_providers = [mock_provider_1, mock_provider_2]
mock_result = StoryOutput(
mode="generated",
title="测试故事",
story_text="内容",
cover_prompt_suggestion="prompt",
)
call_count = 0
async def mock_execute(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise Exception("First provider failed")
return mock_result
with patch.object(provider_router, "get_providers", return_value=mock_providers):
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
mock_adapter_class = MagicMock()
mock_adapter_instance = MagicMock()
mock_adapter_instance.execute = mock_execute
mock_adapter_class.return_value = mock_adapter_instance
mock_get.return_value = mock_adapter_class
result = await provider_router.generate_story_content(
input_type="keywords",
data="测试",
)
assert result == mock_result
assert call_count == 2 # 第一个失败,第二个成功
@pytest.mark.asyncio
async def test_all_providers_fail(self):
"""所有 provider 都失败时抛出异常。"""
from app.services import provider_router
mock_provider = MagicMock()
mock_provider.configure_mock(
id="provider-1",
type="text",
adapter="text_primary",
api_key="key1",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=10,
weight=1.0,
enabled=True,
)
mock_providers = [mock_provider]
async def mock_execute(**kwargs):
raise Exception("Provider failed")
with patch.object(provider_router, "get_providers", return_value=mock_providers):
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
mock_adapter_class = MagicMock()
mock_adapter_instance = MagicMock()
mock_adapter_instance.execute = mock_execute
mock_adapter_class.return_value = mock_adapter_instance
mock_get.return_value = mock_adapter_class
with pytest.raises(ValueError, match="No text provider succeeded"):
await provider_router.generate_story_content(
input_type="keywords",
data="测试",
)
class TestProviderConfigFromDB:
"""从 DB 加载 provider 配置测试。"""
def test_build_config_from_provider_with_api_key(self):
"""Provider 有 api_key 时优先使用。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "text_primary"
mock_provider.api_key = "db-api-key"
mock_provider.api_base = "https://custom.api.com"
mock_provider.model = "custom-model"
mock_provider.timeout_ms = 30000
mock_provider.max_retries = 5
mock_provider.config_ref = None
mock_provider.config_json = {}
config = _build_config_from_provider(mock_provider)
assert config.api_key == "db-api-key"
assert config.api_base == "https://custom.api.com"
assert config.model == "custom-model"
assert config.timeout_ms == 30000
assert config.max_retries == 5
def test_build_config_fallback_to_settings(self):
"""Provider 无 api_key 时回退到 settings。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "text_primary"
mock_provider.api_key = None
mock_provider.api_base = None
mock_provider.model = None
mock_provider.timeout_ms = None
mock_provider.max_retries = None
mock_provider.config_ref = "text_api_key"
mock_provider.config_json = {}
with patch("app.services.provider_router.settings") as mock_settings:
mock_settings.text_api_key = "settings-api-key"
mock_settings.text_model = "gemini-2.0-flash"
config = _build_config_from_provider(mock_provider)
assert config.api_key == "settings-api-key"
class TestProviderCacheStartup:
"""Provider cache 启动加载测试。"""
@pytest.mark.asyncio
async def test_cache_loaded_on_startup(self):
"""启动时加载 provider cache。"""
from app.main import _load_provider_cache
with patch("app.db.database._get_session_factory") as mock_factory:
mock_session = AsyncMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock()
with patch("app.services.provider_cache.reload_providers", new_callable=AsyncMock) as mock_reload:
mock_reload.return_value = {"text": [], "image": [], "tts": []}
await _load_provider_cache()
mock_reload.assert_called_once()
"""Provider router 测试 - failover 和配置加载。"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.adapters import AdapterConfig
from app.services.adapters.text.models import StoryOutput
class TestProviderFailover:
"""Provider failover 测试。"""
@pytest.mark.asyncio
async def test_failover_to_second_provider(self):
"""第一个 provider 失败时切换到第二个。"""
from app.services import provider_router
# Mock 两个 provider - 使用 spec=False 并显式设置所有属性
mock_provider_1 = MagicMock()
mock_provider_1.configure_mock(
id="provider-1",
type="text",
adapter="text_primary",
api_key="key1",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=10,
weight=1.0,
enabled=True,
)
mock_provider_2 = MagicMock()
mock_provider_2.configure_mock(
id="provider-2",
type="text",
adapter="text_primary",
api_key="key2",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=5,
weight=1.0,
enabled=True,
)
mock_providers = [mock_provider_1, mock_provider_2]
mock_result = StoryOutput(
mode="generated",
title="测试故事",
story_text="内容",
cover_prompt_suggestion="prompt",
)
call_count = 0
async def mock_execute(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise Exception("First provider failed")
return mock_result
with patch.object(provider_router, "get_providers", return_value=mock_providers):
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
mock_adapter_class = MagicMock()
mock_adapter_instance = MagicMock()
mock_adapter_instance.execute = mock_execute
mock_adapter_class.return_value = mock_adapter_instance
mock_get.return_value = mock_adapter_class
result = await provider_router.generate_story_content(
input_type="keywords",
data="测试",
)
assert result == mock_result
assert call_count == 2 # 第一个失败,第二个成功
@pytest.mark.asyncio
async def test_all_providers_fail(self):
"""所有 provider 都失败时抛出异常。"""
from app.services import provider_router
mock_provider = MagicMock()
mock_provider.configure_mock(
id="provider-1",
type="text",
adapter="text_primary",
api_key="key1",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=10,
weight=1.0,
enabled=True,
)
mock_providers = [mock_provider]
async def mock_execute(**kwargs):
raise Exception("Provider failed")
with patch.object(provider_router, "get_providers", return_value=mock_providers):
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
mock_adapter_class = MagicMock()
mock_adapter_instance = MagicMock()
mock_adapter_instance.execute = mock_execute
mock_adapter_class.return_value = mock_adapter_instance
mock_get.return_value = mock_adapter_class
with pytest.raises(ValueError, match="No text provider succeeded"):
await provider_router.generate_story_content(
input_type="keywords",
data="测试",
)
class TestProviderConfigFromDB:
"""从 DB 加载 provider 配置测试。"""
def test_build_config_from_provider_with_api_key(self):
"""Provider 有 api_key 时优先使用。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "text_primary"
mock_provider.api_key = "db-api-key"
mock_provider.api_base = "https://custom.api.com"
mock_provider.model = "custom-model"
mock_provider.timeout_ms = 30000
mock_provider.max_retries = 5
mock_provider.config_ref = None
mock_provider.config_json = {}
config = _build_config_from_provider(mock_provider)
assert config.api_key == "db-api-key"
assert config.api_base == "https://custom.api.com"
assert config.model == "custom-model"
assert config.timeout_ms == 30000
assert config.max_retries == 5
def test_build_config_fallback_to_settings(self):
"""Provider 无 api_key 时回退到 settings。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "text_primary"
mock_provider.api_key = None
mock_provider.api_base = None
mock_provider.model = None
mock_provider.timeout_ms = None
mock_provider.max_retries = None
mock_provider.config_ref = "text_api_key"
mock_provider.config_json = {}
with patch("app.services.provider_router.settings") as mock_settings:
mock_settings.text_api_key = "settings-api-key"
mock_settings.text_model = "gemini-2.0-flash"
config = _build_config_from_provider(mock_provider)
assert config.api_key == "settings-api-key"
class TestProviderCacheStartup:
"""Provider cache 启动加载测试。"""
@pytest.mark.asyncio
async def test_cache_loaded_on_startup(self):
"""启动时加载 provider cache。"""
from app.main import _load_provider_cache
with patch("app.db.database._get_session_factory") as mock_factory:
mock_session = AsyncMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock()
with patch("app.services.provider_cache.reload_providers", new_callable=AsyncMock) as mock_reload:
mock_reload.return_value = {"text": [], "image": [], "tts": []}
await _load_provider_cache()
mock_reload.assert_called_once()

View File

@@ -65,4 +65,4 @@ def test_add_achievement(auth_client):
)
assert response.status_code == 200
data = response.json()
assert {"type": "勇气", "description": "克服黑暗"} in data["achievements"]
assert {"type": "勇气", "description": "克服黑暗"} in data["achievements"]