Initial commit: clean project structure

- Backend: FastAPI + SQLAlchemy + Celery (Python 3.11+)
- Frontend: Vue 3 + TypeScript + Pinia + Tailwind
- Admin Frontend: separate Vue 3 app for management
- Docker Compose: 9 services orchestration
- Specs: design prototypes, memory system PRD, product roadmap

Cleanup performed:
- Removed temporary debug scripts from backend root
- Removed deprecated admin_app.py (embedded UI)
- Removed duplicate docs from admin-frontend
- Updated .gitignore for Vite cache and egg-info
This commit is contained in:
zhangtuo
2026-01-20 18:20:03 +08:00
commit e9d7f8832a
241 changed files with 33070 additions and 0 deletions

View File

View File

@@ -0,0 +1,85 @@
"""Achievement extraction service."""
import json
import re
import httpx
from app.core.config import settings
from app.core.logging import get_logger
from app.core.prompts import ACHIEVEMENT_EXTRACTION_PROMPT
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
async def extract_achievements(story_text: str) -> list[dict]:
"""Extract achievements from story text using LLM."""
if not settings.text_api_key:
logger.warning("achievement_extraction_skipped", reason="missing_text_api_key")
return []
model = settings.text_model or "gemini-2.0-flash"
url = f"{TEXT_API_BASE}/{model}:generateContent"
prompt = ACHIEVEMENT_EXTRACTION_PROMPT.format(story_text=story_text)
payload = {
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.2,
"topP": 0.9,
},
}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(
url,
json=payload,
headers={"x-goog-api-key": settings.text_api_key},
)
response.raise_for_status()
result = response.json()
candidates = result.get("candidates") or []
if not candidates:
logger.warning("achievement_extraction_empty")
return []
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
logger.warning("achievement_extraction_missing_text")
return []
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError:
logger.warning("achievement_extraction_parse_failed")
return []
achievements = parsed.get("achievements")
if not isinstance(achievements, list):
return []
normalized: list[dict] = []
for item in achievements:
if not isinstance(item, dict):
continue
a_type = str(item.get("type", "")).strip()
description = str(item.get("description", "")).strip()
score = item.get("score", 0)
if not a_type or not description:
continue
normalized.append({
"type": a_type,
"description": description,
"score": score
})
return normalized

View File

@@ -0,0 +1,21 @@
"""适配器模块 - 供应商平台化架构核心。"""
from app.services.adapters.base import AdapterConfig, BaseAdapter
# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.registry import AdapterRegistry
# Storybook adapters
from app.services.adapters.storybook import primary as _storybook_primary # noqa: F401
from app.services.adapters.text import gemini as _text_gemini_adapter # noqa: F401
# 导入所有适配器以触发注册
# Text adapters
from app.services.adapters.text import openai as _text_openai_adapter # noqa: F401
# TTS adapters
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
__all__ = ["AdapterConfig", "BaseAdapter", "AdapterRegistry"]

View File

@@ -0,0 +1,46 @@
"""适配器基类定义。"""
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from pydantic import BaseModel
T = TypeVar("T")
class AdapterConfig(BaseModel):
"""适配器配置基类。"""
api_key: str
api_base: str | None = None
model: str | None = None
timeout_ms: int = 60000
max_retries: int = 3
extra_config: dict = {}
class BaseAdapter(ABC, Generic[T]):
"""适配器基类,所有供应商适配器必须继承此类。"""
# 子类必须定义
adapter_type: str # text / image / tts
adapter_name: str # text_primary / image_primary / tts_primary
def __init__(self, config: AdapterConfig):
self.config = config
@abstractmethod
async def execute(self, **kwargs) -> T:
"""执行适配器逻辑,返回结果。"""
pass
@abstractmethod
async def health_check(self) -> bool:
"""健康检查,返回是否可用。"""
pass
@property
@abstractmethod
def estimated_cost(self) -> float:
"""预估单次调用成本 (USD)。"""
pass

View File

@@ -0,0 +1,3 @@
"""图像生成适配器。"""# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.image import antigravity as _image_antigravity_adapter # noqa: F401

View File

@@ -0,0 +1,214 @@
"""Antigravity 图像生成适配器。
使用 OpenAI 兼容 API 生成图像。
支持 gemini-3-pro-image 等模型。
"""
import base64
import time
from typing import Any
from openai import AsyncOpenAI
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "http://127.0.0.1:8045/v1"
DEFAULT_MODEL = "gemini-3-pro-image"
DEFAULT_SIZE = "1024x1024"
# 支持的尺寸映射
SUPPORTED_SIZES = {
"1024x1024": "1:1",
"1280x720": "16:9",
"720x1280": "9:16",
"1216x896": "4:3",
}
@AdapterRegistry.register("image", "antigravity")
class AntigravityImageAdapter(BaseAdapter[str]):
"""Antigravity 图像生成适配器 (OpenAI 兼容 API)。
特点:
- 使用 OpenAI 兼容的 chat.completions 端点
- 通过 extra_body.size 指定图像尺寸
- 支持 gemini-3-pro-image 等模型
- 返回图片 URL 或 base64
"""
adapter_type = "image"
adapter_name = "antigravity"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
self.client = AsyncOpenAI(
base_url=self.api_base,
api_key=config.api_key,
timeout=config.timeout_ms / 1000,
)
async def execute(
self,
prompt: str,
model: str | None = None,
size: str | None = None,
num_images: int = 1,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 base64。
Args:
prompt: 图片描述提示词
model: 模型名称 (gemini-3-pro-image / gemini-3-pro-image-16-9 等)
size: 图像尺寸 (1024x1024, 1280x720, 720x1280, 1216x896)
num_images: 生成图片数量 (暂只支持 1)
Returns:
图片 URL 或 base64 字符串
"""
# 优先使用传入参数,其次使用 Adapter 配置,最后使用默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
size = size or cfg.get("size") or DEFAULT_SIZE
start_time = time.time()
logger.info(
"antigravity_generate_start",
prompt_length=len(prompt),
model=model,
size=size,
)
# 调用 API
image_url = await self._generate_image(prompt, model, size)
elapsed = time.time() - start_time
logger.info(
"antigravity_generate_success",
elapsed_seconds=round(elapsed, 2),
model=model,
)
return image_url
async def health_check(self) -> bool:
"""检查 Antigravity API 是否可用。"""
try:
# 简单测试连通性
response = await self.client.chat.completions.create(
model=self.config.model or DEFAULT_MODEL,
messages=[{"role": "user", "content": "test"}],
max_tokens=1,
)
return True
except Exception as e:
logger.warning("antigravity_health_check_failed", error=str(e))
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
Antigravity 使用 Gemini 模型,成本约 $0.02/张。
"""
return 0.02
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((Exception,)),
reraise=True,
)
async def _generate_image(
self,
prompt: str,
model: str,
size: str,
) -> str:
"""调用 Antigravity API 生成图像。
Returns:
图片 URL 或 base64 data URI
"""
try:
response = await self.client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
extra_body={"size": size},
)
# 解析响应
content = response.choices[0].message.content
if not content:
raise ValueError("Antigravity 未返回内容")
# 尝试解析为图片 URL 或 base64
# 响应可能是纯 URL、base64 或 markdown 格式的图片
image_url = self._extract_image_url(content)
if image_url:
return image_url
raise ValueError(f"Antigravity 响应无法解析为图片: {content[:200]}")
except Exception as e:
logger.error(
"antigravity_generate_error",
error=str(e),
model=model,
)
raise
def _extract_image_url(self, content: str) -> str | None:
"""从响应内容中提取图片 URL。
支持多种格式:
- 纯 URL: https://...
- Markdown: ![...](https://...)
- Base64 data URI: data:image/...
- 纯 base64 字符串
"""
content = content.strip()
# 1. 检查是否为 data URI
if content.startswith("data:image/"):
return content
# 2. 检查是否为纯 URL
if content.startswith("http://") or content.startswith("https://"):
# 可能有多行,取第一行
return content.split("\n")[0].strip()
# 3. 检查 Markdown 图片格式 ![...](url)
import re
md_match = re.search(r"!\[.*?\]\((https?://[^\)]+)\)", content)
if md_match:
return md_match.group(1)
# 4. 检查是否像 base64 编码的图片数据
if self._looks_like_base64(content):
# 假设是 PNG
return f"data:image/png;base64,{content}"
return None
def _looks_like_base64(self, s: str) -> bool:
"""判断字符串是否看起来像 base64 编码。"""
# Base64 只包含 A-Z, a-z, 0-9, +, /, =
# 且长度通常较长
if len(s) < 100:
return False
import re
return bool(re.match(r"^[A-Za-z0-9+/=]+$", s.replace("\n", "")))

