wip: snapshot full local workspace state
Some checks are pending
Build and Push Docker Images / changes (push) Waiting to run
Build and Push Docker Images / build-backend (push) Blocked by required conditions
Build and Push Docker Images / build-frontend (push) Blocked by required conditions
Build and Push Docker Images / build-admin-frontend (push) Blocked by required conditions

This commit is contained in:
2026-04-17 18:58:11 +08:00
parent fea4ef012f
commit b8d3cb4644
181 changed files with 16964 additions and 17486 deletions

View File

@@ -1,21 +1,21 @@
"""适配器模块 - 供应商平台化架构核心。"""
from app.services.adapters.base import AdapterConfig, BaseAdapter
# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.registry import AdapterRegistry
# Storybook adapters
from app.services.adapters.storybook import primary as _storybook_primary # noqa: F401
from app.services.adapters.text import gemini as _text_gemini_adapter # noqa: F401
# 导入所有适配器以触发注册
# Text adapters
from app.services.adapters.text import openai as _text_openai_adapter # noqa: F401
# TTS adapters
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
__all__ = ["AdapterConfig", "BaseAdapter", "AdapterRegistry"]
"""适配器模块 - 供应商平台化架构核心。"""
from app.services.adapters.base import AdapterConfig, BaseAdapter
# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.registry import AdapterRegistry
# Storybook adapters
from app.services.adapters.storybook import primary as _storybook_primary # noqa: F401
from app.services.adapters.text import gemini as _text_gemini_adapter # noqa: F401
# 导入所有适配器以触发注册
# Text adapters
from app.services.adapters.text import openai as _text_openai_adapter # noqa: F401
# TTS adapters
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
__all__ = ["AdapterConfig", "BaseAdapter", "AdapterRegistry"]

View File

@@ -1,46 +1,46 @@
"""适配器基类定义。"""
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from pydantic import BaseModel
T = TypeVar("T")
class AdapterConfig(BaseModel):
"""适配器配置基类。"""
api_key: str
api_base: str | None = None
model: str | None = None
timeout_ms: int = 60000
max_retries: int = 3
extra_config: dict = {}
class BaseAdapter(ABC, Generic[T]):
"""适配器基类,所有供应商适配器必须继承此类。"""
# 子类必须定义
adapter_type: str # text / image / tts
adapter_name: str # text_primary / image_primary / tts_primary
def __init__(self, config: AdapterConfig):
self.config = config
@abstractmethod
async def execute(self, **kwargs) -> T:
"""执行适配器逻辑,返回结果。"""
pass
@abstractmethod
async def health_check(self) -> bool:
"""健康检查,返回是否可用。"""
pass
@property
@abstractmethod
def estimated_cost(self) -> float:
"""预估单次调用成本 (USD)。"""
pass
"""适配器基类定义。"""
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from pydantic import BaseModel
T = TypeVar("T")
class AdapterConfig(BaseModel):
"""适配器配置基类。"""
api_key: str
api_base: str | None = None
model: str | None = None
timeout_ms: int = 60000
max_retries: int = 3
extra_config: dict = {}
class BaseAdapter(ABC, Generic[T]):
"""适配器基类,所有供应商适配器必须继承此类。"""
# 子类必须定义
adapter_type: str # text / image / tts
adapter_name: str # text_primary / image_primary / tts_primary
def __init__(self, config: AdapterConfig):
self.config = config
@abstractmethod
async def execute(self, **kwargs) -> T:
"""执行适配器逻辑,返回结果。"""
pass
@abstractmethod
async def health_check(self) -> bool:
"""健康检查,返回是否可用。"""
pass
@property
@abstractmethod
def estimated_cost(self) -> float:
"""预估单次调用成本 (USD)。"""
pass

View File

@@ -1,3 +1,3 @@
"""图像生成适配器。"""# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.image import antigravity as _image_antigravity_adapter # noqa: F401
"""图像生成适配器。"""# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.image import antigravity as _image_antigravity_adapter # noqa: F401

View File

@@ -1,214 +1,214 @@
"""Antigravity 图像生成适配器。
使用 OpenAI 兼容 API 生成图像。
支持 gemini-3-pro-image 等模型。
"""
import base64
import time
from typing import Any
from openai import AsyncOpenAI
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "http://127.0.0.1:8045/v1"
DEFAULT_MODEL = "gemini-3-pro-image"
DEFAULT_SIZE = "1024x1024"
# 支持的尺寸映射
SUPPORTED_SIZES = {
"1024x1024": "1:1",
"1280x720": "16:9",
"720x1280": "9:16",
"1216x896": "4:3",
}
@AdapterRegistry.register("image", "antigravity")
class AntigravityImageAdapter(BaseAdapter[str]):
"""Antigravity 图像生成适配器 (OpenAI 兼容 API)。
特点:
- 使用 OpenAI 兼容的 chat.completions 端点
- 通过 extra_body.size 指定图像尺寸
- 支持 gemini-3-pro-image 等模型
- 返回图片 URL 或 base64
"""
adapter_type = "image"
adapter_name = "antigravity"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
self.client = AsyncOpenAI(
base_url=self.api_base,
api_key=config.api_key,
timeout=config.timeout_ms / 1000,
)
async def execute(
self,
prompt: str,
model: str | None = None,
size: str | None = None,
num_images: int = 1,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 base64。
Args:
prompt: 图片描述提示词
model: 模型名称 (gemini-3-pro-image / gemini-3-pro-image-16-9 等)
size: 图像尺寸 (1024x1024, 1280x720, 720x1280, 1216x896)
num_images: 生成图片数量 (暂只支持 1)
Returns:
图片 URL 或 base64 字符串
"""
# 优先使用传入参数,其次使用 Adapter 配置,最后使用默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
size = size or cfg.get("size") or DEFAULT_SIZE
start_time = time.time()
logger.info(
"antigravity_generate_start",
prompt_length=len(prompt),
model=model,
size=size,
)
# 调用 API
image_url = await self._generate_image(prompt, model, size)
elapsed = time.time() - start_time
logger.info(
"antigravity_generate_success",
elapsed_seconds=round(elapsed, 2),
model=model,
)
return image_url
async def health_check(self) -> bool:
"""检查 Antigravity API 是否可用。"""
try:
# 简单测试连通性
response = await self.client.chat.completions.create(
model=self.config.model or DEFAULT_MODEL,
messages=[{"role": "user", "content": "test"}],
max_tokens=1,
)
return True
except Exception as e:
logger.warning("antigravity_health_check_failed", error=str(e))
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
Antigravity 使用 Gemini 模型,成本约 $0.02/张。
"""
return 0.02
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((Exception,)),
reraise=True,
)
async def _generate_image(
self,
prompt: str,
model: str,
size: str,
) -> str:
"""调用 Antigravity API 生成图像。
Returns:
图片 URL 或 base64 data URI
"""
try:
response = await self.client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
extra_body={"size": size},
)
# 解析响应
content = response.choices[0].message.content
if not content:
raise ValueError("Antigravity 未返回内容")
# 尝试解析为图片 URL 或 base64
# 响应可能是纯 URL、base64 或 markdown 格式的图片
image_url = self._extract_image_url(content)
if image_url:
return image_url
raise ValueError(f"Antigravity 响应无法解析为图片: {content[:200]}")
except Exception as e:
logger.error(
"antigravity_generate_error",
error=str(e),
model=model,
)
raise
def _extract_image_url(self, content: str) -> str | None:
"""从响应内容中提取图片 URL。
支持多种格式:
- 纯 URL: https://...
- Markdown: ![...](https://...)
- Base64 data URI: data:image/...
- 纯 base64 字符串
"""
content = content.strip()
# 1. 检查是否为 data URI
if content.startswith("data:image/"):
return content
# 2. 检查是否为纯 URL
if content.startswith("http://") or content.startswith("https://"):
# 可能有多行,取第一行
return content.split("\n")[0].strip()
# 3. 检查 Markdown 图片格式 ![...](url)
import re
md_match = re.search(r"!\[.*?\]\((https?://[^\)]+)\)", content)
if md_match:
return md_match.group(1)
# 4. 检查是否像 base64 编码的图片数据
if self._looks_like_base64(content):
# 假设是 PNG
return f"data:image/png;base64,{content}"
return None
def _looks_like_base64(self, s: str) -> bool:
"""判断字符串是否看起来像 base64 编码。"""
# Base64 只包含 A-Z, a-z, 0-9, +, /, =
# 且长度通常较长
if len(s) < 100:
return False
import re
return bool(re.match(r"^[A-Za-z0-9+/=]+$", s.replace("\n", "")))
"""Antigravity 图像生成适配器。
使用 OpenAI 兼容 API 生成图像。
支持 gemini-3-pro-image 等模型。
"""
import base64
import time
from typing import Any
from openai import AsyncOpenAI
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "http://127.0.0.1:8045/v1"
DEFAULT_MODEL = "gemini-3-pro-image"
DEFAULT_SIZE = "1024x1024"
# 支持的尺寸映射
SUPPORTED_SIZES = {
"1024x1024": "1:1",
"1280x720": "16:9",
"720x1280": "9:16",
"1216x896": "4:3",
}
@AdapterRegistry.register("image", "antigravity")
class AntigravityImageAdapter(BaseAdapter[str]):
"""Antigravity 图像生成适配器 (OpenAI 兼容 API)。
特点:
- 使用 OpenAI 兼容的 chat.completions 端点
- 通过 extra_body.size 指定图像尺寸
- 支持 gemini-3-pro-image 等模型
- 返回图片 URL 或 base64
"""
adapter_type = "image"
adapter_name = "antigravity"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
self.client = AsyncOpenAI(
base_url=self.api_base,
api_key=config.api_key,
timeout=config.timeout_ms / 1000,
)
async def execute(
self,
prompt: str,
model: str | None = None,
size: str | None = None,
num_images: int = 1,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 base64。
Args:
prompt: 图片描述提示词
model: 模型名称 (gemini-3-pro-image / gemini-3-pro-image-16-9 等)
size: 图像尺寸 (1024x1024, 1280x720, 720x1280, 1216x896)
num_images: 生成图片数量 (暂只支持 1)
Returns:
图片 URL 或 base64 字符串
"""
# 优先使用传入参数,其次使用 Adapter 配置,最后使用默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
size = size or cfg.get("size") or DEFAULT_SIZE
start_time = time.time()
logger.info(
"antigravity_generate_start",
prompt_length=len(prompt),
model=model,
size=size,
)
# 调用 API
image_url = await self._generate_image(prompt, model, size)
elapsed = time.time() - start_time
logger.info(
"antigravity_generate_success",
elapsed_seconds=round(elapsed, 2),
model=model,
)
return image_url
async def health_check(self) -> bool:
"""检查 Antigravity API 是否可用。"""
try:
# 简单测试连通性
response = await self.client.chat.completions.create(
model=self.config.model or DEFAULT_MODEL,
messages=[{"role": "user", "content": "test"}],
max_tokens=1,
)
return True
except Exception as e:
logger.warning("antigravity_health_check_failed", error=str(e))
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
Antigravity 使用 Gemini 模型,成本约 $0.02/张。
"""
return 0.02
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((Exception,)),
reraise=True,
)
async def _generate_image(
self,
prompt: str,
model: str,
size: str,
) -> str:
"""调用 Antigravity API 生成图像。
Returns:
图片 URL 或 base64 data URI
"""
try:
response = await self.client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
extra_body={"size": size},
)
# 解析响应
content = response.choices[0].message.content
if not content:
raise ValueError("Antigravity 未返回内容")
# 尝试解析为图片 URL 或 base64
# 响应可能是纯 URL、base64 或 markdown 格式的图片
image_url = self._extract_image_url(content)
if image_url:
return image_url
raise ValueError(f"Antigravity 响应无法解析为图片: {content[:200]}")
except Exception as e:
logger.error(
"antigravity_generate_error",
error=str(e),
model=model,
)
raise
def _extract_image_url(self, content: str) -> str | None:
"""从响应内容中提取图片 URL。
支持多种格式:
- 纯 URL: https://...
- Markdown: ![...](https://...)
- Base64 data URI: data:image/...
- 纯 base64 字符串
"""
content = content.strip()
# 1. 检查是否为 data URI
if content.startswith("data:image/"):
return content
# 2. 检查是否为纯 URL
if content.startswith("http://") or content.startswith("https://"):
# 可能有多行,取第一行
return content.split("\n")[0].strip()
# 3. 检查 Markdown 图片格式 ![...](url)
import re
md_match = re.search(r"!\[.*?\]\((https?://[^\)]+)\)", content)
if md_match:
return md_match.group(1)
# 4. 检查是否像 base64 编码的图片数据
if self._looks_like_base64(content):
# 假设是 PNG
return f"data:image/png;base64,{content}"
return None
def _looks_like_base64(self, s: str) -> bool:
"""判断字符串是否看起来像 base64 编码。"""
# Base64 只包含 A-Z, a-z, 0-9, +, /, =
# 且长度通常较长
if len(s) < 100:
return False
import re
return bool(re.match(r"^[A-Za-z0-9+/=]+$", s.replace("\n", "")))

View File