View File

@@ -0,0 +1,252 @@
"""CQTAI nano 图像生成适配器。
支持异步生成 + 轮询获取结果。
API 文档: https://api.cqtai.com
"""
import asyncio
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "https://api.cqtai.com"
DEFAULT_MODEL = "nano-banana"
DEFAULT_RESOLUTION = "2K"
DEFAULT_ASPECT_RATIO = "1:1"
POLL_INTERVAL_SECONDS = 2
MAX_POLL_ATTEMPTS = 60 # 最多轮询 2 分钟
@AdapterRegistry.register("image", "cqtai")
class CQTAIImageAdapter(BaseAdapter[str]):
"""CQTAI nano 图像生成适配器,返回图片 URL。
特点:
- 异步生成 + 轮询获取结果
- 支持 nano-banana (标准) 和 nano-banana-pro (高画质)
- 支持多种分辨率和画面比例
- 支持图生图 (filesUrl)
"""
adapter_type = "image"
adapter_name = "cqtai"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
async def execute(
self,
prompt: str,
model: str | None = None,
resolution: str | None = None,
aspect_ratio: str | None = None,
num_images: int = 1,
files_url: list[str] | None = None,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 URL 列表。
Args:
prompt: 图片描述提示词
model: 模型名称 (nano-banana / nano-banana-pro)
resolution: 分辨率 (1K / 2K / 4K)
aspect_ratio: 画面比例 (1:1, 16:9, 9:16, 4:3, 3:4 等)
num_images: 生成图片数量 (1-4)
files_url: 输入图片 URL 列表 (图生图)
Returns:
单张图片返回 str多张返回 list[str]
"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default (extra_config)
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
resolution = resolution or cfg.get("resolution") or DEFAULT_RESOLUTION
aspect_ratio = aspect_ratio or cfg.get("aspect_ratio") or DEFAULT_ASPECT_RATIO
num_images = min(max(num_images, 1), 4) # 限制 1-4
start_time = time.time()
logger.info(
"cqtai_generate_start",
prompt_length=len(prompt),
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
)
# 1. 提交生成任务
task_id = await self._submit_task(
prompt=prompt,
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
files_url=files_url or [],
)
logger.info("cqtai_task_submitted", task_id=task_id)
# 2. 轮询获取结果
result = await self._poll_result(task_id)
elapsed = time.time() - start_time
logger.info(
"cqtai_generate_success",
task_id=task_id,
elapsed_seconds=round(elapsed, 2),
image_count=len(result) if isinstance(result, list) else 1,
)
# 单张图片返回字符串,多张返回列表
if num_images == 1 and isinstance(result, list) and len(result) == 1:
return result[0]
return result
async def health_check(self) -> bool:
"""检查 CQTAI API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
# 简单的连通性测试
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": "health_check_test"},
headers={"Authorization": self.config.api_key},
)
# 即使返回错误也说明服务可达
return response.status_code in (200, 400, 401, 403, 404)
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
nano-banana: ¥0.1 ≈ $0.014
nano-banana-pro: ¥0.2 ≈ $0.028
"""
model = self.config.model or DEFAULT_MODEL
if model == "nano-banana-pro":
return 0.028
return 0.014
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _submit_task(
self,
prompt: str,
model: str,
resolution: str,
aspect_ratio: str,
num_images: int,
files_url: list[str],
) -> str:
"""提交图像生成任务,返回任务 ID。"""
timeout = self.config.timeout_ms / 1000
payload = {
"prompt": prompt,
"numImages": num_images,
"aspectRatio": aspect_ratio,
"filesUrl": files_url,
}
# 可选参数,不传则使用默认值
if model != DEFAULT_MODEL:
payload["model"] = model
if resolution != DEFAULT_RESOLUTION:
payload["resolution"] = resolution
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
f"{self.api_base}/api/cqt/generator/nano",
json=payload,
headers={
"Authorization": self.config.api_key,
"Content-Type": "application/json",
},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 任务提交失败: {data.get('msg', '未知错误')}")
task_id = data.get("data")
if not task_id:
raise ValueError("CQTAI 未返回任务 ID")
return task_id
async def _poll_result(self, task_id: str) -> list[str]:
"""轮询获取生成结果。
Returns:
图片 URL 列表
"""
timeout = self.config.timeout_ms / 1000
for attempt in range(MAX_POLL_ATTEMPTS):
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": task_id},
headers={"Authorization": self.config.api_key},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 查询失败: {data.get('msg', '未知错误')}")
result_data = data.get("data", {})
status = result_data.get("status")
if status == "completed":
# 提取图片 URL
images = result_data.get("images", [])
if not images:
# 兼容不同返回格式
image_url = result_data.get("imageUrl") or result_data.get("url")
if image_url:
images = [image_url]
if not images:
raise ValueError("CQTAI 未返回图片 URL")
return images
elif status == "failed":
error_msg = result_data.get("error", "生成失败")
raise ValueError(f"CQTAI 图像生成失败: {error_msg}")
# 继续等待
logger.debug(
"cqtai_poll_waiting",
task_id=task_id,
attempt=attempt + 1,
status=status,
)
await asyncio.sleep(POLL_INTERVAL_SECONDS)
raise TimeoutError(f"CQTAI 任务超时: {task_id}")

View File

@@ -0,0 +1,73 @@
"""适配器注册表 - 支持动态注册和工厂创建。"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.services.adapters.base import AdapterConfig, BaseAdapter
class AdapterRegistry:
"""适配器注册表,管理所有已注册的适配器类。"""
_adapters: dict[str, type["BaseAdapter"]] = {}
@classmethod
def register(cls, adapter_type: str, adapter_name: str):
"""装饰器:注册适配器类。
用法:
@AdapterRegistry.register("text", "text_primary")
class TextPrimaryAdapter(BaseAdapter[StoryOutput]):
...
"""
def decorator(adapter_class: type["BaseAdapter"]):
key = f"{adapter_type}:{adapter_name}"
cls._adapters[key] = adapter_class
# 自动设置类属性
adapter_class.adapter_type = adapter_type
adapter_class.adapter_name = adapter_name
return adapter_class
return decorator
@classmethod
def get(cls, adapter_type: str, adapter_name: str) -> type["BaseAdapter"] | None:
"""获取已注册的适配器类。"""
key = f"{adapter_type}:{adapter_name}"
return cls._adapters.get(key)
@classmethod
def list_adapters(cls, adapter_type: str | None = None) -> list[str]:
"""列出所有已注册的适配器。
Args:
adapter_type: 可选,筛选特定类型 (text/image/tts)
Returns:
适配器键列表,格式为 "type:name"
"""
if adapter_type:
return [k for k in cls._adapters if k.startswith(f"{adapter_type}:")]
return list(cls._adapters.keys())
@classmethod
def create(
cls,
adapter_type: str,
adapter_name: str,
config: "AdapterConfig",
) -> "BaseAdapter":
"""工厂方法:创建适配器实例。
Raises:
ValueError: 适配器未注册
"""
adapter_class = cls.get(adapter_type, adapter_name)
if not adapter_class:
available = cls.list_adapters(adapter_type)
raise ValueError(
f"适配器 '{adapter_type}:{adapter_name}' 未注册。"
f"可用: {available}"
)
return adapter_class(config)

View File

@@ -0,0 +1 @@
"""Storybook 适配器模块。"""

View File

@@ -0,0 +1,195 @@
"""Storybook 适配器 - 生成可翻页的分页故事书。"""
import json
import random
import re
import time
from dataclasses import dataclass, field
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_STORYBOOK,
USER_PROMPT_STORYBOOK,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@dataclass
class StorybookPage:
"""故事书单页。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
@dataclass
class Storybook:
"""故事书输出。"""
title: str
main_character: str
art_style: str
pages: list[StorybookPage] = field(default_factory=list)
cover_prompt: str = ""
cover_url: str | None = None
@AdapterRegistry.register("storybook", "storybook_primary")
class StorybookPrimaryAdapter(BaseAdapter[Storybook]):
"""Storybook 生成适配器(默认)。
生成分页故事书结构,包含每页文字和图像提示词。
图像生成需要单独调用 image adapter。
"""
adapter_type = "storybook"
adapter_name = "storybook_primary"
async def execute(
self,
keywords: str,
page_count: int = 6,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> Storybook:
"""生成分页故事书。
Args:
keywords: 故事关键词
page_count: 页数 (4-12)
education_theme: 教育主题
memory_context: 记忆上下文
Returns:
Storybook 对象,包含标题、页面列表和封面提示词
"""
start_time = time.time()
page_count = max(4, min(page_count, 12)) # 限制 4-12 页
logger.info(
"storybook_generate_start",
keywords=keywords,
page_count=page_count,
has_memory=bool(memory_context),
)
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
prompt = USER_PROMPT_STORYBOOK.format(
keywords=keywords,
education_theme=theme,
random_element=random_element,
page_count=page_count,
memory_context=memory_context or "",
)
payload = {
"system_instruction": {"parts": [{"text": SYSTEM_INSTRUCTION_STORYBOOK}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Storybook 服务未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Storybook 服务响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Storybook JSON 解析失败: {exc}")
# 构建 Storybook 对象
pages = [
StorybookPage(
page_number=p.get("page_number", i + 1),
text=p.get("text", ""),
image_prompt=p.get("image_prompt", ""),
)
for i, p in enumerate(parsed.get("pages", []))
]
storybook = Storybook(
title=parsed.get("title", "未命名故事"),
main_character=parsed.get("main_character", ""),
art_style=parsed.get("art_style", ""),
pages=pages,
cover_prompt=parsed.get("cover_prompt", ""),
)
elapsed = time.time() - start_time
logger.info(
"storybook_generate_success",
elapsed_seconds=round(elapsed, 2),
title=storybook.title,
page_count=len(pages),
)
return storybook
async def health_check(self) -> bool:
"""检查 API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估成本(仅文本生成,不含图像)。"""
return 0.002 # 比普通故事稍贵,因为输出更长
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 API带重试机制。"""
model = self.config.model or "gemini-2.0-flash"
url = f"{TEXT_API_BASE}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1 @@
"""文本生成适配器。"""

View File

@@ -0,0 +1,164 @@
"""文本生成适配器 (Google Gemini)。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@AdapterRegistry.register("text", "gemini")
class GeminiTextAdapter(BaseAdapter[StoryOutput]):
"""Google Gemini 文本生成适配器。"""
adapter_type = "text"
adapter_name = "gemini"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("request_start", adapter="gemini", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
# Gemini API Payload supports 'system_instruction'
payload = {
"system_instruction": {"parts": [{"text": system_instruction}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Gemini 未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Gemini 响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Gemini 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("Gemini 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"request_success",
adapter="gemini",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
async def health_check(self) -> bool:
"""检查 Gemini API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.001
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 Gemini API。"""
model = self.config.model or "gemini-2.0-flash"
base_url = self.config.api_base or TEXT_API_BASE
# 智能补全:
# 1. 如果用户填了完整路径 (以 /models 结尾),就直接用 (支持 v1 或 v1beta)
if self.config.api_base and base_url.rstrip("/").endswith("/models"):
pass
# 2. 如果没填路径 (只是域名),默认补全代码适配的 /v1beta/models
elif self.config.api_base:
base_url = f"{base_url.rstrip('/')}/v1beta/models"
url = f"{base_url}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,11 @@
from dataclasses import dataclass
from typing import Literal
@dataclass
class StoryOutput:
"""故事生成输出。"""
mode: Literal["generated", "enhanced"]
title: str
story_text: str
cover_prompt_suggestion: str

View File

@@ -0,0 +1,172 @@
"""OpenAI 文本生成适配器。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
OPENAI_API_BASE = "https://api.openai.com/v1/chat/completions"
@AdapterRegistry.register("text", "openai")
class OpenAITextAdapter(BaseAdapter[StoryOutput]):
"""OpenAI 文本生成适配器。"""
adapter_type = "text"
adapter_name = "openai"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("openai_text_request_start", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
model = self.config.model or "gpt-4o-mini"
payload = {
"model": model,
"messages": [
{
"role": "system",
"content": system_instruction,
},
{"role": "user", "content": prompt},
],
"response_format": {"type": "json_object"},
"temperature": 0.95,
"top_p": 0.9,
}
result = await self._call_api(payload)
choices = result.get("choices") or []
if not choices:
raise ValueError("OpenAI 未返回内容")
response_text = choices[0].get("message", {}).get("content", "")
if not response_text:
raise ValueError("OpenAI 响应缺少文本")
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"OpenAI 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("OpenAI 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"openai_text_request_success",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
mode=parsed["mode"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(httpx.HTTPStatusError),
)
async def _call_api(self, payload: dict) -> dict:
"""调用 OpenAI API带重试机制。"""
url = self.config.api_base or OPENAI_API_BASE
# 智能补全: 如果用户只填了 Base URL自动补全路径
if self.config.api_base and not url.endswith("/chat/completions"):
base = url.rstrip("/")
url = f"{base}/chat/completions"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()
async def health_check(self) -> bool:
"""检查 OpenAI API 是否可用。"""
try:
payload = {
"model": self.config.model or "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5,
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估文本生成成本 (USD)。"""
return 0.01

View File

@@ -0,0 +1,5 @@
"""TTS 语音合成适配器。"""
from app.services.adapters.tts import edge_tts as _tts_edge_tts_adapter # noqa: F401
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401

View File

@@ -0,0 +1,66 @@
"""EdgeTTS 免费语音生成适配器。"""
import time
import edge_tts
from app.core.logging import get_logger
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认中文女声 (晓晓)
DEFAULT_VOICE = "zh-CN-XiaoxiaoNeural"
@AdapterRegistry.register("tts", "edge_tts")
class EdgeTTSAdapter(BaseAdapter[bytes]):
"""EdgeTTS 语音生成适配器 (Free)。
不需要 API Key。
"""
adapter_type = "tts"
adapter_name = "edge_tts"
async def execute(self, text: str, **kwargs) -> bytes:
"""生成语音。"""
# 支持动态指定音色
voice = kwargs.get("voice") or self.config.model or DEFAULT_VOICE
start_time = time.time()
logger.info("edge_tts_generate_start", text_length=len(text), voice=voice)
# EdgeTTS 只能输出到文件,我们需要用临时文件周转一下
# 或者直接 capture stream (communicate) 但 edge-tts 库主要面向文件
# 优化: 使用 communicate 直接获取 bytes无需磁盘IO
communicate = edge_tts.Communicate(text, voice)
audio_data = b""
async for chunk in communicate.stream():
if chunk["type"] == "audio":
audio_data += chunk["data"]
elapsed = time.time() - start_time
logger.info(
"edge_tts_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_data),
)
return audio_data
async def health_check(self) -> bool:
"""检查 EdgeTTS 是否可用 (网络连通性)。"""
try:
# 简单生成一个词
await self.execute("Hi")
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.0 # Free!

View File

@@ -0,0 +1,104 @@
"""ElevenLabs TTS 语音合成适配器。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
ELEVENLABS_API_BASE = "https://api.elevenlabs.io/v1"
DEFAULT_VOICE_ID = "21m00Tcm4TlvDq8ikWAM" # Rachel
@AdapterRegistry.register("tts", "elevenlabs")
class ElevenLabsTtsAdapter(BaseAdapter[bytes]):
"""ElevenLabs TTS 语音合成适配器,返回 MP3 bytes。"""
adapter_type = "tts"
adapter_name = "elevenlabs"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or ELEVENLABS_API_BASE
async def execute(self, text: str, **kwargs) -> bytes:
"""将文本转换为语音 MP3 bytes。"""
start_time = time.time()
logger.info("elevenlabs_tts_start", text_length=len(text))
voice_id = kwargs.get("voice_id") or DEFAULT_VOICE_ID
model_id = kwargs.get("model") or self.config.model or "eleven_multilingual_v2"
stability = kwargs.get("stability", 0.5)
similarity_boost = kwargs.get("similarity_boost", 0.75)
url = f"{self.api_base}/text-to-speech/{voice_id}"
payload = {
"text": text,
"model_id": model_id,
"voice_settings": {
"stability": stability,
"similarity_boost": similarity_boost,
},
}
audio_bytes = await self._call_api(url, payload)
elapsed = time.time() - start_time
logger.info(
"elevenlabs_tts_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 ElevenLabs API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(
f"{self.api_base}/voices",
headers={"xi-api-key": self.config.api_key},
)
return response.status_code == 200
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每千字符成本 (USD)。"""
return 0.03
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> bytes:
"""调用 ElevenLabs API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"xi-api-key": self.config.api_key,
"Content-Type": "application/json",
"Accept": "audio/mpeg",
},
)
response.raise_for_status()
return response.content

View File

@@ -0,0 +1,149 @@
"""MiniMax 语音生成适配器 (T2A V2)。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.config import settings
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# MiniMax API 配置
DEFAULT_API_URL = "https://api.minimaxi.com/v1/t2a_v2"
DEFAULT_MODEL = "speech-2.6-turbo"
@AdapterRegistry.register("tts", "minimax")
class MiniMaxTTSAdapter(BaseAdapter[bytes]):
"""MiniMax 语音生成适配器。
需要配置:
- api_key: MiniMax API Key
- minimax_group_id: 可选 (取决于使用的模型/账户类型)
"""
adapter_type = "tts"
adapter_name = "minimax"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_url = DEFAULT_API_URL
async def execute(
self,
text: str,
voice_id: str | None = None,
model: str | None = None,
speed: float | None = None,
vol: float | None = None,
pitch: int | None = None,
emotion: str | None = None,
**kwargs,
) -> bytes:
"""生成语音。"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
voice_id = voice_id or cfg.get("voice_id") or "male-qn-qingse"
speed = speed if speed is not None else (cfg.get("speed") or 1.0)
vol = vol if vol is not None else (cfg.get("vol") or 1.0)
pitch = pitch if pitch is not None else (cfg.get("pitch") or 0)
emotion = emotion or cfg.get("emotion")
group_id = kwargs.get("group_id") or settings.minimax_group_id
url = self.api_url
if group_id:
url = f"{self.api_url}?GroupId={group_id}"
payload = {
"model": model,
"text": text,
"stream": False,
"voice_setting": {
"voice_id": voice_id,
"speed": speed,
"vol": vol,
"pitch": pitch,
},
"audio_setting": {
"sample_rate": 32000,
"bitrate": 128000,
"format": "mp3",
"channel": 1
}
}
if emotion:
payload["voice_setting"]["emotion"] = emotion
start_time = time.time()
logger.info("minimax_generate_start", text_length=len(text), model=model)
result = await self._call_api(url, payload)
# 错误处理
if result.get("base_resp", {}).get("status_code") != 0:
error_msg = result.get("base_resp", {}).get("status_msg", "未知错误")
raise ValueError(f"MiniMax API 错误: {error_msg}")
# Hex 解码 (关键逻辑,从 primary.py 迁移)
hex_audio = result.get("data", {}).get("audio")
if not hex_audio:
raise ValueError("API 响应中未找到音频数据 (data.audio)")
try:
audio_bytes = bytes.fromhex(hex_audio)
except ValueError:
raise ValueError("MiniMax 返回的音频数据不是有效的 Hex 字符串")
elapsed = time.time() - start_time
logger.info(
"minimax_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 Minimax API 是否可用。"""
try:
# 尝试生成极短文本
await self.execute("Hi")
return True
except Exception:
return False
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> dict:
"""调用 API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,196 @@
"""成本追踪服务。
记录 API 调用成本,支持预算控制。
"""
from datetime import datetime, timedelta
from decimal import Decimal
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import CostRecord, UserBudget
logger = get_logger(__name__)
class BudgetExceededError(Exception):
"""预算超限错误。"""
def __init__(self, limit_type: str, used: Decimal, limit: Decimal):
self.limit_type = limit_type
self.used = used
self.limit = limit
super().__init__(f"{limit_type} 预算已超限: {used}/{limit} USD")
class CostTracker:
"""成本追踪器。"""
async def record_cost(
self,
db: AsyncSession,
user_id: str,
provider_name: str,
capability: str,
estimated_cost: float,
provider_id: str | None = None,
) -> CostRecord:
"""记录一次 API 调用成本。"""
record = CostRecord(
user_id=user_id,
provider_id=provider_id,
provider_name=provider_name,
capability=capability,
estimated_cost=Decimal(str(estimated_cost)),
)
db.add(record)
await db.commit()
logger.debug(
"cost_recorded",
user_id=user_id,
provider=provider_name,
capability=capability,
cost=estimated_cost,
)
return record
async def get_daily_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户今日成本。"""
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= today_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_monthly_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户本月成本。"""
now = datetime.utcnow()
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= month_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_cost_by_capability(
self,
db: AsyncSession,
user_id: str,
days: int = 30,
) -> dict[str, Decimal]:
"""按能力类型统计成本。"""
since = datetime.utcnow() - timedelta(days=days)
result = await db.execute(
select(CostRecord.capability, func.sum(CostRecord.estimated_cost))
.where(CostRecord.user_id == user_id, CostRecord.timestamp >= since)
.group_by(CostRecord.capability)
)
return {row[0]: Decimal(str(row[1])) for row in result.all()}
async def check_budget(
self,
db: AsyncSession,
user_id: str,
estimated_cost: float,
) -> bool:
"""检查预算是否允许此次调用。
Returns:
True 如果允许,否则抛出 BudgetExceededError
"""
budget = await self.get_user_budget(db, user_id)
if not budget or not budget.enabled:
return True
# 检查日预算
daily_cost = await self.get_daily_cost(db, user_id)
if daily_cost + Decimal(str(estimated_cost)) > budget.daily_limit_usd:
raise BudgetExceededError("", daily_cost, budget.daily_limit_usd)
# 检查月预算
monthly_cost = await self.get_monthly_cost(db, user_id)
if monthly_cost + Decimal(str(estimated_cost)) > budget.monthly_limit_usd:
raise BudgetExceededError("", monthly_cost, budget.monthly_limit_usd)
return True
async def get_user_budget(self, db: AsyncSession, user_id: str) -> UserBudget | None:
"""获取用户预算配置。"""
result = await db.execute(
select(UserBudget).where(UserBudget.user_id == user_id)
)
return result.scalar_one_or_none()
async def set_user_budget(
self,
db: AsyncSession,
user_id: str,
daily_limit: float | None = None,
monthly_limit: float | None = None,
alert_threshold: float | None = None,
enabled: bool | None = None,
) -> UserBudget:
"""设置用户预算。"""
budget = await self.get_user_budget(db, user_id)
if budget is None:
budget = UserBudget(user_id=user_id)
db.add(budget)
if daily_limit is not None:
budget.daily_limit_usd = Decimal(str(daily_limit))
if monthly_limit is not None:
budget.monthly_limit_usd = Decimal(str(monthly_limit))
if alert_threshold is not None:
budget.alert_threshold = Decimal(str(alert_threshold))
if enabled is not None:
budget.enabled = enabled
await db.commit()
await db.refresh(budget)
return budget
async def get_cost_summary(
self,
db: AsyncSession,
user_id: str,
) -> dict:
"""获取用户成本摘要。"""
daily = await self.get_daily_cost(db, user_id)
monthly = await self.get_monthly_cost(db, user_id)
by_capability = await self.get_cost_by_capability(db, user_id)
budget = await self.get_user_budget(db, user_id)
return {
"daily_cost_usd": float(daily),
"monthly_cost_usd": float(monthly),
"by_capability": {k: float(v) for k, v in by_capability.items()},
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd) if budget else None,
"monthly_limit_usd": float(budget.monthly_limit_usd) if budget else None,
"daily_usage_percent": float(daily / budget.daily_limit_usd * 100)
if budget and budget.daily_limit_usd
else None,
"monthly_usage_percent": float(monthly / budget.monthly_limit_usd * 100)
if budget and budget.monthly_limit_usd
else None,
"enabled": budget.enabled if budget else False,
},
}
# 全局单例
cost_tracker = CostTracker()

View File

@@ -0,0 +1,471 @@
"""Memory service handles memory retrieval, scoring, and prompt injection."""
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models import ChildProfile, MemoryItem, StoryUniverse
logger = get_logger(__name__)
class MemoryType:
"""记忆类型常量及配置。"""
# 基础类型
RECENT_STORY = "recent_story"
FAVORITE_CHARACTER = "favorite_character"
SCARY_ELEMENT = "scary_element"
VOCABULARY_GROWTH = "vocabulary_growth"
EMOTIONAL_HIGHLIGHT = "emotional_highlight"
# Phase 1 新增类型
READING_PREFERENCE = "reading_preference" # 阅读偏好
MILESTONE = "milestone" # 里程碑事件
SKILL_MASTERED = "skill_mastered" # 掌握的技能
# 类型配置: (默认权重, 默认TTL天数, 描述)
CONFIG = {
RECENT_STORY: (1.0, 30, "最近阅读的故事"),
FAVORITE_CHARACTER: (1.5, None, "喜欢的角色"), # None = 永久
SCARY_ELEMENT: (2.0, None, "回避的元素"), # 高权重,永久有效
VOCABULARY_GROWTH: (0.8, 90, "词汇积累"),
EMOTIONAL_HIGHLIGHT: (1.2, 60, "情感高光"),
READING_PREFERENCE: (1.0, None, "阅读偏好"),
MILESTONE: (1.5, None, "里程碑事件"),
SKILL_MASTERED: (1.0, 180, "掌握的技能"),
}
@classmethod
def get_default_weight(cls, memory_type: str) -> float:
"""获取类型的默认权重。"""
config = cls.CONFIG.get(memory_type)
return config[0] if config else 1.0
@classmethod
def get_default_ttl(cls, memory_type: str) -> int | None:
"""获取类型的默认 TTL 天数。"""
config = cls.CONFIG.get(memory_type)
return config[1] if config else None
def _decay_factor(days: float) -> float:
"""计算时间衰减因子。"""
if days <= 7:
return 1.0
if days <= 30:
return 0.7
if days <= 90:
return 0.4
return 0.2
async def build_enhanced_memory_context(
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
) -> str | None:
"""构建增强版记忆上下文(自然语言 Prompt"""
if not profile_id and not universe_id:
return None
context_parts: list[str] = []
# 1. 基础档案 (Identity Layer)
if profile_id:
profile = await db.scalar(select(ChildProfile).where(ChildProfile.id == profile_id))
if profile:
context_parts.append(f"【目标读者】\n姓名:{profile.name}")
if profile.age:
context_parts.append(f"年龄:{profile.age}")
if profile.interests:
context_parts.append(f"兴趣爱好:{''.join(profile.interests)}")
if profile.growth_themes:
context_parts.append(f"当前成长关注点:{''.join(profile.growth_themes)}")
context_parts.append("") # 空行
# 2. 故事宇宙 (Universe Layer)
if universe_id:
universe = await db.scalar(select(StoryUniverse).where(StoryUniverse.id == universe_id))
if universe:
context_parts.append("【故事宇宙设定】")
context_parts.append(f"世界观:{universe.name}")
# 主角
protagonist = universe.protagonist or {}
p_desc = f"{protagonist.get('name', '主角')} ({protagonist.get('personality', '')})"
context_parts.append(f"主角设定:{p_desc}")
# 常驻角色
if universe.recurring_characters:
chars = [f"{c.get('name')} ({c.get('type')})" for c in universe.recurring_characters if isinstance(c, dict)]
context_parts.append(f"已知伙伴:{''.join(chars)}")
# 成就
if universe.achievements:
badges = [str(a.get('type')) for a in universe.achievements if isinstance(a, dict)]
if badges:
context_parts.append(f"已获荣誉:{''.join(badges[:5])}")
context_parts.append("")
# 3. 动态记忆 (Working Memory)
if profile_id:
memories = await _fetch_scored_memories(profile_id, universe_id, db)
if memories:
memory_text = _format_memories_to_prompt(memories)
if memory_text:
context_parts.append("【关键记忆回忆】(请在故事中自然地融入或致敬以下元素)")
context_parts.append(memory_text)
return "\n".join(context_parts)
async def _fetch_scored_memories(
profile_id: str,
universe_id: str | None,
db: AsyncSession,
limit: int = 8
) -> list[MemoryItem]:
"""获取并评分记忆项,返回 Top N。"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
# 取最近 50 条进行评分
query = query.order_by(MemoryItem.last_used_at.desc(), MemoryItem.created_at.desc()).limit(50)
result = await db.execute(query)
items = result.scalars().all()
scored: list[tuple[float, MemoryItem]] = []
now = datetime.now(timezone.utc)
for item in items:
reference = item.last_used_at or item.created_at or now
delta_days = max((now - reference).total_seconds() / 86400, 0)
if item.ttl_days and delta_days > item.ttl_days:
continue
score = (item.base_weight or 1.0) * _decay_factor(delta_days)
if score <= 0.1: # 忽略低权重
continue
scored.append((score, item))
scored.sort(key=lambda x: x[0], reverse=True)
return [item for _, item in scored[:limit]]
def _format_memories_to_prompt(memories: list[MemoryItem]) -> str:
"""将记忆项转换为自然语言指令。"""
lines = []
# 分类处理
recent_stories = []
favorites = []
scary = []
vocab = []
for m in memories:
if m.type == MemoryType.RECENT_STORY:
recent_stories.append(m)
elif m.type == MemoryType.FAVORITE_CHARACTER:
favorites.append(m)
elif m.type == MemoryType.SCARY_ELEMENT:
scary.append(m)
elif m.type == MemoryType.VOCABULARY_GROWTH:
vocab.append(m)
# 1. 喜欢的角色
if favorites:
names = []
for m in favorites:
val = m.value
if isinstance(val, dict):
names.append(f"{val.get('name')} ({val.get('description', '')})")
if names:
lines.append(f"- 孩子特别喜欢这些角色,可以让他们客串出场:{', '.join(names)}")
# 2. 避雷区
if scary:
items = []
for m in scary:
val = m.value
if isinstance(val, dict):
items.append(val.get('keyword', ''))
elif isinstance(val, str):
items.append(val)
if items:
lines.append(f"- 【注意禁止】不要出现以下让孩子害怕的元素:{', '.join(items)}")
# 3. 近期故事 (取最近 2 个)
if recent_stories:
lines.append("- 近期经历(可作为彩蛋提及):")
for m in recent_stories[:2]:
val = m.value
if isinstance(val, dict):
title = val.get('title', '未知故事')
lines.append(f" * 之前读过《{title}")
# 4. 词汇积累
if vocab:
words = []
for m in vocab:
val = m.value
if isinstance(val, dict):
words.append(val.get('word'))
if words:
lines.append(f"- 已掌握词汇(可适当复现以巩固):{', '.join([w for w in words if w])}")
return "\n".join(lines)
async def prune_expired_memories(db: AsyncSession) -> int:
"""清理过期的记忆项。
Returns:
删除的记录数量
"""
from sqlalchemy import delete
now = datetime.now(timezone.utc)
# 查找所有设置了 TTL 的项目
stmt = select(MemoryItem).where(MemoryItem.ttl_days.is_not(None))
result = await db.execute(stmt)
candidates = result.scalars().all()
to_delete_ids = []
for item in candidates:
if not item.ttl_days:
continue
reference = item.last_used_at or item.created_at or now
delta_days = (now - reference).total_seconds() / 86400
if delta_days > item.ttl_days:
to_delete_ids.append(item.id)
if not to_delete_ids:
return 0
delete_stmt = delete(MemoryItem).where(MemoryItem.id.in_(to_delete_ids))
await db.execute(delete_stmt)
await db.commit()
logger.info("memory_pruned", count=len(to_delete_ids))
return len(to_delete_ids)
async def create_memory(
db: AsyncSession,
profile_id: str,
memory_type: str,
value: dict,
universe_id: str | None = None,
weight: float | None = None,
ttl_days: int | None = None,
) -> MemoryItem:
"""创建新的记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 记忆类型 (使用 MemoryType 常量)
value: 记忆内容 (JSON 格式)
universe_id: 可选,关联的故事宇宙 ID
weight: 可选,权重 (默认使用类型配置)
ttl_days: 可选,过期天数 (默认使用类型配置)
Returns:
创建的 MemoryItem
"""
memory = MemoryItem(
child_profile_id=profile_id,
universe_id=universe_id,
type=memory_type,
value=value,
base_weight=weight or MemoryType.get_default_weight(memory_type),
ttl_days=ttl_days if ttl_days is not None else MemoryType.get_default_ttl(memory_type),
)
db.add(memory)
await db.commit()
await db.refresh(memory)
logger.info(
"memory_created",
memory_id=memory.id,
profile_id=profile_id,
type=memory_type,
)
return memory
async def update_memory_usage(db: AsyncSession, memory_id: str) -> None:
"""更新记忆的最后使用时间。
Args:
db: 数据库会话
memory_id: 记忆项 ID
"""
result = await db.execute(select(MemoryItem).where(MemoryItem.id == memory_id))
memory = result.scalar_one_or_none()
if memory:
memory.last_used_at = datetime.now(timezone.utc)
await db.commit()
logger.debug("memory_usage_updated", memory_id=memory_id)
async def get_profile_memories(
db: AsyncSession,
profile_id: str,
memory_type: str | None = None,
universe_id: str | None = None,
limit: int = 50,
) -> list[MemoryItem]:
"""获取档案的记忆列表。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 可选,按类型筛选
universe_id: 可选,按宇宙筛选
limit: 返回数量限制
Returns:
MemoryItem 列表
"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if memory_type:
query = query.where(MemoryItem.type == memory_type)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
query = query.order_by(MemoryItem.created_at.desc()).limit(limit)
result = await db.execute(query)
return list(result.scalars().all())
async def create_story_memory(
db: AsyncSession,
profile_id: str,
story_id: int,
title: str,
summary: str | None = None,
keywords: list[str] | None = None,
universe_id: str | None = None,
) -> MemoryItem:
"""为故事创建记忆项。
这是一个便捷函数,专门用于在故事阅读后创建 recent_story 类型的记忆。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
story_id: 故事 ID
title: 故事标题
summary: 故事梗概
keywords: 关键词列表
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"story_id": story_id,
"title": title,
"summary": summary or "",
"keywords": keywords or [],
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.RECENT_STORY,
value=value,
universe_id=universe_id,
)
async def create_character_memory(
db: AsyncSession,
profile_id: str,
name: str,
description: str | None = None,
source_story_id: int | None = None,
affinity_score: float = 1.0,
universe_id: str | None = None,
) -> MemoryItem:
"""为喜欢的角色创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
name: 角色名称
description: 角色描述
source_story_id: 来源故事 ID
affinity_score: 喜爱程度 (0.0-1.0)
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"name": name,
"description": description or "",
"source_story_id": source_story_id,
"affinity_score": min(1.0, max(0.0, affinity_score)),
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.FAVORITE_CHARACTER,
value=value,
universe_id=universe_id,
)
async def create_scary_element_memory(
db: AsyncSession,
profile_id: str,
keyword: str,
category: str = "other",
source_story_id: int | None = None,
) -> MemoryItem:
"""为回避元素创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
keyword: 回避的关键词
category: 分类 (creature/scene/action/other)
source_story_id: 来源故事 ID
Returns:
创建的 MemoryItem
"""
value = {
"keyword": keyword,
"category": category,
"source_story_id": source_story_id,
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.SCARY_ELEMENT,
value=value,
)

View File

@@ -0,0 +1,31 @@
"""In-memory cache for providers loaded from DB."""
from collections import defaultdict
from typing import Literal
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.admin_models import Provider
ProviderType = Literal["text", "image", "tts", "storybook"]
_cache: dict[ProviderType, list[Provider]] = defaultdict(list)
async def reload_providers(db: AsyncSession):
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
providers = result.scalars().all()
grouped: dict[ProviderType, list[Provider]] = defaultdict(list)
for p in providers:
grouped[p.type].append(p)
# sort by priority desc, then weight desc
for k in grouped:
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
_cache.clear()
_cache.update(grouped)
return _cache
def get_providers(provider_type: ProviderType) -> list[Provider]:
return _cache.get(provider_type, [])

View File

@@ -0,0 +1,248 @@
"""供应商指标收集和健康检查服务。"""
from datetime import datetime, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import ProviderHealth, ProviderMetrics
if TYPE_CHECKING:
from app.services.adapters.base import BaseAdapter
logger = get_logger(__name__)
# 熔断阈值:连续失败次数
CIRCUIT_BREAKER_THRESHOLD = 3
# 熔断恢复时间(秒)
CIRCUIT_BREAKER_RECOVERY_SECONDS = 60
class MetricsCollector:
"""供应商调用指标收集器。"""
async def record_call(
self,
db: AsyncSession,
provider_id: str,
success: bool,
latency_ms: int | None = None,
cost_usd: float | None = None,
error_message: str | None = None,
request_id: str | None = None,
) -> None:
"""记录一次 API 调用。"""
metric = ProviderMetrics(
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
cost_usd=Decimal(str(cost_usd)) if cost_usd else None,
error_message=error_message,
request_id=request_id,
)
db.add(metric)
await db.commit()
logger.debug(
"metrics_recorded",
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
)
async def get_success_rate(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的成功率。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(
func.count().filter(ProviderMetrics.success.is_(True)).label("success_count"),
func.count().label("total_count"),
).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
row = result.one()
success_count, total_count = row.success_count, row.total_count
if total_count == 0:
return 1.0 # 无数据时假设健康
return success_count / total_count
async def get_avg_latency(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的平均延迟(毫秒)。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.avg(ProviderMetrics.latency_ms)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
ProviderMetrics.latency_ms.isnot(None),
)
)
avg = result.scalar()
return float(avg) if avg else 0.0
async def get_total_cost(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的总成本USD"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.sum(ProviderMetrics.cost_usd)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
total = result.scalar()
return float(total) if total else 0.0
class HealthChecker:
"""供应商健康检查器。"""
async def check_provider(
self,
db: AsyncSession,
provider_id: str,
adapter: "BaseAdapter",
) -> bool:
"""执行健康检查并更新状态。"""
try:
is_healthy = await adapter.health_check()
except Exception as e:
logger.warning("health_check_failed", provider_id=provider_id, error=str(e))
is_healthy = False
await self.update_health_status(
db,
provider_id,
is_healthy,
error=None if is_healthy else "Health check failed",
)
return is_healthy
async def update_health_status(
self,
db: AsyncSession,
provider_id: str,
is_healthy: bool,
error: str | None = None,
) -> None:
"""更新供应商健康状态(含熔断逻辑)。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
now = datetime.utcnow()
if health is None:
health = ProviderHealth(
provider_id=provider_id,
is_healthy=is_healthy,
last_check=now,
consecutive_failures=0 if is_healthy else 1,
last_error=error,
)
db.add(health)
else:
health.last_check = now
if is_healthy:
health.is_healthy = True
health.consecutive_failures = 0
health.last_error = None
else:
health.consecutive_failures += 1
health.last_error = error
# 熔断逻辑
if health.consecutive_failures >= CIRCUIT_BREAKER_THRESHOLD:
health.is_healthy = False
logger.warning(
"circuit_breaker_triggered",
provider_id=provider_id,
consecutive_failures=health.consecutive_failures,
)
await db.commit()
async def record_call_result(
self,
db: AsyncSession,
provider_id: str,
success: bool,
error: str | None = None,
) -> None:
"""根据调用结果更新健康状态。"""
await self.update_health_status(db, provider_id, success, error)
async def get_healthy_providers(
self,
db: AsyncSession,
provider_ids: list[str],
) -> list[str]:
"""获取健康的供应商列表。"""
if not provider_ids:
return []
# 查询所有已记录的健康状态
result = await db.execute(
select(ProviderHealth.provider_id, ProviderHealth.is_healthy).where(
ProviderHealth.provider_id.in_(provider_ids),
)
)
health_map = {row[0]: row[1] for row in result.all()}
# 未记录的供应商默认健康,已记录但不健康的排除
return [
pid for pid in provider_ids
if pid not in health_map or health_map[pid]
]
async def is_healthy(
self,
db: AsyncSession,
provider_id: str,
) -> bool:
"""检查供应商是否健康。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
if health is None:
return True # 未记录默认健康
# 检查是否可以恢复
if not health.is_healthy and health.last_check:
recovery_time = health.last_check + timedelta(seconds=CIRCUIT_BREAKER_RECOVERY_SECONDS)
if datetime.utcnow() >= recovery_time:
return True # 允许重试
return health.is_healthy
# 全局单例
metrics_collector = MetricsCollector()
health_checker = HealthChecker()

View File

@@ -0,0 +1,432 @@
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
import time
from enum import Enum
from typing import TYPE_CHECKING, Literal, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.services.adapters import AdapterConfig, AdapterRegistry
from app.services.adapters.text.models import StoryOutput
from app.services.cost_tracker import cost_tracker
from app.services.provider_cache import get_providers
from app.services.provider_metrics import health_checker, metrics_collector
if TYPE_CHECKING:
from app.db.admin_models import Provider
logger = get_logger(__name__)
T = TypeVar("T")
ProviderType = Literal["text", "image", "tts", "storybook"]
class RoutingStrategy(str, Enum):
"""路由策略枚举。"""
PRIORITY = "priority" # 按优先级排序(默认)
COST = "cost" # 按成本排序
LATENCY = "latency" # 按延迟排序
ROUND_ROBIN = "round_robin" # 轮询
# 默认配置映射(当 DB 无配置时使用)
# 默认配置映射(当 DB 无配置时使用)
# 这是“代码级”的默认策略,对应 .env 为空的情况
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
"text": ["gemini", "openai"],
"image": ["cqtai"],
"tts": ["minimax", "elevenlabs", "edge_tts"],
"storybook": ["gemini"],
}
# API Key 映射adapter_name -> settings 属性名
API_KEY_MAP: dict[str, str] = {
# Text
"gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段
"text_primary": "text_api_key", # 兼容旧别名
"openai": "openai_api_key",
# Image
"cqtai": "cqtai_api_key",
"image_primary": "image_api_key", # 兼容旧别名
# TTS
"minimax": "minimax_api_key",
"elevenlabs": "elevenlabs_api_key",
"edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空)
"tts_primary": "tts_api_key", # 兼容旧别名
}
# 轮询计数器
_round_robin_counters: dict[ProviderType, int] = {
"text": 0,
"image": 0,
"tts": 0,
}
# 延迟缓存(内存中,简化实现)
_latency_cache: dict[str, float] = {}
def _get_api_key(config_ref: str | None, adapter_name: str) -> str:
"""根据 config_ref 或适配器名称获取 API Key。"""
# 优先使用 config_ref
key_attr = API_KEY_MAP.get(config_ref or adapter_name, None)
if key_attr:
return getattr(settings, key_attr, "")
# 回退到适配器名称
key_attr = API_KEY_MAP.get(adapter_name, None)
if key_attr:
return getattr(settings, key_attr, "")
return ""
def _get_default_config(adapter_name: str) -> AdapterConfig | None:
"""获取适配器的默认配置(无 DB 记录时使用)。返回 None 表示未知适配器。"""
# --- Text Defaults ---
if adapter_name in ("gemini", "text_primary"):
return AdapterConfig(
api_key=settings.text_api_key,
model=settings.text_model or "gemini-2.0-flash",
timeout_ms=60000,
)
if adapter_name == "openai":
return AdapterConfig(
api_key=getattr(settings, "openai_api_key", ""),
model="gpt-4o-mini", # 这里可以从 settings 读取,看需求
timeout_ms=60000,
)
# --- Image Defaults ---
if adapter_name in ("cqtai"):
return AdapterConfig(
api_key=getattr(settings, "cqtai_api_key", ""),
model="nano-banana-pro", # 默认使用 Pro
timeout_ms=120000,
)
if adapter_name == "image_primary":
# 如果还有地方在用 image_primary暂时映射到快或者其他
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
return AdapterConfig(
api_key=settings.image_api_key,
timeout_ms=120000
)
# --- TTS Defaults ---
if adapter_name == "minimax":
# 传递 group_id 到 Adapter
# 目前 AdapterConfig 没有 group_id 字段,我们暂时不改 Base
# 而是假设 Adapter 会从 config (通过 kwargs 或其他方式) 拿。
# 实际上我们的 MiniMaxTTSAdapter 还没有处理 group_id。
# 最简单的方法:把 group_id 藏在 api_base 里或者让 Adapter 自己去 settings 拿。
# 鉴于 _build_config_from_provider 里我们无法传递额外参数给 Adapter.__init__
# 我们这里暂时返回基础配置。
return AdapterConfig(
api_key=getattr(settings, "minimax_api_key", ""),
model="speech-2.6-turbo",
timeout_ms=60000,
)
if adapter_name == "elevenlabs":
return AdapterConfig(
api_key=getattr(settings, "elevenlabs_api_key", ""),
timeout_ms=120000,
)
if adapter_name in ("edge_tts", "tts_primary"):
return AdapterConfig(
api_key=settings.tts_api_key,
api_base=settings.tts_api_base,
model=settings.tts_model or "zh-CN-XiaoxiaoNeural",
timeout_ms=120000,
)
# --- Others ---
if adapter_name in ("storybook_primary", "storybook_gemini"):
return AdapterConfig(
api_key=settings.text_api_key, # 复用 Gemini key
model=settings.text_model,
timeout_ms=120000,
)
# 未知适配器返回 None
return None
def _build_config_from_provider(provider: "Provider") -> AdapterConfig:
"""从 DB Provider 记录构建 AdapterConfig。"""
api_key = getattr(provider, "api_key", None) or ""
if not api_key:
api_key = _get_api_key(provider.config_ref, provider.adapter)
default = _get_default_config(provider.adapter)
if default is None:
default = AdapterConfig(api_key="", timeout_ms=60000)
return AdapterConfig(
api_key=api_key or default.api_key,
api_base=provider.api_base or default.api_base,
model=provider.model or default.model,
timeout_ms=provider.timeout_ms or default.timeout_ms,
max_retries=provider.max_retries or default.max_retries,
extra_config=provider.config_json or {},
)
def _get_providers_with_config(
provider_type: ProviderType,
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
"""获取供应商列表及其配置。
Returns:
[(adapter_name, config, provider_or_none), ...] 按优先级排序
"""
db_providers = get_providers(provider_type)
if db_providers:
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
settings_map = {
"text": settings.text_providers,
"image": settings.image_providers,
"tts": settings.tts_providers,
}
names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type]
result = []
for name in names:
config = _get_default_config(name)
if config is None:
logger.warning("unknown_adapter_skipped", adapter=name, provider_type=provider_type)
continue
result.append((name, config, None))
return result
def _sort_by_strategy(
providers: list[tuple[str, AdapterConfig, "Provider | None"]],
strategy: RoutingStrategy,
provider_type: ProviderType,
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
"""按策略排序供应商列表。"""
if strategy == RoutingStrategy.PRIORITY:
# 按 priority 降序, weight 降序
return sorted(
providers,
key=lambda x: (-(x[2].priority if x[2] else 0), -(x[2].weight if x[2] else 1)),
)
if strategy == RoutingStrategy.COST:
# 按预估成本升序
def get_cost(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
adapter_class = AdapterRegistry.get(provider_type, item[0])
if adapter_class:
try:
adapter = adapter_class(item[1])
return adapter.estimated_cost
except Exception:
pass
return float("inf")
return sorted(providers, key=get_cost)
if strategy == RoutingStrategy.LATENCY:
# 按历史延迟升序
def get_latency(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
return _latency_cache.get(item[0], float("inf"))
return sorted(providers, key=get_latency)
if strategy == RoutingStrategy.ROUND_ROBIN:
# 轮询:旋转列表
counter = _round_robin_counters[provider_type]
_round_robin_counters[provider_type] = (counter + 1) % max(len(providers), 1)
return providers[counter:] + providers[:counter]
return providers
async def _route_with_failover(
provider_type: ProviderType,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
**kwargs,
) -> T:
"""通用 provider failover 路由。
Args:
provider_type: 供应商类型 (text/image/tts/storybook)
strategy: 路由策略
db: 数据库会话(可选,用于指标收集和熔断检查)
user_id: 用户 ID可选用于成本追踪和预算检查
**kwargs: 传递给适配器的参数
"""
providers = _get_providers_with_config(provider_type)
if not providers:
raise ValueError(f"No {provider_type} providers configured.")
# 按策略排序
sorted_providers = _sort_by_strategy(providers, strategy, provider_type)
# 如果有 db 会话,过滤掉熔断的供应商
if db:
healthy_providers = []
for item in sorted_providers:
name, config, db_provider = item
provider_id = db_provider.id if db_provider else name
if await health_checker.is_healthy(db, provider_id):
healthy_providers.append(item)
else:
logger.debug("provider_circuit_open", adapter=name, provider_id=provider_id)
# 如果所有供应商都熔断,仍然尝试第一个(允许恢复)
if not healthy_providers:
healthy_providers = sorted_providers[:1]
sorted_providers = healthy_providers
errors: list[str] = []
for name, config, db_provider in sorted_providers:
adapter_class = AdapterRegistry.get(provider_type, name)
if not adapter_class:
errors.append(f"{name}: 适配器未注册")
continue
provider_id = db_provider.id if db_provider else name
try:
logger.debug(
"provider_attempt",
provider_type=provider_type,
adapter=name,
strategy=strategy.value,
)
adapter = adapter_class(config)
# 执行并计时
start_time = time.time()
result = await adapter.execute(**kwargs)
latency_ms = int((time.time() - start_time) * 1000)
# 更新延迟缓存
_latency_cache[name] = latency_ms
# 记录成功指标
if db:
await metrics_collector.record_call(
db,
provider_id=provider_id,
success=True,
latency_ms=latency_ms,
cost_usd=adapter.estimated_cost,
)
await health_checker.record_call_result(db, provider_id, success=True)
# 记录用户成本
if user_id:
await cost_tracker.record_cost(
db,
user_id=user_id,
provider_name=name,
capability=provider_type,
estimated_cost=adapter.estimated_cost,
provider_id=provider_id if db_provider else None,
)
logger.info(
"provider_success",
provider_type=provider_type,
adapter=name,
latency_ms=latency_ms,
)
return result
except Exception as exc:
error_msg = str(exc)
logger.warning(
"provider_failed",
provider_type=provider_type,
adapter=name,
error=error_msg,
)
errors.append(f"{name}: {exc}")
# 记录失败指标
if db:
await metrics_collector.record_call(
db,
provider_id=provider_id,
success=False,
error_message=error_msg,
)
await health_checker.record_call_result(
db, provider_id, success=False, error=error_msg
)
raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}")
async def generate_story_content(
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
) -> StoryOutput:
"""生成或润色故事,支持 failover。"""
return await _route_with_failover(
"text",
strategy=strategy,
db=db,
input_type=input_type,
data=data,
education_theme=education_theme,
memory_context=memory_context,
)
async def generate_image(
prompt: str,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
**kwargs,
) -> str:
"""生成图片,返回 URL支持 failover。"""
return await _route_with_failover("image", strategy=strategy, db=db, prompt=prompt, **kwargs)
async def text_to_speech(
text: str,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
) -> bytes:
"""文本转语音,返回 MP3 bytes支持 failover。"""
return await _route_with_failover("tts", strategy=strategy, db=db, text=text)
async def generate_storybook(
keywords: str,
page_count: int = 6,
education_theme: str | None = None,
memory_context: str | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
):
"""生成分页故事书,支持 failover。"""
from app.services.adapters.storybook.primary import Storybook
result: Storybook = await _route_with_failover(
"storybook",
strategy=strategy,
db=db,
keywords=keywords,
page_count=page_count,
education_theme=education_theme,
memory_context=memory_context,
)
return result

View File

@@ -0,0 +1,207 @@
"""供应商密钥加密存储服务。
使用 Fernet 对称加密,密钥从 SECRET_KEY 派生。
"""
import base64
import hashlib
from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.admin_models import ProviderSecret
if TYPE_CHECKING:
pass
logger = get_logger(__name__)
class SecretEncryptionError(Exception):
"""密钥加密/解密错误。"""
pass
class SecretService:
"""供应商密钥加密存储服务。"""
_fernet: Fernet | None = None
@classmethod
def _get_fernet(cls) -> Fernet:
"""获取 Fernet 实例,从 SECRET_KEY 派生加密密钥。"""
if cls._fernet is None:
# 从 SECRET_KEY 派生 32 字节密钥
key_bytes = hashlib.sha256(settings.secret_key.encode()).digest()
fernet_key = base64.urlsafe_b64encode(key_bytes)
cls._fernet = Fernet(fernet_key)
return cls._fernet
@classmethod
def encrypt(cls, plaintext: str) -> str:
"""加密明文,返回 base64 编码的密文。
Args:
plaintext: 要加密的明文
Returns:
base64 编码的密文
"""
if not plaintext:
return ""
fernet = cls._get_fernet()
encrypted = fernet.encrypt(plaintext.encode())
return encrypted.decode()
@classmethod
def decrypt(cls, ciphertext: str) -> str:
"""解密密文,返回明文。
Args:
ciphertext: base64 编码的密文
Returns:
解密后的明文
Raises:
SecretEncryptionError: 解密失败
"""
if not ciphertext:
return ""
try:
fernet = cls._get_fernet()
decrypted = fernet.decrypt(ciphertext.encode())
return decrypted.decode()
except InvalidToken as e:
logger.error("secret_decrypt_failed", error=str(e))
raise SecretEncryptionError("密钥解密失败,可能是 SECRET_KEY 已更改") from e
@classmethod
async def get_secret(cls, db: AsyncSession, name: str) -> str | None:
"""从数据库获取并解密密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
解密后的密钥值,不存在返回 None
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return None
return cls.decrypt(secret.encrypted_value)
@classmethod
async def set_secret(cls, db: AsyncSession, name: str, value: str) -> ProviderSecret:
"""存储或更新加密密钥。
Args:
db: 数据库会话
name: 密钥名称
value: 密钥明文值
Returns:
ProviderSecret 实例
"""
encrypted = cls.encrypt(value)
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
secret = ProviderSecret(name=name, encrypted_value=encrypted)
db.add(secret)
else:
secret.encrypted_value = encrypted
await db.commit()
await db.refresh(secret)
logger.info("secret_stored", name=name)
return secret
@classmethod
async def delete_secret(cls, db: AsyncSession, name: str) -> bool:
"""删除密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
是否删除成功
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return False
await db.delete(secret)
await db.commit()
logger.info("secret_deleted", name=name)
return True
@classmethod
async def list_secrets(cls, db: AsyncSession) -> list[str]:
"""列出所有密钥名称(不返回值)。
Args:
db: 数据库会话
Returns:
密钥名称列表
"""
result = await db.execute(select(ProviderSecret.name))
return [row[0] for row in result.fetchall()]
@classmethod
async def get_api_key(
cls,
db: AsyncSession,
provider_api_key: str | None,
config_ref: str | None,
) -> str | None:
"""获取 Provider 的 API Key按优先级查找。
优先级:
1. provider.api_key (数据库明文/加密)
2. provider.config_ref 指向的 ProviderSecret
3. 环境变量 (config_ref 作为变量名)
Args:
db: 数据库会话
provider_api_key: Provider 表中的 api_key 字段
config_ref: Provider 表中的 config_ref 字段
Returns:
API Key 或 None
"""
# 1. 直接使用 provider.api_key
if provider_api_key:
# 尝试解密,如果失败则当作明文
try:
decrypted = cls.decrypt(provider_api_key)
if decrypted:
return decrypted
except SecretEncryptionError:
pass
return provider_api_key
# 2. 从 ProviderSecret 表查找
if config_ref:
secret_value = await cls.get_secret(db, config_ref)
if secret_value:
return secret_value
# 3. 从环境变量查找
env_value = getattr(settings, config_ref.lower(), None)
if env_value:
return env_value
return None