@@ -1,252 +1,252 @@
"""CQTAI nano 图像生成适配器。
支持异步生成 + 轮询获取结果。
API 文档: https://api.cqtai.com
"""
import asyncio
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "https://api.cqtai.com"
DEFAULT_MODEL = "nano-banana"
DEFAULT_RESOLUTION = "2K"
DEFAULT_ASPECT_RATIO = "1:1"
POLL_INTERVAL_SECONDS = 2
MAX_POLL_ATTEMPTS = 60 # 最多轮询 2 分钟
@AdapterRegistry.register("image", "cqtai")
class CQTAIImageAdapter(BaseAdapter[str]):
"""CQTAI nano 图像生成适配器,返回图片 URL。
特点:
- 异步生成 + 轮询获取结果
- 支持 nano-banana (标准) 和 nano-banana-pro (高画质)
- 支持多种分辨率和画面比例
- 支持图生图 (filesUrl)
"""
adapter_type = "image"
adapter_name = "cqtai"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
async def execute(
self,
prompt: str,
model: str | None = None,
resolution: str | None = None,
aspect_ratio: str | None = None,
num_images: int = 1,
files_url: list[str] | None = None,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 URL 列表。
Args:
prompt: 图片描述提示词
model: 模型名称 (nano-banana / nano-banana-pro)
resolution: 分辨率 (1K / 2K / 4K)
aspect_ratio: 画面比例 (1:1, 16:9, 9:16, 4:3, 3:4 等)
num_images: 生成图片数量 (1-4)
files_url: 输入图片 URL 列表 (图生图)
Returns:
单张图片返回 str多张返回 list[str]
"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default (extra_config)
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
resolution = resolution or cfg.get("resolution") or DEFAULT_RESOLUTION
aspect_ratio = aspect_ratio or cfg.get("aspect_ratio") or DEFAULT_ASPECT_RATIO
num_images = min(max(num_images, 1), 4) # 限制 1-4
start_time = time.time()
logger.info(
"cqtai_generate_start",
prompt_length=len(prompt),
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
)
# 1. 提交生成任务
task_id = await self._submit_task(
prompt=prompt,
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
files_url=files_url or [],
)
logger.info("cqtai_task_submitted", task_id=task_id)
# 2. 轮询获取结果
result = await self._poll_result(task_id)
elapsed = time.time() - start_time
logger.info(
"cqtai_generate_success",
task_id=task_id,
elapsed_seconds=round(elapsed, 2),
image_count=len(result) if isinstance(result, list) else 1,
)
# 单张图片返回字符串,多张返回列表
if num_images == 1 and isinstance(result, list) and len(result) == 1:
return result[0]
return result
async def health_check(self) -> bool:
"""检查 CQTAI API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
# 简单的连通性测试
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": "health_check_test"},
headers={"Authorization": self.config.api_key},
)
# 即使返回错误也说明服务可达
return response.status_code in (200, 400, 401, 403, 404)
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
nano-banana: ¥0.1 ≈ $0.014
nano-banana-pro: ¥0.2 ≈ $0.028
"""
model = self.config.model or DEFAULT_MODEL
if model == "nano-banana-pro":
return 0.028
return 0.014
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _submit_task(
self,
prompt: str,
model: str,
resolution: str,
aspect_ratio: str,
num_images: int,
files_url: list[str],
) -> str:
"""提交图像生成任务,返回任务 ID。"""
timeout = self.config.timeout_ms / 1000
payload = {
"prompt": prompt,
"numImages": num_images,
"aspectRatio": aspect_ratio,
"filesUrl": files_url,
}
# 可选参数,不传则使用默认值
if model != DEFAULT_MODEL:
payload["model"] = model
if resolution != DEFAULT_RESOLUTION:
payload["resolution"] = resolution
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
f"{self.api_base}/api/cqt/generator/nano",
json=payload,
headers={
"Authorization": self.config.api_key,
"Content-Type": "application/json",
},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 任务提交失败: {data.get('msg', '未知错误')}")
task_id = data.get("data")
if not task_id:
raise ValueError("CQTAI 未返回任务 ID")
return task_id
async def _poll_result(self, task_id: str) -> list[str]:
"""轮询获取生成结果。
Returns:
图片 URL 列表
"""
timeout = self.config.timeout_ms / 1000
for attempt in range(MAX_POLL_ATTEMPTS):
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": task_id},
headers={"Authorization": self.config.api_key},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 查询失败: {data.get('msg', '未知错误')}")
result_data = data.get("data", {})
status = result_data.get("status")
if status == "completed":
# 提取图片 URL
images = result_data.get("images", [])
if not images:
# 兼容不同返回格式
image_url = result_data.get("imageUrl") or result_data.get("url")
if image_url:
images = [image_url]
if not images:
raise ValueError("CQTAI 未返回图片 URL")
return images
elif status == "failed":
error_msg = result_data.get("error", "生成失败")
raise ValueError(f"CQTAI 图像生成失败: {error_msg}")
# 继续等待
logger.debug(
"cqtai_poll_waiting",
task_id=task_id,
attempt=attempt + 1,
status=status,
)
await asyncio.sleep(POLL_INTERVAL_SECONDS)
raise TimeoutError(f"CQTAI 任务超时: {task_id}")
"""CQTAI nano 图像生成适配器。
支持异步生成 + 轮询获取结果。
API 文档: https://api.cqtai.com
"""
import asyncio
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认配置
DEFAULT_API_BASE = "https://api.cqtai.com"
DEFAULT_MODEL = "nano-banana"
DEFAULT_RESOLUTION = "2K"
DEFAULT_ASPECT_RATIO = "1:1"
POLL_INTERVAL_SECONDS = 2
MAX_POLL_ATTEMPTS = 60 # 最多轮询 2 分钟
@AdapterRegistry.register("image", "cqtai")
class CQTAIImageAdapter(BaseAdapter[str]):
"""CQTAI nano 图像生成适配器,返回图片 URL。
特点:
- 异步生成 + 轮询获取结果
- 支持 nano-banana (标准) 和 nano-banana-pro (高画质)
- 支持多种分辨率和画面比例
- 支持图生图 (filesUrl)
"""
adapter_type = "image"
adapter_name = "cqtai"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or DEFAULT_API_BASE
async def execute(
self,
prompt: str,
model: str | None = None,
resolution: str | None = None,
aspect_ratio: str | None = None,
num_images: int = 1,
files_url: list[str] | None = None,
**kwargs,
) -> str | list[str]:
"""根据提示词生成图片,返回 URL 或 URL 列表。
Args:
prompt: 图片描述提示词
model: 模型名称 (nano-banana / nano-banana-pro)
resolution: 分辨率 (1K / 2K / 4K)
aspect_ratio: 画面比例 (1:1, 16:9, 9:16, 4:3, 3:4 等)
num_images: 生成图片数量 (1-4)
files_url: 输入图片 URL 列表 (图生图)
Returns:
单张图片返回 str多张返回 list[str]
"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default (extra_config)
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
resolution = resolution or cfg.get("resolution") or DEFAULT_RESOLUTION
aspect_ratio = aspect_ratio or cfg.get("aspect_ratio") or DEFAULT_ASPECT_RATIO
num_images = min(max(num_images, 1), 4) # 限制 1-4
start_time = time.time()
logger.info(
"cqtai_generate_start",
prompt_length=len(prompt),
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
)
# 1. 提交生成任务
task_id = await self._submit_task(
prompt=prompt,
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
num_images=num_images,
files_url=files_url or [],
)
logger.info("cqtai_task_submitted", task_id=task_id)
# 2. 轮询获取结果
result = await self._poll_result(task_id)
elapsed = time.time() - start_time
logger.info(
"cqtai_generate_success",
task_id=task_id,
elapsed_seconds=round(elapsed, 2),
image_count=len(result) if isinstance(result, list) else 1,
)
# 单张图片返回字符串,多张返回列表
if num_images == 1 and isinstance(result, list) and len(result) == 1:
return result[0]
return result
async def health_check(self) -> bool:
"""检查 CQTAI API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
# 简单的连通性测试
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": "health_check_test"},
headers={"Authorization": self.config.api_key},
)
# 即使返回错误也说明服务可达
return response.status_code in (200, 400, 401, 403, 404)
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每张图片成本 (USD)。
nano-banana: ¥0.1 ≈ $0.014
nano-banana-pro: ¥0.2 ≈ $0.028
"""
model = self.config.model or DEFAULT_MODEL
if model == "nano-banana-pro":
return 0.028
return 0.014
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _submit_task(
self,
prompt: str,
model: str,
resolution: str,
aspect_ratio: str,
num_images: int,
files_url: list[str],
) -> str:
"""提交图像生成任务,返回任务 ID。"""
timeout = self.config.timeout_ms / 1000
payload = {
"prompt": prompt,
"numImages": num_images,
"aspectRatio": aspect_ratio,
"filesUrl": files_url,
}
# 可选参数,不传则使用默认值
if model != DEFAULT_MODEL:
payload["model"] = model
if resolution != DEFAULT_RESOLUTION:
payload["resolution"] = resolution
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
f"{self.api_base}/api/cqt/generator/nano",
json=payload,
headers={
"Authorization": self.config.api_key,
"Content-Type": "application/json",
},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 任务提交失败: {data.get('msg', '未知错误')}")
task_id = data.get("data")
if not task_id:
raise ValueError("CQTAI 未返回任务 ID")
return task_id
async def _poll_result(self, task_id: str) -> list[str]:
"""轮询获取生成结果。
Returns:
图片 URL 列表
"""
timeout = self.config.timeout_ms / 1000
for attempt in range(MAX_POLL_ATTEMPTS):
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(
f"{self.api_base}/api/cqt/info/nano",
params={"id": task_id},
headers={"Authorization": self.config.api_key},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 200:
raise ValueError(f"CQTAI 查询失败: {data.get('msg', '未知错误')}")
result_data = data.get("data", {})
status = result_data.get("status")
if status == "completed":
# 提取图片 URL
images = result_data.get("images", [])
if not images:
# 兼容不同返回格式
image_url = result_data.get("imageUrl") or result_data.get("url")
if image_url:
images = [image_url]
if not images:
raise ValueError("CQTAI 未返回图片 URL")
return images
elif status == "failed":
error_msg = result_data.get("error", "生成失败")
raise ValueError(f"CQTAI 图像生成失败: {error_msg}")
# 继续等待
logger.debug(
"cqtai_poll_waiting",
task_id=task_id,
attempt=attempt + 1,
status=status,
)
await asyncio.sleep(POLL_INTERVAL_SECONDS)
raise TimeoutError(f"CQTAI 任务超时: {task_id}")

View File

@@ -1,73 +1,73 @@
"""适配器注册表 - 支持动态注册和工厂创建。"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.services.adapters.base import AdapterConfig, BaseAdapter
class AdapterRegistry:
"""适配器注册表,管理所有已注册的适配器类。"""
_adapters: dict[str, type["BaseAdapter"]] = {}
@classmethod
def register(cls, adapter_type: str, adapter_name: str):
"""装饰器:注册适配器类。
用法:
"""适配器注册表 - 支持动态注册和工厂创建。"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.services.adapters.base import AdapterConfig, BaseAdapter
class AdapterRegistry:
"""适配器注册表,管理所有已注册的适配器类。"""
_adapters: dict[str, type["BaseAdapter"]] = {}
@classmethod
def register(cls, adapter_type: str, adapter_name: str):
"""装饰器:注册适配器类。
用法:
@AdapterRegistry.register("text", "text_primary")
class TextPrimaryAdapter(BaseAdapter[StoryOutput]):
...
"""
def decorator(adapter_class: type["BaseAdapter"]):
key = f"{adapter_type}:{adapter_name}"
cls._adapters[key] = adapter_class
# 自动设置类属性
adapter_class.adapter_type = adapter_type
adapter_class.adapter_name = adapter_name
return adapter_class
return decorator
@classmethod
def get(cls, adapter_type: str, adapter_name: str) -> type["BaseAdapter"] | None:
"""获取已注册的适配器类。"""
key = f"{adapter_type}:{adapter_name}"
return cls._adapters.get(key)
@classmethod
def list_adapters(cls, adapter_type: str | None = None) -> list[str]:
"""列出所有已注册的适配器。
Args:
adapter_type: 可选,筛选特定类型 (text/image/tts)
Returns:
适配器键列表,格式为 "type:name"
"""
if adapter_type:
return [k for k in cls._adapters if k.startswith(f"{adapter_type}:")]
return list(cls._adapters.keys())
@classmethod
def create(
cls,
adapter_type: str,
adapter_name: str,
config: "AdapterConfig",
) -> "BaseAdapter":
"""工厂方法:创建适配器实例。
Raises:
ValueError: 适配器未注册
"""
adapter_class = cls.get(adapter_type, adapter_name)
if not adapter_class:
available = cls.list_adapters(adapter_type)
raise ValueError(
f"适配器 '{adapter_type}:{adapter_name}' 未注册。"
f"可用: {available}"
)
return adapter_class(config)
...
"""
def decorator(adapter_class: type["BaseAdapter"]):
key = f"{adapter_type}:{adapter_name}"
cls._adapters[key] = adapter_class
# 自动设置类属性
adapter_class.adapter_type = adapter_type
adapter_class.adapter_name = adapter_name
return adapter_class
return decorator
@classmethod
def get(cls, adapter_type: str, adapter_name: str) -> type["BaseAdapter"] | None:
"""获取已注册的适配器类。"""
key = f"{adapter_type}:{adapter_name}"
return cls._adapters.get(key)
@classmethod
def list_adapters(cls, adapter_type: str | None = None) -> list[str]:
"""列出所有已注册的适配器。
Args:
adapter_type: 可选,筛选特定类型 (text/image/tts)
Returns:
适配器键列表,格式为 "type:name"
"""
if adapter_type:
return [k for k in cls._adapters if k.startswith(f"{adapter_type}:")]
return list(cls._adapters.keys())
@classmethod
def create(
cls,
adapter_type: str,
adapter_name: str,
config: "AdapterConfig",
) -> "BaseAdapter":
"""工厂方法:创建适配器实例。
Raises:
ValueError: 适配器未注册
"""
adapter_class = cls.get(adapter_type, adapter_name)
if not adapter_class:
available = cls.list_adapters(adapter_type)
raise ValueError(
f"适配器 '{adapter_type}:{adapter_name}' 未注册。"
f"可用: {available}"
)
return adapter_class(config)

View File

@@ -1 +1 @@
"""Storybook 适配器模块。"""
"""Storybook 适配器模块。"""

View File

@@ -1,195 +1,195 @@
"""Storybook 适配器 - 生成可翻页的分页故事书。"""
import json
import random
import re
import time
from dataclasses import dataclass, field
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_STORYBOOK,
USER_PROMPT_STORYBOOK,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@dataclass
class StorybookPage:
"""故事书单页。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
@dataclass
class Storybook:
"""故事书输出。"""
title: str
main_character: str
art_style: str
pages: list[StorybookPage] = field(default_factory=list)
cover_prompt: str = ""
cover_url: str | None = None
@AdapterRegistry.register("storybook", "storybook_primary")
class StorybookPrimaryAdapter(BaseAdapter[Storybook]):
"""Storybook 生成适配器(默认)。
生成分页故事书结构,包含每页文字和图像提示词。
图像生成需要单独调用 image adapter。
"""
adapter_type = "storybook"
adapter_name = "storybook_primary"
async def execute(
self,
keywords: str,
page_count: int = 6,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> Storybook:
"""生成分页故事书。
Args:
keywords: 故事关键词
page_count: 页数 (4-12)
education_theme: 教育主题
memory_context: 记忆上下文
Returns:
Storybook 对象,包含标题、页面列表和封面提示词
"""
start_time = time.time()
page_count = max(4, min(page_count, 12)) # 限制 4-12 页
logger.info(
"storybook_generate_start",
keywords=keywords,
page_count=page_count,
has_memory=bool(memory_context),
)
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
prompt = USER_PROMPT_STORYBOOK.format(
keywords=keywords,
education_theme=theme,
random_element=random_element,
page_count=page_count,
memory_context=memory_context or "",
)
payload = {
"system_instruction": {"parts": [{"text": SYSTEM_INSTRUCTION_STORYBOOK}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Storybook 服务未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Storybook 服务响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Storybook JSON 解析失败: {exc}")
# 构建 Storybook 对象
pages = [
StorybookPage(
page_number=p.get("page_number", i + 1),
text=p.get("text", ""),
image_prompt=p.get("image_prompt", ""),
)
for i, p in enumerate(parsed.get("pages", []))
]
storybook = Storybook(
title=parsed.get("title", "未命名故事"),
main_character=parsed.get("main_character", ""),
art_style=parsed.get("art_style", ""),
pages=pages,
cover_prompt=parsed.get("cover_prompt", ""),
)
elapsed = time.time() - start_time
logger.info(
"storybook_generate_success",
elapsed_seconds=round(elapsed, 2),
title=storybook.title,
page_count=len(pages),
)
return storybook
async def health_check(self) -> bool:
"""检查 API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估成本(仅文本生成,不含图像)。"""
return 0.002 # 比普通故事稍贵,因为输出更长
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 API带重试机制。"""
model = self.config.model or "gemini-2.0-flash"
url = f"{TEXT_API_BASE}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()
"""Storybook 适配器 - 生成可翻页的分页故事书。"""
import json
import random
import re
import time
from dataclasses import dataclass, field
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_STORYBOOK,
USER_PROMPT_STORYBOOK,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@dataclass
class StorybookPage:
"""故事书单页。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
@dataclass
class Storybook:
"""故事书输出。"""
title: str
main_character: str
art_style: str
pages: list[StorybookPage] = field(default_factory=list)
cover_prompt: str = ""
cover_url: str | None = None
@AdapterRegistry.register("storybook", "storybook_primary")
class StorybookPrimaryAdapter(BaseAdapter[Storybook]):
"""Storybook 生成适配器(默认)。
生成分页故事书结构,包含每页文字和图像提示词。
图像生成需要单独调用 image adapter。
"""
adapter_type = "storybook"
adapter_name = "storybook_primary"
async def execute(
self,
keywords: str,
page_count: int = 6,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> Storybook:
"""生成分页故事书。
Args:
keywords: 故事关键词
page_count: 页数 (4-12)
education_theme: 教育主题
memory_context: 记忆上下文
Returns:
Storybook 对象,包含标题、页面列表和封面提示词
"""
start_time = time.time()
page_count = max(4, min(page_count, 12)) # 限制 4-12 页
logger.info(
"storybook_generate_start",
keywords=keywords,
page_count=page_count,
has_memory=bool(memory_context),
)
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
prompt = USER_PROMPT_STORYBOOK.format(
keywords=keywords,
education_theme=theme,
random_element=random_element,
page_count=page_count,
memory_context=memory_context or "",
)
payload = {
"system_instruction": {"parts": [{"text": SYSTEM_INSTRUCTION_STORYBOOK}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Storybook 服务未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Storybook 服务响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Storybook JSON 解析失败: {exc}")
# 构建 Storybook 对象
pages = [
StorybookPage(
page_number=p.get("page_number", i + 1),
text=p.get("text", ""),
image_prompt=p.get("image_prompt", ""),
)
for i, p in enumerate(parsed.get("pages", []))
]
storybook = Storybook(
title=parsed.get("title", "未命名故事"),
main_character=parsed.get("main_character", ""),
art_style=parsed.get("art_style", ""),
pages=pages,
cover_prompt=parsed.get("cover_prompt", ""),
)
elapsed = time.time() - start_time
logger.info(
"storybook_generate_success",
elapsed_seconds=round(elapsed, 2),
title=storybook.title,
page_count=len(pages),
)
return storybook
async def health_check(self) -> bool:
"""检查 API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估成本(仅文本生成,不含图像)。"""
return 0.002 # 比普通故事稍贵,因为输出更长
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 API带重试机制。"""
model = self.config.model or "gemini-2.0-flash"
url = f"{TEXT_API_BASE}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()

View File

@@ -1 +1 @@
"""文本生成适配器。"""
"""文本生成适配器。"""

View File

@@ -1,164 +1,164 @@
"""文本生成适配器 (Google Gemini)。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@AdapterRegistry.register("text", "gemini")
class GeminiTextAdapter(BaseAdapter[StoryOutput]):
"""Google Gemini 文本生成适配器。"""
adapter_type = "text"
adapter_name = "gemini"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("request_start", adapter="gemini", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
# Gemini API Payload supports 'system_instruction'
payload = {
"system_instruction": {"parts": [{"text": system_instruction}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Gemini 未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Gemini 响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Gemini 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("Gemini 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"request_success",
adapter="gemini",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
async def health_check(self) -> bool:
"""检查 Gemini API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.001
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 Gemini API。"""
model = self.config.model or "gemini-2.0-flash"
base_url = self.config.api_base or TEXT_API_BASE
# 智能补全:
# 1. 如果用户填了完整路径 (以 /models 结尾),就直接用 (支持 v1 或 v1beta)
if self.config.api_base and base_url.rstrip("/").endswith("/models"):
pass
# 2. 如果没填路径 (只是域名),默认补全代码适配的 /v1beta/models
elif self.config.api_base:
base_url = f"{base_url.rstrip('/')}/v1beta/models"
url = f"{base_url}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()
"""文本生成适配器 (Google Gemini)。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
@AdapterRegistry.register("text", "gemini")
class GeminiTextAdapter(BaseAdapter[StoryOutput]):
"""Google Gemini 文本生成适配器。"""
adapter_type = "text"
adapter_name = "gemini"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("request_start", adapter="gemini", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
# Gemini API Payload supports 'system_instruction'
payload = {
"system_instruction": {"parts": [{"text": system_instruction}]},
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.95,
"topP": 0.9,
},
}
result = await self._call_api(payload)
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("Gemini 未返回内容")
parts = candidates[0].get("content", {}).get("parts") or []
if not parts or "text" not in parts[0]:
raise ValueError("Gemini 响应缺少文本")
response_text = parts[0]["text"]
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Gemini 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("Gemini 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"request_success",
adapter="gemini",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
async def health_check(self) -> bool:
"""检查 Gemini API 是否可用。"""
try:
payload = {
"contents": [{"parts": [{"text": "Hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.001
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, payload: dict) -> dict:
"""调用 Gemini API。"""
model = self.config.model or "gemini-2.0-flash"
base_url = self.config.api_base or TEXT_API_BASE
# 智能补全:
# 1. 如果用户填了完整路径 (以 /models 结尾),就直接用 (支持 v1 或 v1beta)
if self.config.api_base and base_url.rstrip("/").endswith("/models"):
pass
# 2. 如果没填路径 (只是域名),默认补全代码适配的 /v1beta/models
elif self.config.api_base:
base_url = f"{base_url.rstrip('/')}/v1beta/models"
url = f"{base_url}/{model}:generateContent?key={self.config.api_key}"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()

View File

@@ -1,11 +1,11 @@
from dataclasses import dataclass
from typing import Literal
@dataclass
class StoryOutput:
"""故事生成输出。"""
mode: Literal["generated", "enhanced"]
title: str
story_text: str
cover_prompt_suggestion: str
from dataclasses import dataclass
from typing import Literal
@dataclass
class StoryOutput:
"""故事生成输出。"""
mode: Literal["generated", "enhanced"]
title: str
story_text: str
cover_prompt_suggestion: str

View File

@@ -1,172 +1,172 @@
"""OpenAI 文本生成适配器。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
OPENAI_API_BASE = "https://api.openai.com/v1/chat/completions"
@AdapterRegistry.register("text", "openai")
class OpenAITextAdapter(BaseAdapter[StoryOutput]):
"""OpenAI 文本生成适配器。"""
adapter_type = "text"
adapter_name = "openai"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("openai_text_request_start", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
model = self.config.model or "gpt-4o-mini"
payload = {
"model": model,
"messages": [
{
"role": "system",
"content": system_instruction,
},
{"role": "user", "content": prompt},
],
"response_format": {"type": "json_object"},
"temperature": 0.95,
"top_p": 0.9,
}
result = await self._call_api(payload)
choices = result.get("choices") or []
if not choices:
raise ValueError("OpenAI 未返回内容")
response_text = choices[0].get("message", {}).get("content", "")
if not response_text:
raise ValueError("OpenAI 响应缺少文本")
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"OpenAI 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("OpenAI 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"openai_text_request_success",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
mode=parsed["mode"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(httpx.HTTPStatusError),
)
async def _call_api(self, payload: dict) -> dict:
"""调用 OpenAI API带重试机制。"""
url = self.config.api_base or OPENAI_API_BASE
# 智能补全: 如果用户只填了 Base URL自动补全路径
if self.config.api_base and not url.endswith("/chat/completions"):
base = url.rstrip("/")
url = f"{base}/chat/completions"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()
async def health_check(self) -> bool:
"""检查 OpenAI API 是否可用。"""
try:
payload = {
"model": self.config.model or "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5,
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估文本生成成本 (USD)。"""
return 0.01
"""OpenAI 文本生成适配器。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
OPENAI_API_BASE = "https://api.openai.com/v1/chat/completions"
@AdapterRegistry.register("text", "openai")
class OpenAITextAdapter(BaseAdapter[StoryOutput]):
"""OpenAI 文本生成适配器。"""
adapter_type = "text"
adapter_name = "openai"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("openai_text_request_start", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
model = self.config.model or "gpt-4o-mini"
payload = {
"model": model,
"messages": [
{
"role": "system",
"content": system_instruction,
},
{"role": "user", "content": prompt},
],
"response_format": {"type": "json_object"},
"temperature": 0.95,
"top_p": 0.9,
}
result = await self._call_api(payload)
choices = result.get("choices") or []
if not choices:
raise ValueError("OpenAI 未返回内容")
response_text = choices[0].get("message", {}).get("content", "")
if not response_text:
raise ValueError("OpenAI 响应缺少文本")
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"OpenAI 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("OpenAI 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"openai_text_request_success",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
mode=parsed["mode"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(httpx.HTTPStatusError),
)
async def _call_api(self, payload: dict) -> dict:
"""调用 OpenAI API带重试机制。"""
url = self.config.api_base or OPENAI_API_BASE
# 智能补全: 如果用户只填了 Base URL自动补全路径
if self.config.api_base and not url.endswith("/chat/completions"):
base = url.rstrip("/")
url = f"{base}/chat/completions"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()
async def health_check(self) -> bool:
"""检查 OpenAI API 是否可用。"""
try:
payload = {
"model": self.config.model or "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5,
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估文本生成成本 (USD)。"""
return 0.01

View File

@@ -1,5 +1,5 @@
"""TTS 语音合成适配器。"""
from app.services.adapters.tts import edge_tts as _tts_edge_tts_adapter # noqa: F401
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
"""TTS 语音合成适配器。"""
from app.services.adapters.tts import edge_tts as _tts_edge_tts_adapter # noqa: F401
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401

View File

@@ -1,66 +1,66 @@
"""EdgeTTS 免费语音生成适配器。"""
import time
import edge_tts
from app.core.logging import get_logger
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认中文女声 (晓晓)
DEFAULT_VOICE = "zh-CN-XiaoxiaoNeural"
@AdapterRegistry.register("tts", "edge_tts")
class EdgeTTSAdapter(BaseAdapter[bytes]):
"""EdgeTTS 语音生成适配器 (Free)。
不需要 API Key。
"""
adapter_type = "tts"
adapter_name = "edge_tts"
async def execute(self, text: str, **kwargs) -> bytes:
"""生成语音。"""
# 支持动态指定音色
voice = kwargs.get("voice") or self.config.model or DEFAULT_VOICE
start_time = time.time()
logger.info("edge_tts_generate_start", text_length=len(text), voice=voice)
# EdgeTTS 只能输出到文件,我们需要用临时文件周转一下
# 或者直接 capture stream (communicate) 但 edge-tts 库主要面向文件
# 优化: 使用 communicate 直接获取 bytes无需磁盘IO
communicate = edge_tts.Communicate(text, voice)
audio_data = b""
async for chunk in communicate.stream():
if chunk["type"] == "audio":
audio_data += chunk["data"]
elapsed = time.time() - start_time
logger.info(
"edge_tts_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_data),
)
return audio_data
async def health_check(self) -> bool:
"""检查 EdgeTTS 是否可用 (网络连通性)。"""
try:
# 简单生成一个词
await self.execute("Hi")
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.0 # Free!
"""EdgeTTS 免费语音生成适配器。"""
import time
import edge_tts
from app.core.logging import get_logger
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# 默认中文女声 (晓晓)
DEFAULT_VOICE = "zh-CN-XiaoxiaoNeural"
@AdapterRegistry.register("tts", "edge_tts")
class EdgeTTSAdapter(BaseAdapter[bytes]):
"""EdgeTTS 语音生成适配器 (Free)。
不需要 API Key。
"""
adapter_type = "tts"
adapter_name = "edge_tts"
async def execute(self, text: str, **kwargs) -> bytes:
"""生成语音。"""
# 支持动态指定音色
voice = kwargs.get("voice") or self.config.model or DEFAULT_VOICE
start_time = time.time()
logger.info("edge_tts_generate_start", text_length=len(text), voice=voice)
# EdgeTTS 只能输出到文件,我们需要用临时文件周转一下
# 或者直接 capture stream (communicate) 但 edge-tts 库主要面向文件
# 优化: 使用 communicate 直接获取 bytes无需磁盘IO
communicate = edge_tts.Communicate(text, voice)
audio_data = b""
async for chunk in communicate.stream():
if chunk["type"] == "audio":
audio_data += chunk["data"]
elapsed = time.time() - start_time
logger.info(
"edge_tts_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_data),
)
return audio_data
async def health_check(self) -> bool:
"""检查 EdgeTTS 是否可用 (网络连通性)。"""
try:
# 简单生成一个词
await self.execute("Hi")
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
return 0.0 # Free!

View File

@@ -1,104 +1,104 @@
"""ElevenLabs TTS 语音合成适配器。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
ELEVENLABS_API_BASE = "https://api.elevenlabs.io/v1"
DEFAULT_VOICE_ID = "21m00Tcm4TlvDq8ikWAM" # Rachel
@AdapterRegistry.register("tts", "elevenlabs")
class ElevenLabsTtsAdapter(BaseAdapter[bytes]):
"""ElevenLabs TTS 语音合成适配器,返回 MP3 bytes。"""
adapter_type = "tts"
adapter_name = "elevenlabs"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or ELEVENLABS_API_BASE
async def execute(self, text: str, **kwargs) -> bytes:
"""将文本转换为语音 MP3 bytes。"""
start_time = time.time()
logger.info("elevenlabs_tts_start", text_length=len(text))
voice_id = kwargs.get("voice_id") or DEFAULT_VOICE_ID
model_id = kwargs.get("model") or self.config.model or "eleven_multilingual_v2"
stability = kwargs.get("stability", 0.5)
similarity_boost = kwargs.get("similarity_boost", 0.75)
url = f"{self.api_base}/text-to-speech/{voice_id}"
payload = {
"text": text,
"model_id": model_id,
"voice_settings": {
"stability": stability,
"similarity_boost": similarity_boost,
},
}
audio_bytes = await self._call_api(url, payload)
elapsed = time.time() - start_time
logger.info(
"elevenlabs_tts_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 ElevenLabs API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(
f"{self.api_base}/voices",
headers={"xi-api-key": self.config.api_key},
)
return response.status_code == 200
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每千字符成本 (USD)。"""
return 0.03
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> bytes:
"""调用 ElevenLabs API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"xi-api-key": self.config.api_key,
"Content-Type": "application/json",
"Accept": "audio/mpeg",
},
)
response.raise_for_status()
return response.content
"""ElevenLabs TTS 语音合成适配器。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
ELEVENLABS_API_BASE = "https://api.elevenlabs.io/v1"
DEFAULT_VOICE_ID = "21m00Tcm4TlvDq8ikWAM" # Rachel
@AdapterRegistry.register("tts", "elevenlabs")
class ElevenLabsTtsAdapter(BaseAdapter[bytes]):
"""ElevenLabs TTS 语音合成适配器,返回 MP3 bytes。"""
adapter_type = "tts"
adapter_name = "elevenlabs"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_base = config.api_base or ELEVENLABS_API_BASE
async def execute(self, text: str, **kwargs) -> bytes:
"""将文本转换为语音 MP3 bytes。"""
start_time = time.time()
logger.info("elevenlabs_tts_start", text_length=len(text))
voice_id = kwargs.get("voice_id") or DEFAULT_VOICE_ID
model_id = kwargs.get("model") or self.config.model or "eleven_multilingual_v2"
stability = kwargs.get("stability", 0.5)
similarity_boost = kwargs.get("similarity_boost", 0.75)
url = f"{self.api_base}/text-to-speech/{voice_id}"
payload = {
"text": text,
"model_id": model_id,
"voice_settings": {
"stability": stability,
"similarity_boost": similarity_boost,
},
}
audio_bytes = await self._call_api(url, payload)
elapsed = time.time() - start_time
logger.info(
"elevenlabs_tts_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 ElevenLabs API 是否可用。"""
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(
f"{self.api_base}/voices",
headers={"xi-api-key": self.config.api_key},
)
return response.status_code == 200
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每千字符成本 (USD)。"""
return 0.03
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> bytes:
"""调用 ElevenLabs API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"xi-api-key": self.config.api_key,
"Content-Type": "application/json",
"Accept": "audio/mpeg",
},
)
response.raise_for_status()
return response.content

View File

@@ -1,149 +1,149 @@
"""MiniMax 语音生成适配器 (T2A V2)。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.config import settings
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# MiniMax API 配置
DEFAULT_API_URL = "https://api.minimaxi.com/v1/t2a_v2"
DEFAULT_MODEL = "speech-2.6-turbo"
@AdapterRegistry.register("tts", "minimax")
class MiniMaxTTSAdapter(BaseAdapter[bytes]):
"""MiniMax 语音生成适配器。
需要配置:
- api_key: MiniMax API Key
- minimax_group_id: 可选 (取决于使用的模型/账户类型)
"""
adapter_type = "tts"
adapter_name = "minimax"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_url = DEFAULT_API_URL
async def execute(
self,
text: str,
voice_id: str | None = None,
model: str | None = None,
speed: float | None = None,
vol: float | None = None,
pitch: int | None = None,
emotion: str | None = None,
**kwargs,
) -> bytes:
"""生成语音。"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
voice_id = voice_id or cfg.get("voice_id") or "male-qn-qingse"
speed = speed if speed is not None else (cfg.get("speed") or 1.0)
vol = vol if vol is not None else (cfg.get("vol") or 1.0)
pitch = pitch if pitch is not None else (cfg.get("pitch") or 0)
emotion = emotion or cfg.get("emotion")
group_id = kwargs.get("group_id") or settings.minimax_group_id
url = self.api_url
if group_id:
url = f"{self.api_url}?GroupId={group_id}"
payload = {
"model": model,
"text": text,
"stream": False,
"voice_setting": {
"voice_id": voice_id,
"speed": speed,
"vol": vol,
"pitch": pitch,
},
"audio_setting": {
"sample_rate": 32000,
"bitrate": 128000,
"format": "mp3",
"channel": 1
}
}
if emotion:
payload["voice_setting"]["emotion"] = emotion
start_time = time.time()
logger.info("minimax_generate_start", text_length=len(text), model=model)
result = await self._call_api(url, payload)
# 错误处理
if result.get("base_resp", {}).get("status_code") != 0:
error_msg = result.get("base_resp", {}).get("status_msg", "未知错误")
raise ValueError(f"MiniMax API 错误: {error_msg}")
# Hex 解码 (关键逻辑,从 primary.py 迁移)
hex_audio = result.get("data", {}).get("audio")
if not hex_audio:
raise ValueError("API 响应中未找到音频数据 (data.audio)")
try:
audio_bytes = bytes.fromhex(hex_audio)
except ValueError:
raise ValueError("MiniMax 返回的音频数据不是有效的 Hex 字符串")
elapsed = time.time() - start_time
logger.info(
"minimax_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 Minimax API 是否可用。"""
try:
# 尝试生成极短文本
await self.execute("Hi")
return True
except Exception:
return False
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> dict:
"""调用 API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()
"""MiniMax 语音生成适配器 (T2A V2)。"""
import time
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.config import settings
from app.core.logging import get_logger
from app.services.adapters.base import AdapterConfig, BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
# MiniMax API 配置
DEFAULT_API_URL = "https://api.minimaxi.com/v1/t2a_v2"
DEFAULT_MODEL = "speech-2.6-turbo"
@AdapterRegistry.register("tts", "minimax")
class MiniMaxTTSAdapter(BaseAdapter[bytes]):
"""MiniMax 语音生成适配器。
需要配置:
- api_key: MiniMax API Key
- minimax_group_id: 可选 (取决于使用的模型/账户类型)
"""
adapter_type = "tts"
adapter_name = "minimax"
def __init__(self, config: AdapterConfig):
super().__init__(config)
self.api_url = DEFAULT_API_URL
async def execute(
self,
text: str,
voice_id: str | None = None,
model: str | None = None,
speed: float | None = None,
vol: float | None = None,
pitch: int | None = None,
emotion: str | None = None,
**kwargs,
) -> bytes:
"""生成语音。"""
# 1. 优先使用传入参数
# 2. 其次使用 Adapter 配置里的 default
# 3. 最后使用系统默认值
model = model or self.config.model or DEFAULT_MODEL
cfg = self.config.extra_config or {}
voice_id = voice_id or cfg.get("voice_id") or "male-qn-qingse"
speed = speed if speed is not None else (cfg.get("speed") or 1.0)
vol = vol if vol is not None else (cfg.get("vol") or 1.0)
pitch = pitch if pitch is not None else (cfg.get("pitch") or 0)
emotion = emotion or cfg.get("emotion")
group_id = kwargs.get("group_id") or settings.minimax_group_id
url = self.api_url
if group_id:
url = f"{self.api_url}?GroupId={group_id}"
payload = {
"model": model,
"text": text,
"stream": False,
"voice_setting": {
"voice_id": voice_id,
"speed": speed,
"vol": vol,
"pitch": pitch,
},
"audio_setting": {
"sample_rate": 32000,
"bitrate": 128000,
"format": "mp3",
"channel": 1
}
}
if emotion:
payload["voice_setting"]["emotion"] = emotion
start_time = time.time()
logger.info("minimax_generate_start", text_length=len(text), model=model)
result = await self._call_api(url, payload)
# 错误处理
if result.get("base_resp", {}).get("status_code") != 0:
error_msg = result.get("base_resp", {}).get("status_msg", "未知错误")
raise ValueError(f"MiniMax API 错误: {error_msg}")
# Hex 解码 (关键逻辑,从 primary.py 迁移)
hex_audio = result.get("data", {}).get("audio")
if not hex_audio:
raise ValueError("API 响应中未找到音频数据 (data.audio)")
try:
audio_bytes = bytes.fromhex(hex_audio)
except ValueError:
raise ValueError("MiniMax 返回的音频数据不是有效的 Hex 字符串")
elapsed = time.time() - start_time
logger.info(
"minimax_generate_success",
elapsed_seconds=round(elapsed, 2),
audio_size_bytes=len(audio_bytes),
)
return audio_bytes
async def health_check(self) -> bool:
"""检查 Minimax API 是否可用。"""
try:
# 尝试生成极短文本
await self.execute("Hi")
return True
except Exception:
return False
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
reraise=True,
)
async def _call_api(self, url: str, payload: dict) -> dict:
"""调用 API带重试机制。"""
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()

View File

@@ -1,196 +1,196 @@
"""成本追踪服务。
记录 API 调用成本,支持预算控制。
"""
from datetime import datetime, timedelta
from decimal import Decimal
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import CostRecord, UserBudget
logger = get_logger(__name__)
class BudgetExceededError(Exception):
"""预算超限错误。"""
def __init__(self, limit_type: str, used: Decimal, limit: Decimal):
self.limit_type = limit_type
self.used = used
self.limit = limit
super().__init__(f"{limit_type} 预算已超限: {used}/{limit} USD")
class CostTracker:
"""成本追踪器。"""
async def record_cost(
self,
db: AsyncSession,
user_id: str,
provider_name: str,
capability: str,
estimated_cost: float,
provider_id: str | None = None,
) -> CostRecord:
"""记录一次 API 调用成本。"""
record = CostRecord(
user_id=user_id,
provider_id=provider_id,
provider_name=provider_name,
capability=capability,
estimated_cost=Decimal(str(estimated_cost)),
)
db.add(record)
await db.commit()
logger.debug(
"cost_recorded",
user_id=user_id,
provider=provider_name,
capability=capability,
cost=estimated_cost,
)
return record
async def get_daily_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户今日成本。"""
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= today_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_monthly_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户本月成本。"""
now = datetime.utcnow()
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= month_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_cost_by_capability(
self,
db: AsyncSession,
user_id: str,
days: int = 30,
) -> dict[str, Decimal]:
"""按能力类型统计成本。"""
since = datetime.utcnow() - timedelta(days=days)
result = await db.execute(
select(CostRecord.capability, func.sum(CostRecord.estimated_cost))
.where(CostRecord.user_id == user_id, CostRecord.timestamp >= since)
.group_by(CostRecord.capability)
)
return {row[0]: Decimal(str(row[1])) for row in result.all()}
async def check_budget(
self,
db: AsyncSession,
user_id: str,
estimated_cost: float,
) -> bool:
"""检查预算是否允许此次调用。
Returns:
True 如果允许,否则抛出 BudgetExceededError
"""
budget = await self.get_user_budget(db, user_id)
if not budget or not budget.enabled:
return True
# 检查日预算
daily_cost = await self.get_daily_cost(db, user_id)
if daily_cost + Decimal(str(estimated_cost)) > budget.daily_limit_usd:
raise BudgetExceededError("", daily_cost, budget.daily_limit_usd)
# 检查月预算
monthly_cost = await self.get_monthly_cost(db, user_id)
if monthly_cost + Decimal(str(estimated_cost)) > budget.monthly_limit_usd:
raise BudgetExceededError("", monthly_cost, budget.monthly_limit_usd)
return True
async def get_user_budget(self, db: AsyncSession, user_id: str) -> UserBudget | None:
"""获取用户预算配置。"""
result = await db.execute(
select(UserBudget).where(UserBudget.user_id == user_id)
)
return result.scalar_one_or_none()
async def set_user_budget(
self,
db: AsyncSession,
user_id: str,
daily_limit: float | None = None,
monthly_limit: float | None = None,
alert_threshold: float | None = None,
enabled: bool | None = None,
) -> UserBudget:
"""设置用户预算。"""
budget = await self.get_user_budget(db, user_id)
if budget is None:
budget = UserBudget(user_id=user_id)
db.add(budget)
if daily_limit is not None:
budget.daily_limit_usd = Decimal(str(daily_limit))
if monthly_limit is not None:
budget.monthly_limit_usd = Decimal(str(monthly_limit))
if alert_threshold is not None:
budget.alert_threshold = Decimal(str(alert_threshold))
if enabled is not None:
budget.enabled = enabled
await db.commit()
await db.refresh(budget)
return budget
async def get_cost_summary(
self,
db: AsyncSession,
user_id: str,
) -> dict:
"""获取用户成本摘要。"""
daily = await self.get_daily_cost(db, user_id)
monthly = await self.get_monthly_cost(db, user_id)
by_capability = await self.get_cost_by_capability(db, user_id)
budget = await self.get_user_budget(db, user_id)
return {
"daily_cost_usd": float(daily),
"monthly_cost_usd": float(monthly),
"by_capability": {k: float(v) for k, v in by_capability.items()},
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd) if budget else None,
"monthly_limit_usd": float(budget.monthly_limit_usd) if budget else None,
"daily_usage_percent": float(daily / budget.daily_limit_usd * 100)
if budget and budget.daily_limit_usd
else None,
"monthly_usage_percent": float(monthly / budget.monthly_limit_usd * 100)
if budget and budget.monthly_limit_usd
else None,
"enabled": budget.enabled if budget else False,
},
}
# 全局单例
cost_tracker = CostTracker()
"""成本追踪服务。
记录 API 调用成本,支持预算控制。
"""
from datetime import datetime, timedelta
from decimal import Decimal
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import CostRecord, UserBudget
logger = get_logger(__name__)
class BudgetExceededError(Exception):
"""预算超限错误。"""
def __init__(self, limit_type: str, used: Decimal, limit: Decimal):
self.limit_type = limit_type
self.used = used
self.limit = limit
super().__init__(f"{limit_type} 预算已超限: {used}/{limit} USD")
class CostTracker:
"""成本追踪器。"""
async def record_cost(
self,
db: AsyncSession,
user_id: str,
provider_name: str,
capability: str,
estimated_cost: float,
provider_id: str | None = None,
) -> CostRecord:
"""记录一次 API 调用成本。"""
record = CostRecord(
user_id=user_id,
provider_id=provider_id,
provider_name=provider_name,
capability=capability,
estimated_cost=Decimal(str(estimated_cost)),
)
db.add(record)
await db.commit()
logger.debug(
"cost_recorded",
user_id=user_id,
provider=provider_name,
capability=capability,
cost=estimated_cost,
)
return record
async def get_daily_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户今日成本。"""
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= today_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_monthly_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户本月成本。"""
now = datetime.utcnow()
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= month_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_cost_by_capability(
self,
db: AsyncSession,
user_id: str,
days: int = 30,
) -> dict[str, Decimal]:
"""按能力类型统计成本。"""
since = datetime.utcnow() - timedelta(days=days)
result = await db.execute(
select(CostRecord.capability, func.sum(CostRecord.estimated_cost))
.where(CostRecord.user_id == user_id, CostRecord.timestamp >= since)
.group_by(CostRecord.capability)
)
return {row[0]: Decimal(str(row[1])) for row in result.all()}
async def check_budget(
self,
db: AsyncSession,
user_id: str,
estimated_cost: float,
) -> bool:
"""检查预算是否允许此次调用。
Returns:
True 如果允许,否则抛出 BudgetExceededError
"""
budget = await self.get_user_budget(db, user_id)
if not budget or not budget.enabled:
return True
# 检查日预算
daily_cost = await self.get_daily_cost(db, user_id)
if daily_cost + Decimal(str(estimated_cost)) > budget.daily_limit_usd:
raise BudgetExceededError("", daily_cost, budget.daily_limit_usd)
# 检查月预算
monthly_cost = await self.get_monthly_cost(db, user_id)
if monthly_cost + Decimal(str(estimated_cost)) > budget.monthly_limit_usd:
raise BudgetExceededError("", monthly_cost, budget.monthly_limit_usd)
return True
async def get_user_budget(self, db: AsyncSession, user_id: str) -> UserBudget | None:
"""获取用户预算配置。"""
result = await db.execute(
select(UserBudget).where(UserBudget.user_id == user_id)
)
return result.scalar_one_or_none()
async def set_user_budget(
self,
db: AsyncSession,
user_id: str,
daily_limit: float | None = None,
monthly_limit: float | None = None,
alert_threshold: float | None = None,
enabled: bool | None = None,
) -> UserBudget:
"""设置用户预算。"""
budget = await self.get_user_budget(db, user_id)
if budget is None:
budget = UserBudget(user_id=user_id)
db.add(budget)
if daily_limit is not None:
budget.daily_limit_usd = Decimal(str(daily_limit))
if monthly_limit is not None:
budget.monthly_limit_usd = Decimal(str(monthly_limit))
if alert_threshold is not None:
budget.alert_threshold = Decimal(str(alert_threshold))
if enabled is not None:
budget.enabled = enabled
await db.commit()
await db.refresh(budget)
return budget
async def get_cost_summary(
self,
db: AsyncSession,
user_id: str,
) -> dict:
"""获取用户成本摘要。"""
daily = await self.get_daily_cost(db, user_id)
monthly = await self.get_monthly_cost(db, user_id)
by_capability = await self.get_cost_by_capability(db, user_id)
budget = await self.get_user_budget(db, user_id)
return {
"daily_cost_usd": float(daily),
"monthly_cost_usd": float(monthly),
"by_capability": {k: float(v) for k, v in by_capability.items()},
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd) if budget else None,
"monthly_limit_usd": float(budget.monthly_limit_usd) if budget else None,
"daily_usage_percent": float(daily / budget.daily_limit_usd * 100)
if budget and budget.daily_limit_usd
else None,
"monthly_usage_percent": float(monthly / budget.monthly_limit_usd * 100)
if budget and budget.monthly_limit_usd
else None,
"enabled": budget.enabled if budget else False,
},
}
# 全局单例
cost_tracker = CostTracker()

View File

@@ -1,471 +1,471 @@
"""Memory service handles memory retrieval, scoring, and prompt injection."""
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models import ChildProfile, MemoryItem, StoryUniverse
logger = get_logger(__name__)
class MemoryType:
"""记忆类型常量及配置。"""
# 基础类型
RECENT_STORY = "recent_story"
FAVORITE_CHARACTER = "favorite_character"
SCARY_ELEMENT = "scary_element"
VOCABULARY_GROWTH = "vocabulary_growth"
EMOTIONAL_HIGHLIGHT = "emotional_highlight"
# Phase 1 新增类型
READING_PREFERENCE = "reading_preference" # 阅读偏好
MILESTONE = "milestone" # 里程碑事件
SKILL_MASTERED = "skill_mastered" # 掌握的技能
# 类型配置: (默认权重, 默认TTL天数, 描述)
CONFIG = {
RECENT_STORY: (1.0, 30, "最近阅读的故事"),
FAVORITE_CHARACTER: (1.5, None, "喜欢的角色"), # None = 永久
SCARY_ELEMENT: (2.0, None, "回避的元素"), # 高权重,永久有效
VOCABULARY_GROWTH: (0.8, 90, "词汇积累"),
EMOTIONAL_HIGHLIGHT: (1.2, 60, "情感高光"),
READING_PREFERENCE: (1.0, None, "阅读偏好"),
MILESTONE: (1.5, None, "里程碑事件"),
SKILL_MASTERED: (1.0, 180, "掌握的技能"),
}
@classmethod
def get_default_weight(cls, memory_type: str) -> float:
"""获取类型的默认权重。"""
config = cls.CONFIG.get(memory_type)
return config[0] if config else 1.0
@classmethod
def get_default_ttl(cls, memory_type: str) -> int | None:
"""获取类型的默认 TTL 天数。"""
config = cls.CONFIG.get(memory_type)
return config[1] if config else None
def _decay_factor(days: float) -> float:
"""计算时间衰减因子。"""
if days <= 7:
return 1.0
if days <= 30:
return 0.7
if days <= 90:
return 0.4
return 0.2
async def build_enhanced_memory_context(
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
) -> str | None:
"""构建增强版记忆上下文(自然语言 Prompt"""
if not profile_id and not universe_id:
return None
context_parts: list[str] = []
# 1. 基础档案 (Identity Layer)
if profile_id:
profile = await db.scalar(select(ChildProfile).where(ChildProfile.id == profile_id))
if profile:
context_parts.append(f"【目标读者】\n姓名:{profile.name}")
if profile.age:
context_parts.append(f"年龄:{profile.age}")
if profile.interests:
context_parts.append(f"兴趣爱好:{''.join(profile.interests)}")
if profile.growth_themes:
context_parts.append(f"当前成长关注点:{''.join(profile.growth_themes)}")
context_parts.append("") # 空行
# 2. 故事宇宙 (Universe Layer)
if universe_id:
universe = await db.scalar(select(StoryUniverse).where(StoryUniverse.id == universe_id))
if universe:
context_parts.append("【故事宇宙设定】")
context_parts.append(f"世界观:{universe.name}")
# 主角
protagonist = universe.protagonist or {}
p_desc = f"{protagonist.get('name', '主角')} ({protagonist.get('personality', '')})"
context_parts.append(f"主角设定:{p_desc}")
# 常驻角色
if universe.recurring_characters:
chars = [f"{c.get('name')} ({c.get('type')})" for c in universe.recurring_characters if isinstance(c, dict)]
context_parts.append(f"已知伙伴:{''.join(chars)}")
# 成就
if universe.achievements:
badges = [str(a.get('type')) for a in universe.achievements if isinstance(a, dict)]
if badges:
context_parts.append(f"已获荣誉:{''.join(badges[:5])}")
context_parts.append("")
# 3. 动态记忆 (Working Memory)
if profile_id:
memories = await _fetch_scored_memories(profile_id, universe_id, db)
if memories:
memory_text = _format_memories_to_prompt(memories)
if memory_text:
context_parts.append("【关键记忆回忆】(请在故事中自然地融入或致敬以下元素)")
context_parts.append(memory_text)
return "\n".join(context_parts)
async def _fetch_scored_memories(
profile_id: str,
universe_id: str | None,
db: AsyncSession,
limit: int = 8
) -> list[MemoryItem]:
"""获取并评分记忆项,返回 Top N。"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
# 取最近 50 条进行评分
query = query.order_by(MemoryItem.last_used_at.desc(), MemoryItem.created_at.desc()).limit(50)
result = await db.execute(query)
items = result.scalars().all()
scored: list[tuple[float, MemoryItem]] = []
now = datetime.now(timezone.utc)
for item in items:
reference = item.last_used_at or item.created_at or now
delta_days = max((now - reference).total_seconds() / 86400, 0)
if item.ttl_days and delta_days > item.ttl_days:
continue
score = (item.base_weight or 1.0) * _decay_factor(delta_days)
if score <= 0.1: # 忽略低权重
continue
scored.append((score, item))
scored.sort(key=lambda x: x[0], reverse=True)
return [item for _, item in scored[:limit]]
def _format_memories_to_prompt(memories: list[MemoryItem]) -> str:
"""将记忆项转换为自然语言指令。"""
lines = []
# 分类处理
recent_stories = []
favorites = []
scary = []
vocab = []
for m in memories:
if m.type == MemoryType.RECENT_STORY:
recent_stories.append(m)
elif m.type == MemoryType.FAVORITE_CHARACTER:
favorites.append(m)
elif m.type == MemoryType.SCARY_ELEMENT:
scary.append(m)
elif m.type == MemoryType.VOCABULARY_GROWTH:
vocab.append(m)
# 1. 喜欢的角色
if favorites:
names = []
for m in favorites:
val = m.value
if isinstance(val, dict):
names.append(f"{val.get('name')} ({val.get('description', '')})")
if names:
lines.append(f"- 孩子特别喜欢这些角色,可以让他们客串出场:{', '.join(names)}")
# 2. 避雷区
if scary:
items = []
for m in scary:
val = m.value
if isinstance(val, dict):
items.append(val.get('keyword', ''))
elif isinstance(val, str):
items.append(val)
if items:
lines.append(f"- 【注意禁止】不要出现以下让孩子害怕的元素:{', '.join(items)}")
# 3. 近期故事 (取最近 2 个)
if recent_stories:
lines.append("- 近期经历(可作为彩蛋提及):")
for m in recent_stories[:2]:
val = m.value
if isinstance(val, dict):
title = val.get('title', '未知故事')
lines.append(f" * 之前读过《{title}")
# 4. 词汇积累
if vocab:
words = []
for m in vocab:
val = m.value
if isinstance(val, dict):
words.append(val.get('word'))
if words:
lines.append(f"- 已掌握词汇(可适当复现以巩固):{', '.join([w for w in words if w])}")
return "\n".join(lines)
async def prune_expired_memories(db: AsyncSession) -> int:
"""清理过期的记忆项。
Returns:
删除的记录数量
"""
from sqlalchemy import delete
now = datetime.now(timezone.utc)
# 查找所有设置了 TTL 的项目
stmt = select(MemoryItem).where(MemoryItem.ttl_days.is_not(None))
result = await db.execute(stmt)
candidates = result.scalars().all()
to_delete_ids = []
for item in candidates:
if not item.ttl_days:
continue
reference = item.last_used_at or item.created_at or now
delta_days = (now - reference).total_seconds() / 86400
if delta_days > item.ttl_days:
to_delete_ids.append(item.id)
if not to_delete_ids:
return 0
delete_stmt = delete(MemoryItem).where(MemoryItem.id.in_(to_delete_ids))
await db.execute(delete_stmt)
await db.commit()
logger.info("memory_pruned", count=len(to_delete_ids))
return len(to_delete_ids)
async def create_memory(
db: AsyncSession,
profile_id: str,
memory_type: str,
value: dict,
universe_id: str | None = None,
weight: float | None = None,
ttl_days: int | None = None,
) -> MemoryItem:
"""创建新的记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 记忆类型 (使用 MemoryType 常量)
value: 记忆内容 (JSON 格式)
universe_id: 可选,关联的故事宇宙 ID
weight: 可选,权重 (默认使用类型配置)
ttl_days: 可选,过期天数 (默认使用类型配置)
Returns:
创建的 MemoryItem
"""
memory = MemoryItem(
child_profile_id=profile_id,
universe_id=universe_id,
type=memory_type,
value=value,
base_weight=weight or MemoryType.get_default_weight(memory_type),
ttl_days=ttl_days if ttl_days is not None else MemoryType.get_default_ttl(memory_type),
)
db.add(memory)
await db.commit()
await db.refresh(memory)
logger.info(
"memory_created",
memory_id=memory.id,
profile_id=profile_id,
type=memory_type,
)
return memory
async def update_memory_usage(db: AsyncSession, memory_id: str) -> None:
"""更新记忆的最后使用时间。
Args:
db: 数据库会话
memory_id: 记忆项 ID
"""
result = await db.execute(select(MemoryItem).where(MemoryItem.id == memory_id))
memory = result.scalar_one_or_none()
if memory:
memory.last_used_at = datetime.now(timezone.utc)
await db.commit()
logger.debug("memory_usage_updated", memory_id=memory_id)
async def get_profile_memories(
db: AsyncSession,
profile_id: str,
memory_type: str | None = None,
universe_id: str | None = None,
limit: int = 50,
) -> list[MemoryItem]:
"""获取档案的记忆列表。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 可选,按类型筛选
universe_id: 可选,按宇宙筛选
limit: 返回数量限制
Returns:
MemoryItem 列表
"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if memory_type:
query = query.where(MemoryItem.type == memory_type)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
query = query.order_by(MemoryItem.created_at.desc()).limit(limit)
result = await db.execute(query)
return list(result.scalars().all())
async def create_story_memory(
db: AsyncSession,
profile_id: str,
story_id: int,
title: str,
summary: str | None = None,
keywords: list[str] | None = None,
universe_id: str | None = None,
) -> MemoryItem:
"""为故事创建记忆项。
这是一个便捷函数,专门用于在故事阅读后创建 recent_story 类型的记忆。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
story_id: 故事 ID
title: 故事标题
summary: 故事梗概
keywords: 关键词列表
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"story_id": story_id,
"title": title,
"summary": summary or "",
"keywords": keywords or [],
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.RECENT_STORY,
value=value,
universe_id=universe_id,
)
async def create_character_memory(
db: AsyncSession,
profile_id: str,
name: str,
description: str | None = None,
source_story_id: int | None = None,
affinity_score: float = 1.0,
universe_id: str | None = None,
) -> MemoryItem:
"""为喜欢的角色创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
name: 角色名称
description: 角色描述
source_story_id: 来源故事 ID
affinity_score: 喜爱程度 (0.0-1.0)
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"name": name,
"description": description or "",
"source_story_id": source_story_id,
"affinity_score": min(1.0, max(0.0, affinity_score)),
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.FAVORITE_CHARACTER,
value=value,
universe_id=universe_id,
)
async def create_scary_element_memory(
db: AsyncSession,
profile_id: str,
keyword: str,
category: str = "other",
source_story_id: int | None = None,
) -> MemoryItem:
"""为回避元素创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
keyword: 回避的关键词
category: 分类 (creature/scene/action/other)
source_story_id: 来源故事 ID
Returns:
创建的 MemoryItem
"""
value = {
"keyword": keyword,
"category": category,
"source_story_id": source_story_id,
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.SCARY_ELEMENT,
value=value,
)
"""Memory service handles memory retrieval, scoring, and prompt injection."""
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models import ChildProfile, MemoryItem, StoryUniverse
logger = get_logger(__name__)
class MemoryType:
"""记忆类型常量及配置。"""
# 基础类型
RECENT_STORY = "recent_story"
FAVORITE_CHARACTER = "favorite_character"
SCARY_ELEMENT = "scary_element"
VOCABULARY_GROWTH = "vocabulary_growth"
EMOTIONAL_HIGHLIGHT = "emotional_highlight"
# Phase 1 新增类型
READING_PREFERENCE = "reading_preference" # 阅读偏好
MILESTONE = "milestone" # 里程碑事件
SKILL_MASTERED = "skill_mastered" # 掌握的技能
# 类型配置: (默认权重, 默认TTL天数, 描述)
CONFIG = {
RECENT_STORY: (1.0, 30, "最近阅读的故事"),
FAVORITE_CHARACTER: (1.5, None, "喜欢的角色"), # None = 永久
SCARY_ELEMENT: (2.0, None, "回避的元素"), # 高权重,永久有效
VOCABULARY_GROWTH: (0.8, 90, "词汇积累"),
EMOTIONAL_HIGHLIGHT: (1.2, 60, "情感高光"),
READING_PREFERENCE: (1.0, None, "阅读偏好"),
MILESTONE: (1.5, None, "里程碑事件"),
SKILL_MASTERED: (1.0, 180, "掌握的技能"),
}
@classmethod
def get_default_weight(cls, memory_type: str) -> float:
"""获取类型的默认权重。"""
config = cls.CONFIG.get(memory_type)
return config[0] if config else 1.0
@classmethod
def get_default_ttl(cls, memory_type: str) -> int | None:
"""获取类型的默认 TTL 天数。"""
config = cls.CONFIG.get(memory_type)
return config[1] if config else None
def _decay_factor(days: float) -> float:
"""计算时间衰减因子。"""
if days <= 7:
return 1.0
if days <= 30:
return 0.7
if days <= 90:
return 0.4
return 0.2
async def build_enhanced_memory_context(
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
) -> str | None:
"""构建增强版记忆上下文(自然语言 Prompt"""
if not profile_id and not universe_id:
return None
context_parts: list[str] = []
# 1. 基础档案 (Identity Layer)
if profile_id:
profile = await db.scalar(select(ChildProfile).where(ChildProfile.id == profile_id))
if profile:
context_parts.append(f"【目标读者】\n姓名:{profile.name}")
if profile.age:
context_parts.append(f"年龄:{profile.age}")
if profile.interests:
context_parts.append(f"兴趣爱好:{''.join(profile.interests)}")
if profile.growth_themes:
context_parts.append(f"当前成长关注点:{''.join(profile.growth_themes)}")
context_parts.append("") # 空行
# 2. 故事宇宙 (Universe Layer)
if universe_id:
universe = await db.scalar(select(StoryUniverse).where(StoryUniverse.id == universe_id))
if universe:
context_parts.append("【故事宇宙设定】")
context_parts.append(f"世界观:{universe.name}")
# 主角
protagonist = universe.protagonist or {}
p_desc = f"{protagonist.get('name', '主角')} ({protagonist.get('personality', '')})"
context_parts.append(f"主角设定:{p_desc}")
# 常驻角色
if universe.recurring_characters:
chars = [f"{c.get('name')} ({c.get('type')})" for c in universe.recurring_characters if isinstance(c, dict)]
context_parts.append(f"已知伙伴:{''.join(chars)}")
# 成就
if universe.achievements:
badges = [str(a.get('type')) for a in universe.achievements if isinstance(a, dict)]
if badges:
context_parts.append(f"已获荣誉:{''.join(badges[:5])}")
context_parts.append("")
# 3. 动态记忆 (Working Memory)
if profile_id:
memories = await _fetch_scored_memories(profile_id, universe_id, db)
if memories:
memory_text = _format_memories_to_prompt(memories)
if memory_text:
context_parts.append("【关键记忆回忆】(请在故事中自然地融入或致敬以下元素)")
context_parts.append(memory_text)
return "\n".join(context_parts)
async def _fetch_scored_memories(
profile_id: str,
universe_id: str | None,
db: AsyncSession,
limit: int = 8
) -> list[MemoryItem]:
"""获取并评分记忆项,返回 Top N。"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
# 取最近 50 条进行评分
query = query.order_by(MemoryItem.last_used_at.desc(), MemoryItem.created_at.desc()).limit(50)
result = await db.execute(query)
items = result.scalars().all()
scored: list[tuple[float, MemoryItem]] = []
now = datetime.now(timezone.utc)
for item in items:
reference = item.last_used_at or item.created_at or now
delta_days = max((now - reference).total_seconds() / 86400, 0)
if item.ttl_days and delta_days > item.ttl_days:
continue
score = (item.base_weight or 1.0) * _decay_factor(delta_days)
if score <= 0.1: # 忽略低权重
continue
scored.append((score, item))
scored.sort(key=lambda x: x[0], reverse=True)
return [item for _, item in scored[:limit]]
def _format_memories_to_prompt(memories: list[MemoryItem]) -> str:
"""将记忆项转换为自然语言指令。"""
lines = []
# 分类处理
recent_stories = []
favorites = []
scary = []
vocab = []
for m in memories:
if m.type == MemoryType.RECENT_STORY:
recent_stories.append(m)
elif m.type == MemoryType.FAVORITE_CHARACTER:
favorites.append(m)
elif m.type == MemoryType.SCARY_ELEMENT:
scary.append(m)
elif m.type == MemoryType.VOCABULARY_GROWTH:
vocab.append(m)
# 1. 喜欢的角色
if favorites:
names = []
for m in favorites:
val = m.value
if isinstance(val, dict):
names.append(f"{val.get('name')} ({val.get('description', '')})")
if names:
lines.append(f"- 孩子特别喜欢这些角色,可以让他们客串出场:{', '.join(names)}")
# 2. 避雷区
if scary:
items = []
for m in scary:
val = m.value
if isinstance(val, dict):
items.append(val.get('keyword', ''))
elif isinstance(val, str):
items.append(val)
if items:
lines.append(f"- 【注意禁止】不要出现以下让孩子害怕的元素:{', '.join(items)}")
# 3. 近期故事 (取最近 2 个)
if recent_stories:
lines.append("- 近期经历(可作为彩蛋提及):")
for m in recent_stories[:2]:
val = m.value
if isinstance(val, dict):
title = val.get('title', '未知故事')
lines.append(f" * 之前读过《{title}")
# 4. 词汇积累
if vocab:
words = []
for m in vocab:
val = m.value
if isinstance(val, dict):
words.append(val.get('word'))
if words:
lines.append(f"- 已掌握词汇(可适当复现以巩固):{', '.join([w for w in words if w])}")
return "\n".join(lines)
async def prune_expired_memories(db: AsyncSession) -> int:
"""清理过期的记忆项。
Returns:
删除的记录数量
"""
from sqlalchemy import delete
now = datetime.now(timezone.utc)
# 查找所有设置了 TTL 的项目
stmt = select(MemoryItem).where(MemoryItem.ttl_days.is_not(None))
result = await db.execute(stmt)
candidates = result.scalars().all()
to_delete_ids = []
for item in candidates:
if not item.ttl_days:
continue
reference = item.last_used_at or item.created_at or now
delta_days = (now - reference).total_seconds() / 86400
if delta_days > item.ttl_days:
to_delete_ids.append(item.id)
if not to_delete_ids:
return 0
delete_stmt = delete(MemoryItem).where(MemoryItem.id.in_(to_delete_ids))
await db.execute(delete_stmt)
await db.commit()
logger.info("memory_pruned", count=len(to_delete_ids))
return len(to_delete_ids)
async def create_memory(
db: AsyncSession,
profile_id: str,
memory_type: str,
value: dict,
universe_id: str | None = None,
weight: float | None = None,
ttl_days: int | None = None,
) -> MemoryItem:
"""创建新的记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 记忆类型 (使用 MemoryType 常量)
value: 记忆内容 (JSON 格式)
universe_id: 可选,关联的故事宇宙 ID
weight: 可选,权重 (默认使用类型配置)
ttl_days: 可选,过期天数 (默认使用类型配置)
Returns:
创建的 MemoryItem
"""
memory = MemoryItem(
child_profile_id=profile_id,
universe_id=universe_id,
type=memory_type,
value=value,
base_weight=weight or MemoryType.get_default_weight(memory_type),
ttl_days=ttl_days if ttl_days is not None else MemoryType.get_default_ttl(memory_type),
)
db.add(memory)
await db.commit()
await db.refresh(memory)
logger.info(
"memory_created",
memory_id=memory.id,
profile_id=profile_id,
type=memory_type,
)
return memory
async def update_memory_usage(db: AsyncSession, memory_id: str) -> None:
"""更新记忆的最后使用时间。
Args:
db: 数据库会话
memory_id: 记忆项 ID
"""
result = await db.execute(select(MemoryItem).where(MemoryItem.id == memory_id))
memory = result.scalar_one_or_none()
if memory:
memory.last_used_at = datetime.now(timezone.utc)
await db.commit()
logger.debug("memory_usage_updated", memory_id=memory_id)
async def get_profile_memories(
db: AsyncSession,
profile_id: str,
memory_type: str | None = None,
universe_id: str | None = None,
limit: int = 50,
) -> list[MemoryItem]:
"""获取档案的记忆列表。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
memory_type: 可选,按类型筛选
universe_id: 可选,按宇宙筛选
limit: 返回数量限制
Returns:
MemoryItem 列表
"""
query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
if memory_type:
query = query.where(MemoryItem.type == memory_type)
if universe_id:
query = query.where(
(MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None))
)
query = query.order_by(MemoryItem.created_at.desc()).limit(limit)
result = await db.execute(query)
return list(result.scalars().all())
async def create_story_memory(
db: AsyncSession,
profile_id: str,
story_id: int,
title: str,
summary: str | None = None,
keywords: list[str] | None = None,
universe_id: str | None = None,
) -> MemoryItem:
"""为故事创建记忆项。
这是一个便捷函数,专门用于在故事阅读后创建 recent_story 类型的记忆。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
story_id: 故事 ID
title: 故事标题
summary: 故事梗概
keywords: 关键词列表
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"story_id": story_id,
"title": title,
"summary": summary or "",
"keywords": keywords or [],
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.RECENT_STORY,
value=value,
universe_id=universe_id,
)
async def create_character_memory(
db: AsyncSession,
profile_id: str,
name: str,
description: str | None = None,
source_story_id: int | None = None,
affinity_score: float = 1.0,
universe_id: str | None = None,
) -> MemoryItem:
"""为喜欢的角色创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
name: 角色名称
description: 角色描述
source_story_id: 来源故事 ID
affinity_score: 喜爱程度 (0.0-1.0)
universe_id: 可选,关联的故事宇宙 ID
Returns:
创建的 MemoryItem
"""
value = {
"name": name,
"description": description or "",
"source_story_id": source_story_id,
"affinity_score": min(1.0, max(0.0, affinity_score)),
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.FAVORITE_CHARACTER,
value=value,
universe_id=universe_id,
)
async def create_scary_element_memory(
db: AsyncSession,
profile_id: str,
keyword: str,
category: str = "other",
source_story_id: int | None = None,
) -> MemoryItem:
"""为回避元素创建记忆项。
Args:
db: 数据库会话
profile_id: 孩子档案 ID
keyword: 回避的关键词
category: 分类 (creature/scene/action/other)
source_story_id: 来源故事 ID
Returns:
创建的 MemoryItem
"""
value = {
"keyword": keyword,
"category": category,
"source_story_id": source_story_id,
}
return await create_memory(
db=db,
profile_id=profile_id,
memory_type=MemoryType.SCARY_ELEMENT,
value=value,
)

View File

@@ -1,109 +1,109 @@
"""Redis-backed cache for providers loaded from DB."""
import json
from collections import defaultdict
from typing import Literal
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.core.redis import get_redis
from app.db.admin_models import Provider
logger = get_logger(__name__)
ProviderType = Literal["text", "image", "tts", "storybook"]
class CachedProvider(BaseModel):
"""Serializable provider configuration matching DB model fields."""
id: str
name: str
type: str
adapter: str
model: str | None = None
api_base: str | None = None
api_key: str | None = None
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_json: dict | None = None
config_ref: str | None = None
# Local memory fallback (L1 cache)
_local_cache: dict[ProviderType, list[CachedProvider]] = defaultdict(list)
CACHE_KEY = "dreamweaver:providers:config"
async def reload_providers(db: AsyncSession) -> dict[ProviderType, list[CachedProvider]]:
"""Reload providers from DB and update Redis cache."""
try:
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
providers = result.scalars().all()
# Convert to Pydantic models
cached_list = []
for p in providers:
cached_list.append(CachedProvider(
id=p.id,
name=p.name,
type=p.type,
adapter=p.adapter,
model=p.model,
api_base=p.api_base,
api_key=p.api_key,
timeout_ms=p.timeout_ms,
max_retries=p.max_retries,
weight=p.weight,
priority=p.priority,
enabled=p.enabled,
config_json=p.config_json,
config_ref=p.config_ref
))
# Group by type
grouped: dict[str, list[CachedProvider]] = defaultdict(list)
for cp in cached_list:
grouped[cp.type].append(cp)
# Sort
for k in grouped:
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
# Update Redis
redis = await get_redis()
# Serialize entire dict structure
# Pydantic -> dict -> json
json_data = {k: [p.model_dump() for p in v] for k, v in grouped.items()}
await redis.set(CACHE_KEY, json.dumps(json_data))
# Update local cache
_local_cache.clear()
_local_cache.update(grouped)
return grouped
except Exception as e:
logger.error("failed_to_reload_providers", error=str(e))
raise
async def get_providers(provider_type: ProviderType) -> list[CachedProvider]:
"""Get providers from Redis (preferred) or local fallback."""
try:
redis = await get_redis()
data = await redis.get(CACHE_KEY)
if data:
raw_dict = json.loads(data)
if provider_type in raw_dict:
return [CachedProvider(**item) for item in raw_dict[provider_type]]
return []
except Exception as e:
logger.warning("redis_cache_read_failed", error=str(e))
# Fallback to local memory
return _local_cache.get(provider_type, [])
"""Redis-backed cache for providers loaded from DB."""
import json
from collections import defaultdict
from typing import Literal
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.core.redis import get_redis
from app.db.admin_models import Provider
logger = get_logger(__name__)
ProviderType = Literal["text", "image", "tts", "storybook"]
class CachedProvider(BaseModel):
"""Serializable provider configuration matching DB model fields."""
id: str
name: str
type: str
adapter: str
model: str | None = None
api_base: str | None = None
api_key: str | None = None
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_json: dict | None = None
config_ref: str | None = None
# Local memory fallback (L1 cache)
_local_cache: dict[ProviderType, list[CachedProvider]] = defaultdict(list)
CACHE_KEY = "dreamweaver:providers:config"
async def reload_providers(db: AsyncSession) -> dict[ProviderType, list[CachedProvider]]:
"""Reload providers from DB and update Redis cache."""
try:
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
providers = result.scalars().all()
# Convert to Pydantic models
cached_list = []
for p in providers:
cached_list.append(CachedProvider(
id=p.id,
name=p.name,
type=p.type,
adapter=p.adapter,
model=p.model,
api_base=p.api_base,
api_key=p.api_key,
timeout_ms=p.timeout_ms,
max_retries=p.max_retries,
weight=p.weight,
priority=p.priority,
enabled=p.enabled,
config_json=p.config_json,
config_ref=p.config_ref
))
# Group by type
grouped: dict[str, list[CachedProvider]] = defaultdict(list)
for cp in cached_list:
grouped[cp.type].append(cp)
# Sort
for k in grouped:
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
# Update Redis
redis = await get_redis()
# Serialize entire dict structure
# Pydantic -> dict -> json
json_data = {k: [p.model_dump() for p in v] for k, v in grouped.items()}
await redis.set(CACHE_KEY, json.dumps(json_data))
# Update local cache
_local_cache.clear()
_local_cache.update(grouped)
return grouped
except Exception as e:
logger.error("failed_to_reload_providers", error=str(e))
raise
async def get_providers(provider_type: ProviderType) -> list[CachedProvider]:
"""Get providers from Redis (preferred) or local fallback."""
try:
redis = await get_redis()
data = await redis.get(CACHE_KEY)
if data:
raw_dict = json.loads(data)
if provider_type in raw_dict:
return [CachedProvider(**item) for item in raw_dict[provider_type]]
return []
except Exception as e:
logger.warning("redis_cache_read_failed", error=str(e))
# Fallback to local memory
return _local_cache.get(provider_type, [])

View File

@@ -1,248 +1,248 @@
"""供应商指标收集和健康检查服务。"""
from datetime import datetime, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import ProviderHealth, ProviderMetrics
if TYPE_CHECKING:
from app.services.adapters.base import BaseAdapter
logger = get_logger(__name__)
# 熔断阈值:连续失败次数
CIRCUIT_BREAKER_THRESHOLD = 3
# 熔断恢复时间(秒)
CIRCUIT_BREAKER_RECOVERY_SECONDS = 60
class MetricsCollector:
"""供应商调用指标收集器。"""
async def record_call(
self,
db: AsyncSession,
provider_id: str,
success: bool,
latency_ms: int | None = None,
cost_usd: float | None = None,
error_message: str | None = None,
request_id: str | None = None,
) -> None:
"""记录一次 API 调用。"""
metric = ProviderMetrics(
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
cost_usd=Decimal(str(cost_usd)) if cost_usd else None,
error_message=error_message,
request_id=request_id,
)
db.add(metric)
await db.commit()
logger.debug(
"metrics_recorded",
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
)
async def get_success_rate(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的成功率。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(
func.count().filter(ProviderMetrics.success.is_(True)).label("success_count"),
func.count().label("total_count"),
).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
row = result.one()
success_count, total_count = row.success_count, row.total_count
if total_count == 0:
return 1.0 # 无数据时假设健康
return success_count / total_count
async def get_avg_latency(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的平均延迟(毫秒)。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.avg(ProviderMetrics.latency_ms)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
ProviderMetrics.latency_ms.isnot(None),
)
)
avg = result.scalar()
return float(avg) if avg else 0.0
async def get_total_cost(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的总成本USD"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.sum(ProviderMetrics.cost_usd)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
total = result.scalar()
return float(total) if total else 0.0
class HealthChecker:
"""供应商健康检查器。"""
async def check_provider(
self,
db: AsyncSession,
provider_id: str,
adapter: "BaseAdapter",
) -> bool:
"""执行健康检查并更新状态。"""
try:
is_healthy = await adapter.health_check()
except Exception as e:
logger.warning("health_check_failed", provider_id=provider_id, error=str(e))
is_healthy = False
await self.update_health_status(
db,
provider_id,
is_healthy,
error=None if is_healthy else "Health check failed",
)
return is_healthy
async def update_health_status(
self,
db: AsyncSession,
provider_id: str,
is_healthy: bool,
error: str | None = None,
) -> None:
"""更新供应商健康状态(含熔断逻辑)。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
now = datetime.utcnow()
if health is None:
health = ProviderHealth(
provider_id=provider_id,
is_healthy=is_healthy,
last_check=now,
consecutive_failures=0 if is_healthy else 1,
last_error=error,
)
db.add(health)
else:
health.last_check = now
if is_healthy:
health.is_healthy = True
health.consecutive_failures = 0
health.last_error = None
else:
health.consecutive_failures += 1
health.last_error = error
# 熔断逻辑
if health.consecutive_failures >= CIRCUIT_BREAKER_THRESHOLD:
health.is_healthy = False
logger.warning(
"circuit_breaker_triggered",
provider_id=provider_id,
consecutive_failures=health.consecutive_failures,
)
await db.commit()
async def record_call_result(
self,
db: AsyncSession,
provider_id: str,
success: bool,
error: str | None = None,
) -> None:
"""根据调用结果更新健康状态。"""
await self.update_health_status(db, provider_id, success, error)
async def get_healthy_providers(
self,
db: AsyncSession,
provider_ids: list[str],
) -> list[str]:
"""获取健康的供应商列表。"""
if not provider_ids:
return []
# 查询所有已记录的健康状态
result = await db.execute(
select(ProviderHealth.provider_id, ProviderHealth.is_healthy).where(
ProviderHealth.provider_id.in_(provider_ids),
)
)
health_map = {row[0]: row[1] for row in result.all()}
# 未记录的供应商默认健康,已记录但不健康的排除
return [
pid for pid in provider_ids
if pid not in health_map or health_map[pid]
]
async def is_healthy(
self,
db: AsyncSession,
provider_id: str,
) -> bool:
"""检查供应商是否健康。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
if health is None:
return True # 未记录默认健康
# 检查是否可以恢复
if not health.is_healthy and health.last_check:
recovery_time = health.last_check + timedelta(seconds=CIRCUIT_BREAKER_RECOVERY_SECONDS)
if datetime.utcnow() >= recovery_time:
return True # 允许重试
return health.is_healthy
# 全局单例
metrics_collector = MetricsCollector()
health_checker = HealthChecker()
"""供应商指标收集和健康检查服务。"""
from datetime import datetime, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import ProviderHealth, ProviderMetrics
if TYPE_CHECKING:
from app.services.adapters.base import BaseAdapter
logger = get_logger(__name__)
# 熔断阈值:连续失败次数
CIRCUIT_BREAKER_THRESHOLD = 3
# 熔断恢复时间(秒)
CIRCUIT_BREAKER_RECOVERY_SECONDS = 60
class MetricsCollector:
"""供应商调用指标收集器。"""
async def record_call(
self,
db: AsyncSession,
provider_id: str,
success: bool,
latency_ms: int | None = None,
cost_usd: float | None = None,
error_message: str | None = None,
request_id: str | None = None,
) -> None:
"""记录一次 API 调用。"""
metric = ProviderMetrics(
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
cost_usd=Decimal(str(cost_usd)) if cost_usd else None,
error_message=error_message,
request_id=request_id,
)
db.add(metric)
await db.commit()
logger.debug(
"metrics_recorded",
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
)
async def get_success_rate(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的成功率。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(
func.count().filter(ProviderMetrics.success.is_(True)).label("success_count"),
func.count().label("total_count"),
).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
row = result.one()
success_count, total_count = row.success_count, row.total_count
if total_count == 0:
return 1.0 # 无数据时假设健康
return success_count / total_count
async def get_avg_latency(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的平均延迟(毫秒)。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.avg(ProviderMetrics.latency_ms)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
ProviderMetrics.latency_ms.isnot(None),
)
)
avg = result.scalar()
return float(avg) if avg else 0.0
async def get_total_cost(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的总成本USD"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.sum(ProviderMetrics.cost_usd)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
total = result.scalar()
return float(total) if total else 0.0
class HealthChecker:
"""供应商健康检查器。"""
async def check_provider(
self,
db: AsyncSession,
provider_id: str,
adapter: "BaseAdapter",
) -> bool:
"""执行健康检查并更新状态。"""
try:
is_healthy = await adapter.health_check()
except Exception as e:
logger.warning("health_check_failed", provider_id=provider_id, error=str(e))
is_healthy = False
await self.update_health_status(
db,
provider_id,
is_healthy,
error=None if is_healthy else "Health check failed",
)
return is_healthy
async def update_health_status(
self,
db: AsyncSession,
provider_id: str,
is_healthy: bool,
error: str | None = None,
) -> None:
"""更新供应商健康状态(含熔断逻辑)。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
now = datetime.utcnow()
if health is None:
health = ProviderHealth(
provider_id=provider_id,
is_healthy=is_healthy,
last_check=now,
consecutive_failures=0 if is_healthy else 1,
last_error=error,
)
db.add(health)
else:
health.last_check = now
if is_healthy:
health.is_healthy = True
health.consecutive_failures = 0
health.last_error = None
else:
health.consecutive_failures += 1
health.last_error = error
# 熔断逻辑
if health.consecutive_failures >= CIRCUIT_BREAKER_THRESHOLD:
health.is_healthy = False
logger.warning(
"circuit_breaker_triggered",
provider_id=provider_id,
consecutive_failures=health.consecutive_failures,
)
await db.commit()
async def record_call_result(
self,
db: AsyncSession,
provider_id: str,
success: bool,
error: str | None = None,
) -> None:
"""根据调用结果更新健康状态。"""
await self.update_health_status(db, provider_id, success, error)
async def get_healthy_providers(
self,
db: AsyncSession,
provider_ids: list[str],
) -> list[str]:
"""获取健康的供应商列表。"""
if not provider_ids:
return []
# 查询所有已记录的健康状态
result = await db.execute(
select(ProviderHealth.provider_id, ProviderHealth.is_healthy).where(
ProviderHealth.provider_id.in_(provider_ids),
)
)
health_map = {row[0]: row[1] for row in result.all()}
# 未记录的供应商默认健康,已记录但不健康的排除
return [
pid for pid in provider_ids
if pid not in health_map or health_map[pid]
]
async def is_healthy(
self,
db: AsyncSession,
provider_id: str,
) -> bool:
"""检查供应商是否健康。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
if health is None:
return True # 未记录默认健康
# 检查是否可以恢复
if not health.is_healthy and health.last_check:
recovery_time = health.last_check + timedelta(seconds=CIRCUIT_BREAKER_RECOVERY_SECONDS)
if datetime.utcnow() >= recovery_time:
return True # 允许重试
return health.is_healthy
# 全局单例
metrics_collector = MetricsCollector()
health_checker = HealthChecker()

View File

@@ -1,207 +1,207 @@
"""供应商密钥加密存储服务。
使用 Fernet 对称加密,密钥从 SECRET_KEY 派生。
"""
import base64
import hashlib
from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.admin_models import ProviderSecret
if TYPE_CHECKING:
pass
logger = get_logger(__name__)
class SecretEncryptionError(Exception):
"""密钥加密/解密错误。"""
pass
class SecretService:
"""供应商密钥加密存储服务。"""
_fernet: Fernet | None = None
@classmethod
def _get_fernet(cls) -> Fernet:
"""获取 Fernet 实例,从 SECRET_KEY 派生加密密钥。"""
if cls._fernet is None:
# 从 SECRET_KEY 派生 32 字节密钥
key_bytes = hashlib.sha256(settings.secret_key.encode()).digest()
fernet_key = base64.urlsafe_b64encode(key_bytes)
cls._fernet = Fernet(fernet_key)
return cls._fernet
@classmethod
def encrypt(cls, plaintext: str) -> str:
"""加密明文,返回 base64 编码的密文。
Args:
plaintext: 要加密的明文
Returns:
base64 编码的密文
"""
if not plaintext:
return ""
fernet = cls._get_fernet()
encrypted = fernet.encrypt(plaintext.encode())
return encrypted.decode()
@classmethod
def decrypt(cls, ciphertext: str) -> str:
"""解密密文,返回明文。
Args:
ciphertext: base64 编码的密文
Returns:
解密后的明文
Raises:
SecretEncryptionError: 解密失败
"""
if not ciphertext:
return ""
try:
fernet = cls._get_fernet()
decrypted = fernet.decrypt(ciphertext.encode())
return decrypted.decode()
except InvalidToken as e:
logger.error("secret_decrypt_failed", error=str(e))
raise SecretEncryptionError("密钥解密失败,可能是 SECRET_KEY 已更改") from e
@classmethod
async def get_secret(cls, db: AsyncSession, name: str) -> str | None:
"""从数据库获取并解密密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
解密后的密钥值,不存在返回 None
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return None
return cls.decrypt(secret.encrypted_value)
@classmethod
async def set_secret(cls, db: AsyncSession, name: str, value: str) -> ProviderSecret:
"""存储或更新加密密钥。
Args:
db: 数据库会话
name: 密钥名称
value: 密钥明文值
Returns:
ProviderSecret 实例
"""
encrypted = cls.encrypt(value)
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
secret = ProviderSecret(name=name, encrypted_value=encrypted)
db.add(secret)
else:
secret.encrypted_value = encrypted
await db.commit()
await db.refresh(secret)
logger.info("secret_stored", name=name)
return secret
@classmethod
async def delete_secret(cls, db: AsyncSession, name: str) -> bool:
"""删除密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
是否删除成功
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return False
await db.delete(secret)
await db.commit()
logger.info("secret_deleted", name=name)
return True
@classmethod
async def list_secrets(cls, db: AsyncSession) -> list[str]:
"""列出所有密钥名称(不返回值)。
Args:
db: 数据库会话
Returns:
密钥名称列表
"""
result = await db.execute(select(ProviderSecret.name))
return [row[0] for row in result.fetchall()]
@classmethod
async def get_api_key(
cls,
db: AsyncSession,
provider_api_key: str | None,
config_ref: str | None,
) -> str | None:
"""获取 Provider 的 API Key按优先级查找。
优先级:
1. provider.api_key (数据库明文/加密)
2. provider.config_ref 指向的 ProviderSecret
3. 环境变量 (config_ref 作为变量名)
Args:
db: 数据库会话
provider_api_key: Provider 表中的 api_key 字段
config_ref: Provider 表中的 config_ref 字段
Returns:
API Key 或 None
"""
# 1. 直接使用 provider.api_key
if provider_api_key:
# 尝试解密,如果失败则当作明文
try:
decrypted = cls.decrypt(provider_api_key)
if decrypted:
return decrypted
except SecretEncryptionError:
pass
return provider_api_key
# 2. 从 ProviderSecret 表查找
if config_ref:
secret_value = await cls.get_secret(db, config_ref)
if secret_value:
return secret_value
# 3. 从环境变量查找
env_value = getattr(settings, config_ref.lower(), None)
if env_value:
return env_value
return None
"""供应商密钥加密存储服务。
使用 Fernet 对称加密,密钥从 SECRET_KEY 派生。
"""
import base64
import hashlib
from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.admin_models import ProviderSecret
if TYPE_CHECKING:
pass
logger = get_logger(__name__)
class SecretEncryptionError(Exception):
"""密钥加密/解密错误。"""
pass
class SecretService:
"""供应商密钥加密存储服务。"""
_fernet: Fernet | None = None
@classmethod
def _get_fernet(cls) -> Fernet:
"""获取 Fernet 实例,从 SECRET_KEY 派生加密密钥。"""
if cls._fernet is None:
# 从 SECRET_KEY 派生 32 字节密钥
key_bytes = hashlib.sha256(settings.secret_key.encode()).digest()
fernet_key = base64.urlsafe_b64encode(key_bytes)
cls._fernet = Fernet(fernet_key)
return cls._fernet
@classmethod
def encrypt(cls, plaintext: str) -> str:
"""加密明文,返回 base64 编码的密文。
Args:
plaintext: 要加密的明文
Returns:
base64 编码的密文
"""
if not plaintext:
return ""
fernet = cls._get_fernet()
encrypted = fernet.encrypt(plaintext.encode())
return encrypted.decode()
@classmethod
def decrypt(cls, ciphertext: str) -> str:
"""解密密文,返回明文。
Args:
ciphertext: base64 编码的密文
Returns:
解密后的明文
Raises:
SecretEncryptionError: 解密失败
"""
if not ciphertext:
return ""
try:
fernet = cls._get_fernet()
decrypted = fernet.decrypt(ciphertext.encode())
return decrypted.decode()
except InvalidToken as e:
logger.error("secret_decrypt_failed", error=str(e))
raise SecretEncryptionError("密钥解密失败,可能是 SECRET_KEY 已更改") from e
@classmethod
async def get_secret(cls, db: AsyncSession, name: str) -> str | None:
"""从数据库获取并解密密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
解密后的密钥值,不存在返回 None
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return None
return cls.decrypt(secret.encrypted_value)
@classmethod
async def set_secret(cls, db: AsyncSession, name: str, value: str) -> ProviderSecret:
"""存储或更新加密密钥。
Args:
db: 数据库会话
name: 密钥名称
value: 密钥明文值
Returns:
ProviderSecret 实例
"""
encrypted = cls.encrypt(value)
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
secret = ProviderSecret(name=name, encrypted_value=encrypted)
db.add(secret)
else:
secret.encrypted_value = encrypted
await db.commit()
await db.refresh(secret)
logger.info("secret_stored", name=name)
return secret
@classmethod
async def delete_secret(cls, db: AsyncSession, name: str) -> bool:
"""删除密钥。
Args:
db: 数据库会话
name: 密钥名称
Returns:
是否删除成功
"""
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
secret = result.scalar_one_or_none()
if secret is None:
return False
await db.delete(secret)
await db.commit()
logger.info("secret_deleted", name=name)
return True
@classmethod
async def list_secrets(cls, db: AsyncSession) -> list[str]:
"""列出所有密钥名称(不返回值)。
Args:
db: 数据库会话
Returns:
密钥名称列表
"""
result = await db.execute(select(ProviderSecret.name))
return [row[0] for row in result.fetchall()]
@classmethod
async def get_api_key(
cls,
db: AsyncSession,
provider_api_key: str | None,
config_ref: str | None,
) -> str | None:
"""获取 Provider 的 API Key按优先级查找。
优先级:
1. provider.api_key (数据库明文/加密)
2. provider.config_ref 指向的 ProviderSecret
3. 环境变量 (config_ref 作为变量名)
Args:
db: 数据库会话
provider_api_key: Provider 表中的 api_key 字段
config_ref: Provider 表中的 config_ref 字段
Returns:
API Key 或 None
"""
# 1. 直接使用 provider.api_key
if provider_api_key:
# 尝试解密,如果失败则当作明文
try:
decrypted = cls.decrypt(provider_api_key)
if decrypted:
return decrypted
except SecretEncryptionError:
pass
return provider_api_key
# 2. 从 ProviderSecret 表查找
if config_ref:
secret_value = await cls.get_secret(db, config_ref)
if secret_value:
return secret_value
# 3. 从环境变量查找
env_value = getattr(settings, config_ref.lower(), None)
if env_value:
return env_value
return None