wip: snapshot full local workspace state
Some checks are pending
Build and Push Docker Images / changes (push) Waiting to run
Build and Push Docker Images / build-backend (push) Blocked by required conditions
Build and Push Docker Images / build-frontend (push) Blocked by required conditions
Build and Push Docker Images / build-admin-frontend (push) Blocked by required conditions
Some checks are pending
Build and Push Docker Images / changes (push) Waiting to run
Build and Push Docker Images / build-backend (push) Blocked by required conditions
Build and Push Docker Images / build-frontend (push) Blocked by required conditions
Build and Push Docker Images / build-admin-frontend (push) Blocked by required conditions
This commit is contained in:
@@ -1,21 +1,21 @@
|
||||
"""适配器模块 - 供应商平台化架构核心。"""
|
||||
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
|
||||
# Image adapters
|
||||
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
# Storybook adapters
|
||||
from app.services.adapters.storybook import primary as _storybook_primary # noqa: F401
|
||||
from app.services.adapters.text import gemini as _text_gemini_adapter # noqa: F401
|
||||
|
||||
# 导入所有适配器以触发注册
|
||||
# Text adapters
|
||||
from app.services.adapters.text import openai as _text_openai_adapter # noqa: F401
|
||||
|
||||
# TTS adapters
|
||||
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
|
||||
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
|
||||
|
||||
__all__ = ["AdapterConfig", "BaseAdapter", "AdapterRegistry"]
|
||||
"""适配器模块 - 供应商平台化架构核心。"""
|
||||
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
|
||||
# Image adapters
|
||||
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
# Storybook adapters
|
||||
from app.services.adapters.storybook import primary as _storybook_primary # noqa: F401
|
||||
from app.services.adapters.text import gemini as _text_gemini_adapter # noqa: F401
|
||||
|
||||
# 导入所有适配器以触发注册
|
||||
# Text adapters
|
||||
from app.services.adapters.text import openai as _text_openai_adapter # noqa: F401
|
||||
|
||||
# TTS adapters
|
||||
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
|
||||
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
|
||||
|
||||
__all__ = ["AdapterConfig", "BaseAdapter", "AdapterRegistry"]
|
||||
|
||||
@@ -1,46 +1,46 @@
|
||||
"""适配器基类定义。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AdapterConfig(BaseModel):
|
||||
"""适配器配置基类。"""
|
||||
|
||||
api_key: str
|
||||
api_base: str | None = None
|
||||
model: str | None = None
|
||||
timeout_ms: int = 60000
|
||||
max_retries: int = 3
|
||||
extra_config: dict = {}
|
||||
|
||||
|
||||
class BaseAdapter(ABC, Generic[T]):
|
||||
"""适配器基类,所有供应商适配器必须继承此类。"""
|
||||
|
||||
# 子类必须定义
|
||||
adapter_type: str # text / image / tts
|
||||
adapter_name: str # text_primary / image_primary / tts_primary
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> T:
|
||||
"""执行适配器逻辑,返回结果。"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> bool:
|
||||
"""健康检查,返回是否可用。"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估单次调用成本 (USD)。"""
|
||||
pass
|
||||
"""适配器基类定义。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AdapterConfig(BaseModel):
|
||||
"""适配器配置基类。"""
|
||||
|
||||
api_key: str
|
||||
api_base: str | None = None
|
||||
model: str | None = None
|
||||
timeout_ms: int = 60000
|
||||
max_retries: int = 3
|
||||
extra_config: dict = {}
|
||||
|
||||
|
||||
class BaseAdapter(ABC, Generic[T]):
|
||||
"""适配器基类,所有供应商适配器必须继承此类。"""
|
||||
|
||||
# 子类必须定义
|
||||
adapter_type: str # text / image / tts
|
||||
adapter_name: str # text_primary / image_primary / tts_primary
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> T:
|
||||
"""执行适配器逻辑,返回结果。"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> bool:
|
||||
"""健康检查,返回是否可用。"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估单次调用成本 (USD)。"""
|
||||
pass
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""图像生成适配器。"""# Image adapters
|
||||
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
|
||||
from app.services.adapters.image import antigravity as _image_antigravity_adapter # noqa: F401
|
||||
"""图像生成适配器。"""# Image adapters
|
||||
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
|
||||
from app.services.adapters.image import antigravity as _image_antigravity_adapter # noqa: F401
|
||||
|
||||
@@ -1,214 +1,214 @@
|
||||
"""Antigravity 图像生成适配器。
|
||||
|
||||
使用 OpenAI 兼容 API 生成图像。
|
||||
支持 gemini-3-pro-image 等模型。
|
||||
"""
|
||||
|
||||
import base64
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_API_BASE = "http://127.0.0.1:8045/v1"
|
||||
DEFAULT_MODEL = "gemini-3-pro-image"
|
||||
DEFAULT_SIZE = "1024x1024"
|
||||
|
||||
# 支持的尺寸映射
|
||||
SUPPORTED_SIZES = {
|
||||
"1024x1024": "1:1",
|
||||
"1280x720": "16:9",
|
||||
"720x1280": "9:16",
|
||||
"1216x896": "4:3",
|
||||
}
|
||||
|
||||
|
||||
@AdapterRegistry.register("image", "antigravity")
|
||||
class AntigravityImageAdapter(BaseAdapter[str]):
|
||||
"""Antigravity 图像生成适配器 (OpenAI 兼容 API)。
|
||||
|
||||
特点:
|
||||
- 使用 OpenAI 兼容的 chat.completions 端点
|
||||
- 通过 extra_body.size 指定图像尺寸
|
||||
- 支持 gemini-3-pro-image 等模型
|
||||
- 返回图片 URL 或 base64
|
||||
"""
|
||||
|
||||
adapter_type = "image"
|
||||
adapter_name = "antigravity"
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
super().__init__(config)
|
||||
self.api_base = config.api_base or DEFAULT_API_BASE
|
||||
self.client = AsyncOpenAI(
|
||||
base_url=self.api_base,
|
||||
api_key=config.api_key,
|
||||
timeout=config.timeout_ms / 1000,
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
size: str | None = None,
|
||||
num_images: int = 1,
|
||||
**kwargs,
|
||||
) -> str | list[str]:
|
||||
"""根据提示词生成图片,返回 URL 或 base64。
|
||||
|
||||
Args:
|
||||
prompt: 图片描述提示词
|
||||
model: 模型名称 (gemini-3-pro-image / gemini-3-pro-image-16-9 等)
|
||||
size: 图像尺寸 (1024x1024, 1280x720, 720x1280, 1216x896)
|
||||
num_images: 生成图片数量 (暂只支持 1)
|
||||
|
||||
Returns:
|
||||
图片 URL 或 base64 字符串
|
||||
"""
|
||||
# 优先使用传入参数,其次使用 Adapter 配置,最后使用默认值
|
||||
model = model or self.config.model or DEFAULT_MODEL
|
||||
|
||||
cfg = self.config.extra_config or {}
|
||||
size = size or cfg.get("size") or DEFAULT_SIZE
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
"antigravity_generate_start",
|
||||
prompt_length=len(prompt),
|
||||
model=model,
|
||||
size=size,
|
||||
)
|
||||
|
||||
# 调用 API
|
||||
image_url = await self._generate_image(prompt, model, size)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"antigravity_generate_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
model=model,
|
||||
)
|
||||
|
||||
return image_url
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 Antigravity API 是否可用。"""
|
||||
try:
|
||||
# 简单测试连通性
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.config.model or DEFAULT_MODEL,
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("antigravity_health_check_failed", error=str(e))
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估每张图片成本 (USD)。
|
||||
|
||||
Antigravity 使用 Gemini 模型,成本约 $0.02/张。
|
||||
"""
|
||||
return 0.02
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((Exception,)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
size: str,
|
||||
) -> str:
|
||||
"""调用 Antigravity API 生成图像。
|
||||
|
||||
Returns:
|
||||
图片 URL 或 base64 data URI
|
||||
"""
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
extra_body={"size": size},
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
content = response.choices[0].message.content
|
||||
if not content:
|
||||
raise ValueError("Antigravity 未返回内容")
|
||||
|
||||
# 尝试解析为图片 URL 或 base64
|
||||
# 响应可能是纯 URL、base64 或 markdown 格式的图片
|
||||
image_url = self._extract_image_url(content)
|
||||
if image_url:
|
||||
return image_url
|
||||
|
||||
raise ValueError(f"Antigravity 响应无法解析为图片: {content[:200]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"antigravity_generate_error",
|
||||
error=str(e),
|
||||
model=model,
|
||||
)
|
||||
raise
|
||||
|
||||
def _extract_image_url(self, content: str) -> str | None:
|
||||
"""从响应内容中提取图片 URL。
|
||||
|
||||
支持多种格式:
|
||||
- 纯 URL: https://...
|
||||
- Markdown: 
|
||||
- Base64 data URI: data:image/...
|
||||
- 纯 base64 字符串
|
||||
"""
|
||||
content = content.strip()
|
||||
|
||||
# 1. 检查是否为 data URI
|
||||
if content.startswith("data:image/"):
|
||||
return content
|
||||
|
||||
# 2. 检查是否为纯 URL
|
||||
if content.startswith("http://") or content.startswith("https://"):
|
||||
# 可能有多行,取第一行
|
||||
return content.split("\n")[0].strip()
|
||||
|
||||
# 3. 检查 Markdown 图片格式 
|
||||
import re
|
||||
md_match = re.search(r"!\[.*?\]\((https?://[^\)]+)\)", content)
|
||||
if md_match:
|
||||
return md_match.group(1)
|
||||
|
||||
# 4. 检查是否像 base64 编码的图片数据
|
||||
if self._looks_like_base64(content):
|
||||
# 假设是 PNG
|
||||
return f"data:image/png;base64,{content}"
|
||||
|
||||
return None
|
||||
|
||||
def _looks_like_base64(self, s: str) -> bool:
|
||||
"""判断字符串是否看起来像 base64 编码。"""
|
||||
# Base64 只包含 A-Z, a-z, 0-9, +, /, =
|
||||
# 且长度通常较长
|
||||
if len(s) < 100:
|
||||
return False
|
||||
import re
|
||||
return bool(re.match(r"^[A-Za-z0-9+/=]+$", s.replace("\n", "")))
|
||||
"""Antigravity 图像生成适配器。
|
||||
|
||||
使用 OpenAI 兼容 API 生成图像。
|
||||
支持 gemini-3-pro-image 等模型。
|
||||
"""
|
||||
|
||||
import base64
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_API_BASE = "http://127.0.0.1:8045/v1"
|
||||
DEFAULT_MODEL = "gemini-3-pro-image"
|
||||
DEFAULT_SIZE = "1024x1024"
|
||||
|
||||
# 支持的尺寸映射
|
||||
SUPPORTED_SIZES = {
|
||||
"1024x1024": "1:1",
|
||||
"1280x720": "16:9",
|
||||
"720x1280": "9:16",
|
||||
"1216x896": "4:3",
|
||||
}
|
||||
|
||||
|
||||
@AdapterRegistry.register("image", "antigravity")
|
||||
class AntigravityImageAdapter(BaseAdapter[str]):
|
||||
"""Antigravity 图像生成适配器 (OpenAI 兼容 API)。
|
||||
|
||||
特点:
|
||||
- 使用 OpenAI 兼容的 chat.completions 端点
|
||||
- 通过 extra_body.size 指定图像尺寸
|
||||
- 支持 gemini-3-pro-image 等模型
|
||||
- 返回图片 URL 或 base64
|
||||
"""
|
||||
|
||||
adapter_type = "image"
|
||||
adapter_name = "antigravity"
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
super().__init__(config)
|
||||
self.api_base = config.api_base or DEFAULT_API_BASE
|
||||
self.client = AsyncOpenAI(
|
||||
base_url=self.api_base,
|
||||
api_key=config.api_key,
|
||||
timeout=config.timeout_ms / 1000,
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
size: str | None = None,
|
||||
num_images: int = 1,
|
||||
**kwargs,
|
||||
) -> str | list[str]:
|
||||
"""根据提示词生成图片,返回 URL 或 base64。
|
||||
|
||||
Args:
|
||||
prompt: 图片描述提示词
|
||||
model: 模型名称 (gemini-3-pro-image / gemini-3-pro-image-16-9 等)
|
||||
size: 图像尺寸 (1024x1024, 1280x720, 720x1280, 1216x896)
|
||||
num_images: 生成图片数量 (暂只支持 1)
|
||||
|
||||
Returns:
|
||||
图片 URL 或 base64 字符串
|
||||
"""
|
||||
# 优先使用传入参数,其次使用 Adapter 配置,最后使用默认值
|
||||
model = model or self.config.model or DEFAULT_MODEL
|
||||
|
||||
cfg = self.config.extra_config or {}
|
||||
size = size or cfg.get("size") or DEFAULT_SIZE
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
"antigravity_generate_start",
|
||||
prompt_length=len(prompt),
|
||||
model=model,
|
||||
size=size,
|
||||
)
|
||||
|
||||
# 调用 API
|
||||
image_url = await self._generate_image(prompt, model, size)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"antigravity_generate_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
model=model,
|
||||
)
|
||||
|
||||
return image_url
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 Antigravity API 是否可用。"""
|
||||
try:
|
||||
# 简单测试连通性
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.config.model or DEFAULT_MODEL,
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("antigravity_health_check_failed", error=str(e))
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估每张图片成本 (USD)。
|
||||
|
||||
Antigravity 使用 Gemini 模型,成本约 $0.02/张。
|
||||
"""
|
||||
return 0.02
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((Exception,)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
size: str,
|
||||
) -> str:
|
||||
"""调用 Antigravity API 生成图像。
|
||||
|
||||
Returns:
|
||||
图片 URL 或 base64 data URI
|
||||
"""
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
extra_body={"size": size},
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
content = response.choices[0].message.content
|
||||
if not content:
|
||||
raise ValueError("Antigravity 未返回内容")
|
||||
|
||||
# 尝试解析为图片 URL 或 base64
|
||||
# 响应可能是纯 URL、base64 或 markdown 格式的图片
|
||||
image_url = self._extract_image_url(content)
|
||||
if image_url:
|
||||
return image_url
|
||||
|
||||
raise ValueError(f"Antigravity 响应无法解析为图片: {content[:200]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"antigravity_generate_error",
|
||||
error=str(e),
|
||||
model=model,
|
||||
)
|
||||
raise
|
||||
|
||||
def _extract_image_url(self, content: str) -> str | None:
|
||||
"""从响应内容中提取图片 URL。
|
||||
|
||||
支持多种格式:
|
||||
- 纯 URL: https://...
|
||||
- Markdown: 
|
||||
- Base64 data URI: data:image/...
|
||||
- 纯 base64 字符串
|
||||
"""
|
||||
content = content.strip()
|
||||
|
||||
# 1. 检查是否为 data URI
|
||||
if content.startswith("data:image/"):
|
||||
return content
|
||||
|
||||
# 2. 检查是否为纯 URL
|
||||
if content.startswith("http://") or content.startswith("https://"):
|
||||
# 可能有多行,取第一行
|
||||
return content.split("\n")[0].strip()
|
||||
|
||||
# 3. 检查 Markdown 图片格式 
|
||||
import re
|
||||
md_match = re.search(r"!\[.*?\]\((https?://[^\)]+)\)", content)
|
||||
if md_match:
|
||||
return md_match.group(1)
|
||||
|
||||
# 4. 检查是否像 base64 编码的图片数据
|
||||
if self._looks_like_base64(content):
|
||||
# 假设是 PNG
|
||||
return f"data:image/png;base64,{content}"
|
||||
|
||||
return None
|
||||
|
||||
def _looks_like_base64(self, s: str) -> bool:
|
||||
"""判断字符串是否看起来像 base64 编码。"""
|
||||
# Base64 只包含 A-Z, a-z, 0-9, +, /, =
|
||||
# 且长度通常较长
|
||||
if len(s) < 100:
|
||||
return False
|
||||
import re
|
||||
return bool(re.match(r"^[A-Za-z0-9+/=]+$", s.replace("\n", "")))
|
||||
|
||||
@@ -1,252 +1,252 @@
|
||||
"""CQTAI nano 图像生成适配器。
|
||||
|
||||
支持异步生成 + 轮询获取结果。
|
||||
API 文档: https://api.cqtai.com
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_API_BASE = "https://api.cqtai.com"
|
||||
DEFAULT_MODEL = "nano-banana"
|
||||
DEFAULT_RESOLUTION = "2K"
|
||||
DEFAULT_ASPECT_RATIO = "1:1"
|
||||
POLL_INTERVAL_SECONDS = 2
|
||||
MAX_POLL_ATTEMPTS = 60 # 最多轮询 2 分钟
|
||||
|
||||
|
||||
@AdapterRegistry.register("image", "cqtai")
|
||||
class CQTAIImageAdapter(BaseAdapter[str]):
|
||||
"""CQTAI nano 图像生成适配器,返回图片 URL。
|
||||
|
||||
特点:
|
||||
- 异步生成 + 轮询获取结果
|
||||
- 支持 nano-banana (标准) 和 nano-banana-pro (高画质)
|
||||
- 支持多种分辨率和画面比例
|
||||
- 支持图生图 (filesUrl)
|
||||
"""
|
||||
|
||||
adapter_type = "image"
|
||||
adapter_name = "cqtai"
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
super().__init__(config)
|
||||
self.api_base = config.api_base or DEFAULT_API_BASE
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
resolution: str | None = None,
|
||||
aspect_ratio: str | None = None,
|
||||
num_images: int = 1,
|
||||
files_url: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> str | list[str]:
|
||||
"""根据提示词生成图片,返回 URL 或 URL 列表。
|
||||
|
||||
Args:
|
||||
prompt: 图片描述提示词
|
||||
model: 模型名称 (nano-banana / nano-banana-pro)
|
||||
resolution: 分辨率 (1K / 2K / 4K)
|
||||
aspect_ratio: 画面比例 (1:1, 16:9, 9:16, 4:3, 3:4 等)
|
||||
num_images: 生成图片数量 (1-4)
|
||||
files_url: 输入图片 URL 列表 (图生图)
|
||||
|
||||
Returns:
|
||||
单张图片返回 str,多张返回 list[str]
|
||||
"""
|
||||
# 1. 优先使用传入参数
|
||||
# 2. 其次使用 Adapter 配置里的 default (extra_config)
|
||||
# 3. 最后使用系统默认值
|
||||
model = model or self.config.model or DEFAULT_MODEL
|
||||
|
||||
cfg = self.config.extra_config or {}
|
||||
resolution = resolution or cfg.get("resolution") or DEFAULT_RESOLUTION
|
||||
aspect_ratio = aspect_ratio or cfg.get("aspect_ratio") or DEFAULT_ASPECT_RATIO
|
||||
num_images = min(max(num_images, 1), 4) # 限制 1-4
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
"cqtai_generate_start",
|
||||
prompt_length=len(prompt),
|
||||
model=model,
|
||||
resolution=resolution,
|
||||
aspect_ratio=aspect_ratio,
|
||||
num_images=num_images,
|
||||
)
|
||||
|
||||
# 1. 提交生成任务
|
||||
task_id = await self._submit_task(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
resolution=resolution,
|
||||
aspect_ratio=aspect_ratio,
|
||||
num_images=num_images,
|
||||
files_url=files_url or [],
|
||||
)
|
||||
|
||||
logger.info("cqtai_task_submitted", task_id=task_id)
|
||||
|
||||
# 2. 轮询获取结果
|
||||
result = await self._poll_result(task_id)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"cqtai_generate_success",
|
||||
task_id=task_id,
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
image_count=len(result) if isinstance(result, list) else 1,
|
||||
)
|
||||
|
||||
# 单张图片返回字符串,多张返回列表
|
||||
if num_images == 1 and isinstance(result, list) and len(result) == 1:
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 CQTAI API 是否可用。"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
# 简单的连通性测试
|
||||
response = await client.get(
|
||||
f"{self.api_base}/api/cqt/info/nano",
|
||||
params={"id": "health_check_test"},
|
||||
headers={"Authorization": self.config.api_key},
|
||||
)
|
||||
# 即使返回错误也说明服务可达
|
||||
return response.status_code in (200, 400, 401, 403, 404)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估每张图片成本 (USD)。
|
||||
|
||||
nano-banana: ¥0.1 ≈ $0.014
|
||||
nano-banana-pro: ¥0.2 ≈ $0.028
|
||||
"""
|
||||
model = self.config.model or DEFAULT_MODEL
|
||||
if model == "nano-banana-pro":
|
||||
return 0.028
|
||||
return 0.014
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _submit_task(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
num_images: int,
|
||||
files_url: list[str],
|
||||
) -> str:
|
||||
"""提交图像生成任务,返回任务 ID。"""
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"numImages": num_images,
|
||||
"aspectRatio": aspect_ratio,
|
||||
"filesUrl": files_url,
|
||||
}
|
||||
|
||||
# 可选参数,不传则使用默认值
|
||||
if model != DEFAULT_MODEL:
|
||||
payload["model"] = model
|
||||
if resolution != DEFAULT_RESOLUTION:
|
||||
payload["resolution"] = resolution
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.api_base}/api/cqt/generator/nano",
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": self.config.api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if data.get("code") != 200:
|
||||
raise ValueError(f"CQTAI 任务提交失败: {data.get('msg', '未知错误')}")
|
||||
|
||||
task_id = data.get("data")
|
||||
if not task_id:
|
||||
raise ValueError("CQTAI 未返回任务 ID")
|
||||
|
||||
return task_id
|
||||
|
||||
async def _poll_result(self, task_id: str) -> list[str]:
|
||||
"""轮询获取生成结果。
|
||||
|
||||
Returns:
|
||||
图片 URL 列表
|
||||
"""
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
for attempt in range(MAX_POLL_ATTEMPTS):
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.api_base}/api/cqt/info/nano",
|
||||
params={"id": task_id},
|
||||
headers={"Authorization": self.config.api_key},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if data.get("code") != 200:
|
||||
raise ValueError(f"CQTAI 查询失败: {data.get('msg', '未知错误')}")
|
||||
|
||||
result_data = data.get("data", {})
|
||||
status = result_data.get("status")
|
||||
|
||||
if status == "completed":
|
||||
# 提取图片 URL
|
||||
images = result_data.get("images", [])
|
||||
if not images:
|
||||
# 兼容不同返回格式
|
||||
image_url = result_data.get("imageUrl") or result_data.get("url")
|
||||
if image_url:
|
||||
images = [image_url]
|
||||
|
||||
if not images:
|
||||
raise ValueError("CQTAI 未返回图片 URL")
|
||||
|
||||
return images
|
||||
|
||||
elif status == "failed":
|
||||
error_msg = result_data.get("error", "生成失败")
|
||||
raise ValueError(f"CQTAI 图像生成失败: {error_msg}")
|
||||
|
||||
# 继续等待
|
||||
logger.debug(
|
||||
"cqtai_poll_waiting",
|
||||
task_id=task_id,
|
||||
attempt=attempt + 1,
|
||||
status=status,
|
||||
)
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
raise TimeoutError(f"CQTAI 任务超时: {task_id}")
|
||||
"""CQTAI nano 图像生成适配器。
|
||||
|
||||
支持异步生成 + 轮询获取结果。
|
||||
API 文档: https://api.cqtai.com
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_API_BASE = "https://api.cqtai.com"
|
||||
DEFAULT_MODEL = "nano-banana"
|
||||
DEFAULT_RESOLUTION = "2K"
|
||||
DEFAULT_ASPECT_RATIO = "1:1"
|
||||
POLL_INTERVAL_SECONDS = 2
|
||||
MAX_POLL_ATTEMPTS = 60 # 最多轮询 2 分钟
|
||||
|
||||
|
||||
@AdapterRegistry.register("image", "cqtai")
|
||||
class CQTAIImageAdapter(BaseAdapter[str]):
|
||||
"""CQTAI nano 图像生成适配器,返回图片 URL。
|
||||
|
||||
特点:
|
||||
- 异步生成 + 轮询获取结果
|
||||
- 支持 nano-banana (标准) 和 nano-banana-pro (高画质)
|
||||
- 支持多种分辨率和画面比例
|
||||
- 支持图生图 (filesUrl)
|
||||
"""
|
||||
|
||||
adapter_type = "image"
|
||||
adapter_name = "cqtai"
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
super().__init__(config)
|
||||
self.api_base = config.api_base or DEFAULT_API_BASE
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
resolution: str | None = None,
|
||||
aspect_ratio: str | None = None,
|
||||
num_images: int = 1,
|
||||
files_url: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> str | list[str]:
|
||||
"""根据提示词生成图片,返回 URL 或 URL 列表。
|
||||
|
||||
Args:
|
||||
prompt: 图片描述提示词
|
||||
model: 模型名称 (nano-banana / nano-banana-pro)
|
||||
resolution: 分辨率 (1K / 2K / 4K)
|
||||
aspect_ratio: 画面比例 (1:1, 16:9, 9:16, 4:3, 3:4 等)
|
||||
num_images: 生成图片数量 (1-4)
|
||||
files_url: 输入图片 URL 列表 (图生图)
|
||||
|
||||
Returns:
|
||||
单张图片返回 str,多张返回 list[str]
|
||||
"""
|
||||
# 1. 优先使用传入参数
|
||||
# 2. 其次使用 Adapter 配置里的 default (extra_config)
|
||||
# 3. 最后使用系统默认值
|
||||
model = model or self.config.model or DEFAULT_MODEL
|
||||
|
||||
cfg = self.config.extra_config or {}
|
||||
resolution = resolution or cfg.get("resolution") or DEFAULT_RESOLUTION
|
||||
aspect_ratio = aspect_ratio or cfg.get("aspect_ratio") or DEFAULT_ASPECT_RATIO
|
||||
num_images = min(max(num_images, 1), 4) # 限制 1-4
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
"cqtai_generate_start",
|
||||
prompt_length=len(prompt),
|
||||
model=model,
|
||||
resolution=resolution,
|
||||
aspect_ratio=aspect_ratio,
|
||||
num_images=num_images,
|
||||
)
|
||||
|
||||
# 1. 提交生成任务
|
||||
task_id = await self._submit_task(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
resolution=resolution,
|
||||
aspect_ratio=aspect_ratio,
|
||||
num_images=num_images,
|
||||
files_url=files_url or [],
|
||||
)
|
||||
|
||||
logger.info("cqtai_task_submitted", task_id=task_id)
|
||||
|
||||
# 2. 轮询获取结果
|
||||
result = await self._poll_result(task_id)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"cqtai_generate_success",
|
||||
task_id=task_id,
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
image_count=len(result) if isinstance(result, list) else 1,
|
||||
)
|
||||
|
||||
# 单张图片返回字符串,多张返回列表
|
||||
if num_images == 1 and isinstance(result, list) and len(result) == 1:
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 CQTAI API 是否可用。"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
# 简单的连通性测试
|
||||
response = await client.get(
|
||||
f"{self.api_base}/api/cqt/info/nano",
|
||||
params={"id": "health_check_test"},
|
||||
headers={"Authorization": self.config.api_key},
|
||||
)
|
||||
# 即使返回错误也说明服务可达
|
||||
return response.status_code in (200, 400, 401, 403, 404)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估每张图片成本 (USD)。
|
||||
|
||||
nano-banana: ¥0.1 ≈ $0.014
|
||||
nano-banana-pro: ¥0.2 ≈ $0.028
|
||||
"""
|
||||
model = self.config.model or DEFAULT_MODEL
|
||||
if model == "nano-banana-pro":
|
||||
return 0.028
|
||||
return 0.014
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _submit_task(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
num_images: int,
|
||||
files_url: list[str],
|
||||
) -> str:
|
||||
"""提交图像生成任务,返回任务 ID。"""
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"numImages": num_images,
|
||||
"aspectRatio": aspect_ratio,
|
||||
"filesUrl": files_url,
|
||||
}
|
||||
|
||||
# 可选参数,不传则使用默认值
|
||||
if model != DEFAULT_MODEL:
|
||||
payload["model"] = model
|
||||
if resolution != DEFAULT_RESOLUTION:
|
||||
payload["resolution"] = resolution
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.api_base}/api/cqt/generator/nano",
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": self.config.api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if data.get("code") != 200:
|
||||
raise ValueError(f"CQTAI 任务提交失败: {data.get('msg', '未知错误')}")
|
||||
|
||||
task_id = data.get("data")
|
||||
if not task_id:
|
||||
raise ValueError("CQTAI 未返回任务 ID")
|
||||
|
||||
return task_id
|
||||
|
||||
async def _poll_result(self, task_id: str) -> list[str]:
|
||||
"""轮询获取生成结果。
|
||||
|
||||
Returns:
|
||||
图片 URL 列表
|
||||
"""
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
for attempt in range(MAX_POLL_ATTEMPTS):
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.api_base}/api/cqt/info/nano",
|
||||
params={"id": task_id},
|
||||
headers={"Authorization": self.config.api_key},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if data.get("code") != 200:
|
||||
raise ValueError(f"CQTAI 查询失败: {data.get('msg', '未知错误')}")
|
||||
|
||||
result_data = data.get("data", {})
|
||||
status = result_data.get("status")
|
||||
|
||||
if status == "completed":
|
||||
# 提取图片 URL
|
||||
images = result_data.get("images", [])
|
||||
if not images:
|
||||
# 兼容不同返回格式
|
||||
image_url = result_data.get("imageUrl") or result_data.get("url")
|
||||
if image_url:
|
||||
images = [image_url]
|
||||
|
||||
if not images:
|
||||
raise ValueError("CQTAI 未返回图片 URL")
|
||||
|
||||
return images
|
||||
|
||||
elif status == "failed":
|
||||
error_msg = result_data.get("error", "生成失败")
|
||||
raise ValueError(f"CQTAI 图像生成失败: {error_msg}")
|
||||
|
||||
# 继续等待
|
||||
logger.debug(
|
||||
"cqtai_poll_waiting",
|
||||
task_id=task_id,
|
||||
attempt=attempt + 1,
|
||||
status=status,
|
||||
)
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
raise TimeoutError(f"CQTAI 任务超时: {task_id}")
|
||||
|
||||
@@ -1,73 +1,73 @@
|
||||
"""适配器注册表 - 支持动态注册和工厂创建。"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
|
||||
|
||||
class AdapterRegistry:
|
||||
"""适配器注册表,管理所有已注册的适配器类。"""
|
||||
|
||||
_adapters: dict[str, type["BaseAdapter"]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, adapter_type: str, adapter_name: str):
|
||||
"""装饰器:注册适配器类。
|
||||
|
||||
用法:
|
||||
"""适配器注册表 - 支持动态注册和工厂创建。"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
|
||||
|
||||
class AdapterRegistry:
|
||||
"""适配器注册表,管理所有已注册的适配器类。"""
|
||||
|
||||
_adapters: dict[str, type["BaseAdapter"]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, adapter_type: str, adapter_name: str):
|
||||
"""装饰器:注册适配器类。
|
||||
|
||||
用法:
|
||||
@AdapterRegistry.register("text", "text_primary")
|
||||
class TextPrimaryAdapter(BaseAdapter[StoryOutput]):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(adapter_class: type["BaseAdapter"]):
|
||||
key = f"{adapter_type}:{adapter_name}"
|
||||
cls._adapters[key] = adapter_class
|
||||
# 自动设置类属性
|
||||
adapter_class.adapter_type = adapter_type
|
||||
adapter_class.adapter_name = adapter_name
|
||||
return adapter_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def get(cls, adapter_type: str, adapter_name: str) -> type["BaseAdapter"] | None:
|
||||
"""获取已注册的适配器类。"""
|
||||
key = f"{adapter_type}:{adapter_name}"
|
||||
return cls._adapters.get(key)
|
||||
|
||||
@classmethod
|
||||
def list_adapters(cls, adapter_type: str | None = None) -> list[str]:
|
||||
"""列出所有已注册的适配器。
|
||||
|
||||
Args:
|
||||
adapter_type: 可选,筛选特定类型 (text/image/tts)
|
||||
|
||||
Returns:
|
||||
适配器键列表,格式为 "type:name"
|
||||
"""
|
||||
if adapter_type:
|
||||
return [k for k in cls._adapters if k.startswith(f"{adapter_type}:")]
|
||||
return list(cls._adapters.keys())
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
adapter_type: str,
|
||||
adapter_name: str,
|
||||
config: "AdapterConfig",
|
||||
) -> "BaseAdapter":
|
||||
"""工厂方法:创建适配器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 适配器未注册
|
||||
"""
|
||||
adapter_class = cls.get(adapter_type, adapter_name)
|
||||
if not adapter_class:
|
||||
available = cls.list_adapters(adapter_type)
|
||||
raise ValueError(
|
||||
f"适配器 '{adapter_type}:{adapter_name}' 未注册。"
|
||||
f"可用: {available}"
|
||||
)
|
||||
return adapter_class(config)
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(adapter_class: type["BaseAdapter"]):
|
||||
key = f"{adapter_type}:{adapter_name}"
|
||||
cls._adapters[key] = adapter_class
|
||||
# 自动设置类属性
|
||||
adapter_class.adapter_type = adapter_type
|
||||
adapter_class.adapter_name = adapter_name
|
||||
return adapter_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def get(cls, adapter_type: str, adapter_name: str) -> type["BaseAdapter"] | None:
|
||||
"""获取已注册的适配器类。"""
|
||||
key = f"{adapter_type}:{adapter_name}"
|
||||
return cls._adapters.get(key)
|
||||
|
||||
@classmethod
|
||||
def list_adapters(cls, adapter_type: str | None = None) -> list[str]:
|
||||
"""列出所有已注册的适配器。
|
||||
|
||||
Args:
|
||||
adapter_type: 可选,筛选特定类型 (text/image/tts)
|
||||
|
||||
Returns:
|
||||
适配器键列表,格式为 "type:name"
|
||||
"""
|
||||
if adapter_type:
|
||||
return [k for k in cls._adapters if k.startswith(f"{adapter_type}:")]
|
||||
return list(cls._adapters.keys())
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
adapter_type: str,
|
||||
adapter_name: str,
|
||||
config: "AdapterConfig",
|
||||
) -> "BaseAdapter":
|
||||
"""工厂方法:创建适配器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 适配器未注册
|
||||
"""
|
||||
adapter_class = cls.get(adapter_type, adapter_name)
|
||||
if not adapter_class:
|
||||
available = cls.list_adapters(adapter_type)
|
||||
raise ValueError(
|
||||
f"适配器 '{adapter_type}:{adapter_name}' 未注册。"
|
||||
f"可用: {available}"
|
||||
)
|
||||
return adapter_class(config)
|
||||
|
||||
@@ -1 +1 @@
|
||||
"""Storybook 适配器模块。"""
|
||||
"""Storybook 适配器模块。"""
|
||||
|
||||
@@ -1,195 +1,195 @@
|
||||
"""Storybook 适配器 - 生成可翻页的分页故事书。"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.prompts import (
|
||||
RANDOM_ELEMENTS,
|
||||
SYSTEM_INSTRUCTION_STORYBOOK,
|
||||
USER_PROMPT_STORYBOOK,
|
||||
)
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorybookPage:
|
||||
"""故事书单页。"""
|
||||
|
||||
page_number: int
|
||||
text: str
|
||||
image_prompt: str
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Storybook:
|
||||
"""故事书输出。"""
|
||||
|
||||
title: str
|
||||
main_character: str
|
||||
art_style: str
|
||||
pages: list[StorybookPage] = field(default_factory=list)
|
||||
cover_prompt: str = ""
|
||||
cover_url: str | None = None
|
||||
|
||||
|
||||
@AdapterRegistry.register("storybook", "storybook_primary")
|
||||
class StorybookPrimaryAdapter(BaseAdapter[Storybook]):
|
||||
"""Storybook 生成适配器(默认)。
|
||||
|
||||
生成分页故事书结构,包含每页文字和图像提示词。
|
||||
图像生成需要单独调用 image adapter。
|
||||
"""
|
||||
|
||||
adapter_type = "storybook"
|
||||
adapter_name = "storybook_primary"
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
keywords: str,
|
||||
page_count: int = 6,
|
||||
education_theme: str | None = None,
|
||||
memory_context: str | None = None,
|
||||
**kwargs,
|
||||
) -> Storybook:
|
||||
"""生成分页故事书。
|
||||
|
||||
Args:
|
||||
keywords: 故事关键词
|
||||
page_count: 页数 (4-12)
|
||||
education_theme: 教育主题
|
||||
memory_context: 记忆上下文
|
||||
|
||||
Returns:
|
||||
Storybook 对象,包含标题、页面列表和封面提示词
|
||||
"""
|
||||
start_time = time.time()
|
||||
page_count = max(4, min(page_count, 12)) # 限制 4-12 页
|
||||
|
||||
logger.info(
|
||||
"storybook_generate_start",
|
||||
keywords=keywords,
|
||||
page_count=page_count,
|
||||
has_memory=bool(memory_context),
|
||||
)
|
||||
|
||||
theme = education_theme or "成长"
|
||||
random_element = random.choice(RANDOM_ELEMENTS)
|
||||
|
||||
prompt = USER_PROMPT_STORYBOOK.format(
|
||||
keywords=keywords,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
page_count=page_count,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
|
||||
payload = {
|
||||
"system_instruction": {"parts": [{"text": SYSTEM_INSTRUCTION_STORYBOOK}]},
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"responseMimeType": "application/json",
|
||||
"temperature": 0.95,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
result = await self._call_api(payload)
|
||||
|
||||
candidates = result.get("candidates") or []
|
||||
if not candidates:
|
||||
raise ValueError("Storybook 服务未返回内容")
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts") or []
|
||||
if not parts or "text" not in parts[0]:
|
||||
raise ValueError("Storybook 服务响应缺少文本")
|
||||
|
||||
response_text = parts[0]["text"]
|
||||
clean_json = response_text
|
||||
if response_text.startswith("```json"):
|
||||
clean_json = re.sub(r"^```json\n|```$", "", response_text)
|
||||
|
||||
try:
|
||||
parsed = json.loads(clean_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"Storybook JSON 解析失败: {exc}")
|
||||
|
||||
# 构建 Storybook 对象
|
||||
pages = [
|
||||
StorybookPage(
|
||||
page_number=p.get("page_number", i + 1),
|
||||
text=p.get("text", ""),
|
||||
image_prompt=p.get("image_prompt", ""),
|
||||
)
|
||||
for i, p in enumerate(parsed.get("pages", []))
|
||||
]
|
||||
|
||||
storybook = Storybook(
|
||||
title=parsed.get("title", "未命名故事"),
|
||||
main_character=parsed.get("main_character", ""),
|
||||
art_style=parsed.get("art_style", ""),
|
||||
pages=pages,
|
||||
cover_prompt=parsed.get("cover_prompt", ""),
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"storybook_generate_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
title=storybook.title,
|
||||
page_count=len(pages),
|
||||
)
|
||||
|
||||
return storybook
|
||||
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 API 是否可用。"""
|
||||
try:
|
||||
payload = {
|
||||
"contents": [{"parts": [{"text": "Hi"}]}],
|
||||
"generationConfig": {"maxOutputTokens": 10},
|
||||
}
|
||||
await self._call_api(payload)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估成本(仅文本生成,不含图像)。"""
|
||||
return 0.002 # 比普通故事稍贵,因为输出更长
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_api(self, payload: dict) -> dict:
|
||||
"""调用 API,带重试机制。"""
|
||||
model = self.config.model or "gemini-2.0-flash"
|
||||
url = f"{TEXT_API_BASE}/{model}:generateContent?key={self.config.api_key}"
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
"""Storybook 适配器 - 生成可翻页的分页故事书。"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.prompts import (
|
||||
RANDOM_ELEMENTS,
|
||||
SYSTEM_INSTRUCTION_STORYBOOK,
|
||||
USER_PROMPT_STORYBOOK,
|
||||
)
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorybookPage:
|
||||
"""故事书单页。"""
|
||||
|
||||
page_number: int
|
||||
text: str
|
||||
image_prompt: str
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Storybook:
|
||||
"""故事书输出。"""
|
||||
|
||||
title: str
|
||||
main_character: str
|
||||
art_style: str
|
||||
pages: list[StorybookPage] = field(default_factory=list)
|
||||
cover_prompt: str = ""
|
||||
cover_url: str | None = None
|
||||
|
||||
|
||||
@AdapterRegistry.register("storybook", "storybook_primary")
|
||||
class StorybookPrimaryAdapter(BaseAdapter[Storybook]):
|
||||
"""Storybook 生成适配器(默认)。
|
||||
|
||||
生成分页故事书结构,包含每页文字和图像提示词。
|
||||
图像生成需要单独调用 image adapter。
|
||||
"""
|
||||
|
||||
adapter_type = "storybook"
|
||||
adapter_name = "storybook_primary"
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
keywords: str,
|
||||
page_count: int = 6,
|
||||
education_theme: str | None = None,
|
||||
memory_context: str | None = None,
|
||||
**kwargs,
|
||||
) -> Storybook:
|
||||
"""生成分页故事书。
|
||||
|
||||
Args:
|
||||
keywords: 故事关键词
|
||||
page_count: 页数 (4-12)
|
||||
education_theme: 教育主题
|
||||
memory_context: 记忆上下文
|
||||
|
||||
Returns:
|
||||
Storybook 对象,包含标题、页面列表和封面提示词
|
||||
"""
|
||||
start_time = time.time()
|
||||
page_count = max(4, min(page_count, 12)) # 限制 4-12 页
|
||||
|
||||
logger.info(
|
||||
"storybook_generate_start",
|
||||
keywords=keywords,
|
||||
page_count=page_count,
|
||||
has_memory=bool(memory_context),
|
||||
)
|
||||
|
||||
theme = education_theme or "成长"
|
||||
random_element = random.choice(RANDOM_ELEMENTS)
|
||||
|
||||
prompt = USER_PROMPT_STORYBOOK.format(
|
||||
keywords=keywords,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
page_count=page_count,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
|
||||
payload = {
|
||||
"system_instruction": {"parts": [{"text": SYSTEM_INSTRUCTION_STORYBOOK}]},
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"responseMimeType": "application/json",
|
||||
"temperature": 0.95,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
result = await self._call_api(payload)
|
||||
|
||||
candidates = result.get("candidates") or []
|
||||
if not candidates:
|
||||
raise ValueError("Storybook 服务未返回内容")
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts") or []
|
||||
if not parts or "text" not in parts[0]:
|
||||
raise ValueError("Storybook 服务响应缺少文本")
|
||||
|
||||
response_text = parts[0]["text"]
|
||||
clean_json = response_text
|
||||
if response_text.startswith("```json"):
|
||||
clean_json = re.sub(r"^```json\n|```$", "", response_text)
|
||||
|
||||
try:
|
||||
parsed = json.loads(clean_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"Storybook JSON 解析失败: {exc}")
|
||||
|
||||
# 构建 Storybook 对象
|
||||
pages = [
|
||||
StorybookPage(
|
||||
page_number=p.get("page_number", i + 1),
|
||||
text=p.get("text", ""),
|
||||
image_prompt=p.get("image_prompt", ""),
|
||||
)
|
||||
for i, p in enumerate(parsed.get("pages", []))
|
||||
]
|
||||
|
||||
storybook = Storybook(
|
||||
title=parsed.get("title", "未命名故事"),
|
||||
main_character=parsed.get("main_character", ""),
|
||||
art_style=parsed.get("art_style", ""),
|
||||
pages=pages,
|
||||
cover_prompt=parsed.get("cover_prompt", ""),
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"storybook_generate_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
title=storybook.title,
|
||||
page_count=len(pages),
|
||||
)
|
||||
|
||||
return storybook
|
||||
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 API 是否可用。"""
|
||||
try:
|
||||
payload = {
|
||||
"contents": [{"parts": [{"text": "Hi"}]}],
|
||||
"generationConfig": {"maxOutputTokens": 10},
|
||||
}
|
||||
await self._call_api(payload)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估成本(仅文本生成,不含图像)。"""
|
||||
return 0.002 # 比普通故事稍贵,因为输出更长
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_api(self, payload: dict) -> dict:
|
||||
"""调用 API,带重试机制。"""
|
||||
model = self.config.model or "gemini-2.0-flash"
|
||||
url = f"{TEXT_API_BASE}/{model}:generateContent?key={self.config.api_key}"
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@@ -1 +1 @@
|
||||
"""文本生成适配器。"""
|
||||
"""文本生成适配器。"""
|
||||
|
||||
@@ -1,164 +1,164 @@
|
||||
"""文本生成适配器 (Google Gemini)。"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.prompts import (
|
||||
RANDOM_ELEMENTS,
|
||||
SYSTEM_INSTRUCTION_ENHANCER,
|
||||
SYSTEM_INSTRUCTION_STORYTELLER,
|
||||
USER_PROMPT_ENHANCEMENT,
|
||||
USER_PROMPT_GENERATION,
|
||||
)
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
|
||||
@AdapterRegistry.register("text", "gemini")
|
||||
class GeminiTextAdapter(BaseAdapter[StoryOutput]):
|
||||
"""Google Gemini 文本生成适配器。"""
|
||||
|
||||
adapter_type = "text"
|
||||
adapter_name = "gemini"
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
input_type: Literal["keywords", "full_story"],
|
||||
data: str,
|
||||
education_theme: str | None = None,
|
||||
memory_context: str | None = None,
|
||||
**kwargs,
|
||||
) -> StoryOutput:
|
||||
"""生成或润色故事。"""
|
||||
start_time = time.time()
|
||||
logger.info("request_start", adapter="gemini", input_type=input_type, data_length=len(data))
|
||||
|
||||
theme = education_theme or "成长"
|
||||
random_element = random.choice(RANDOM_ELEMENTS)
|
||||
|
||||
if input_type == "keywords":
|
||||
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
|
||||
prompt = USER_PROMPT_GENERATION.format(
|
||||
keywords=data,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
else:
|
||||
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
|
||||
prompt = USER_PROMPT_ENHANCEMENT.format(
|
||||
full_story=data,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
|
||||
# Gemini API Payload supports 'system_instruction'
|
||||
payload = {
|
||||
"system_instruction": {"parts": [{"text": system_instruction}]},
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"responseMimeType": "application/json",
|
||||
"temperature": 0.95,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
result = await self._call_api(payload)
|
||||
|
||||
candidates = result.get("candidates") or []
|
||||
if not candidates:
|
||||
raise ValueError("Gemini 未返回内容")
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts") or []
|
||||
if not parts or "text" not in parts[0]:
|
||||
raise ValueError("Gemini 响应缺少文本")
|
||||
|
||||
response_text = parts[0]["text"]
|
||||
clean_json = response_text
|
||||
if response_text.startswith("```json"):
|
||||
clean_json = re.sub(r"^```json\n|```$", "", response_text)
|
||||
|
||||
try:
|
||||
parsed = json.loads(clean_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"Gemini 输出 JSON 解析失败: {exc}")
|
||||
|
||||
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
|
||||
if any(field not in parsed for field in required_fields):
|
||||
raise ValueError("Gemini 输出缺少必要字段")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"request_success",
|
||||
adapter="gemini",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
title=parsed["title"],
|
||||
)
|
||||
|
||||
return StoryOutput(
|
||||
mode=parsed["mode"],
|
||||
title=parsed["title"],
|
||||
story_text=parsed["story_text"],
|
||||
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
|
||||
)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 Gemini API 是否可用。"""
|
||||
try:
|
||||
payload = {
|
||||
"contents": [{"parts": [{"text": "Hi"}]}],
|
||||
"generationConfig": {"maxOutputTokens": 10},
|
||||
}
|
||||
await self._call_api(payload)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
return 0.001
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_api(self, payload: dict) -> dict:
|
||||
"""调用 Gemini API。"""
|
||||
model = self.config.model or "gemini-2.0-flash"
|
||||
base_url = self.config.api_base or TEXT_API_BASE
|
||||
|
||||
# 智能补全:
|
||||
# 1. 如果用户填了完整路径 (以 /models 结尾),就直接用 (支持 v1 或 v1beta)
|
||||
if self.config.api_base and base_url.rstrip("/").endswith("/models"):
|
||||
pass
|
||||
# 2. 如果没填路径 (只是域名),默认补全代码适配的 /v1beta/models
|
||||
elif self.config.api_base:
|
||||
base_url = f"{base_url.rstrip('/')}/v1beta/models"
|
||||
|
||||
url = f"{base_url}/{model}:generateContent?key={self.config.api_key}"
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
"""文本生成适配器 (Google Gemini)。"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.prompts import (
|
||||
RANDOM_ELEMENTS,
|
||||
SYSTEM_INSTRUCTION_ENHANCER,
|
||||
SYSTEM_INSTRUCTION_STORYTELLER,
|
||||
USER_PROMPT_ENHANCEMENT,
|
||||
USER_PROMPT_GENERATION,
|
||||
)
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
|
||||
@AdapterRegistry.register("text", "gemini")
|
||||
class GeminiTextAdapter(BaseAdapter[StoryOutput]):
|
||||
"""Google Gemini 文本生成适配器。"""
|
||||
|
||||
adapter_type = "text"
|
||||
adapter_name = "gemini"
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
input_type: Literal["keywords", "full_story"],
|
||||
data: str,
|
||||
education_theme: str | None = None,
|
||||
memory_context: str | None = None,
|
||||
**kwargs,
|
||||
) -> StoryOutput:
|
||||
"""生成或润色故事。"""
|
||||
start_time = time.time()
|
||||
logger.info("request_start", adapter="gemini", input_type=input_type, data_length=len(data))
|
||||
|
||||
theme = education_theme or "成长"
|
||||
random_element = random.choice(RANDOM_ELEMENTS)
|
||||
|
||||
if input_type == "keywords":
|
||||
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
|
||||
prompt = USER_PROMPT_GENERATION.format(
|
||||
keywords=data,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
else:
|
||||
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
|
||||
prompt = USER_PROMPT_ENHANCEMENT.format(
|
||||
full_story=data,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
|
||||
# Gemini API Payload supports 'system_instruction'
|
||||
payload = {
|
||||
"system_instruction": {"parts": [{"text": system_instruction}]},
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"responseMimeType": "application/json",
|
||||
"temperature": 0.95,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
result = await self._call_api(payload)
|
||||
|
||||
candidates = result.get("candidates") or []
|
||||
if not candidates:
|
||||
raise ValueError("Gemini 未返回内容")
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts") or []
|
||||
if not parts or "text" not in parts[0]:
|
||||
raise ValueError("Gemini 响应缺少文本")
|
||||
|
||||
response_text = parts[0]["text"]
|
||||
clean_json = response_text
|
||||
if response_text.startswith("```json"):
|
||||
clean_json = re.sub(r"^```json\n|```$", "", response_text)
|
||||
|
||||
try:
|
||||
parsed = json.loads(clean_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"Gemini 输出 JSON 解析失败: {exc}")
|
||||
|
||||
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
|
||||
if any(field not in parsed for field in required_fields):
|
||||
raise ValueError("Gemini 输出缺少必要字段")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"request_success",
|
||||
adapter="gemini",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
title=parsed["title"],
|
||||
)
|
||||
|
||||
return StoryOutput(
|
||||
mode=parsed["mode"],
|
||||
title=parsed["title"],
|
||||
story_text=parsed["story_text"],
|
||||
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
|
||||
)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 Gemini API 是否可用。"""
|
||||
try:
|
||||
payload = {
|
||||
"contents": [{"parts": [{"text": "Hi"}]}],
|
||||
"generationConfig": {"maxOutputTokens": 10},
|
||||
}
|
||||
await self._call_api(payload)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
return 0.001
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_api(self, payload: dict) -> dict:
|
||||
"""调用 Gemini API。"""
|
||||
model = self.config.model or "gemini-2.0-flash"
|
||||
base_url = self.config.api_base or TEXT_API_BASE
|
||||
|
||||
# 智能补全:
|
||||
# 1. 如果用户填了完整路径 (以 /models 结尾),就直接用 (支持 v1 或 v1beta)
|
||||
if self.config.api_base and base_url.rstrip("/").endswith("/models"):
|
||||
pass
|
||||
# 2. 如果没填路径 (只是域名),默认补全代码适配的 /v1beta/models
|
||||
elif self.config.api_base:
|
||||
base_url = f"{base_url.rstrip('/')}/v1beta/models"
|
||||
|
||||
url = f"{base_url}/{model}:generateContent?key={self.config.api_key}"
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoryOutput:
|
||||
"""故事生成输出。"""
|
||||
mode: Literal["generated", "enhanced"]
|
||||
title: str
|
||||
story_text: str
|
||||
cover_prompt_suggestion: str
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoryOutput:
|
||||
"""故事生成输出。"""
|
||||
mode: Literal["generated", "enhanced"]
|
||||
title: str
|
||||
story_text: str
|
||||
cover_prompt_suggestion: str
|
||||
|
||||
@@ -1,172 +1,172 @@
|
||||
"""OpenAI 文本生成适配器。"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.prompts import (
|
||||
RANDOM_ELEMENTS,
|
||||
SYSTEM_INSTRUCTION_ENHANCER,
|
||||
SYSTEM_INSTRUCTION_STORYTELLER,
|
||||
USER_PROMPT_ENHANCEMENT,
|
||||
USER_PROMPT_GENERATION,
|
||||
)
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
OPENAI_API_BASE = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
|
||||
|
||||
|
||||
@AdapterRegistry.register("text", "openai")
|
||||
class OpenAITextAdapter(BaseAdapter[StoryOutput]):
|
||||
"""OpenAI 文本生成适配器。"""
|
||||
|
||||
adapter_type = "text"
|
||||
adapter_name = "openai"
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
input_type: Literal["keywords", "full_story"],
|
||||
data: str,
|
||||
education_theme: str | None = None,
|
||||
memory_context: str | None = None,
|
||||
**kwargs,
|
||||
) -> StoryOutput:
|
||||
"""生成或润色故事。"""
|
||||
start_time = time.time()
|
||||
logger.info("openai_text_request_start", input_type=input_type, data_length=len(data))
|
||||
|
||||
theme = education_theme or "成长"
|
||||
random_element = random.choice(RANDOM_ELEMENTS)
|
||||
|
||||
if input_type == "keywords":
|
||||
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
|
||||
prompt = USER_PROMPT_GENERATION.format(
|
||||
keywords=data,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
else:
|
||||
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
|
||||
prompt = USER_PROMPT_ENHANCEMENT.format(
|
||||
full_story=data,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
|
||||
model = self.config.model or "gpt-4o-mini"
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_instruction,
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"response_format": {"type": "json_object"},
|
||||
"temperature": 0.95,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
|
||||
result = await self._call_api(payload)
|
||||
|
||||
choices = result.get("choices") or []
|
||||
if not choices:
|
||||
raise ValueError("OpenAI 未返回内容")
|
||||
|
||||
response_text = choices[0].get("message", {}).get("content", "")
|
||||
if not response_text:
|
||||
raise ValueError("OpenAI 响应缺少文本")
|
||||
|
||||
clean_json = response_text
|
||||
if response_text.startswith("```json"):
|
||||
clean_json = re.sub(r"^```json\n|```$", "", response_text)
|
||||
|
||||
try:
|
||||
parsed = json.loads(clean_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"OpenAI 输出 JSON 解析失败: {exc}")
|
||||
|
||||
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
|
||||
if any(field not in parsed for field in required_fields):
|
||||
raise ValueError("OpenAI 输出缺少必要字段")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"openai_text_request_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
title=parsed["title"],
|
||||
mode=parsed["mode"],
|
||||
)
|
||||
|
||||
return StoryOutput(
|
||||
mode=parsed["mode"],
|
||||
title=parsed["title"],
|
||||
story_text=parsed["story_text"],
|
||||
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(httpx.HTTPStatusError),
|
||||
)
|
||||
async def _call_api(self, payload: dict) -> dict:
|
||||
"""调用 OpenAI API,带重试机制。"""
|
||||
url = self.config.api_base or OPENAI_API_BASE
|
||||
|
||||
# 智能补全: 如果用户只填了 Base URL,自动补全路径
|
||||
if self.config.api_base and not url.endswith("/chat/completions"):
|
||||
base = url.rstrip("/")
|
||||
url = f"{base}/chat/completions"
|
||||
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 OpenAI API 是否可用。"""
|
||||
try:
|
||||
payload = {
|
||||
"model": self.config.model or "gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 5,
|
||||
}
|
||||
await self._call_api(payload)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估文本生成成本 (USD)。"""
|
||||
return 0.01
|
||||
"""OpenAI 文本生成适配器。"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.prompts import (
|
||||
RANDOM_ELEMENTS,
|
||||
SYSTEM_INSTRUCTION_ENHANCER,
|
||||
SYSTEM_INSTRUCTION_STORYTELLER,
|
||||
USER_PROMPT_ENHANCEMENT,
|
||||
USER_PROMPT_GENERATION,
|
||||
)
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
OPENAI_API_BASE = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
|
||||
|
||||
|
||||
@AdapterRegistry.register("text", "openai")
|
||||
class OpenAITextAdapter(BaseAdapter[StoryOutput]):
|
||||
"""OpenAI 文本生成适配器。"""
|
||||
|
||||
adapter_type = "text"
|
||||
adapter_name = "openai"
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
input_type: Literal["keywords", "full_story"],
|
||||
data: str,
|
||||
education_theme: str | None = None,
|
||||
memory_context: str | None = None,
|
||||
**kwargs,
|
||||
) -> StoryOutput:
|
||||
"""生成或润色故事。"""
|
||||
start_time = time.time()
|
||||
logger.info("openai_text_request_start", input_type=input_type, data_length=len(data))
|
||||
|
||||
theme = education_theme or "成长"
|
||||
random_element = random.choice(RANDOM_ELEMENTS)
|
||||
|
||||
if input_type == "keywords":
|
||||
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
|
||||
prompt = USER_PROMPT_GENERATION.format(
|
||||
keywords=data,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
else:
|
||||
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
|
||||
prompt = USER_PROMPT_ENHANCEMENT.format(
|
||||
full_story=data,
|
||||
education_theme=theme,
|
||||
random_element=random_element,
|
||||
memory_context=memory_context or "",
|
||||
)
|
||||
|
||||
model = self.config.model or "gpt-4o-mini"
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_instruction,
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"response_format": {"type": "json_object"},
|
||||
"temperature": 0.95,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
|
||||
result = await self._call_api(payload)
|
||||
|
||||
choices = result.get("choices") or []
|
||||
if not choices:
|
||||
raise ValueError("OpenAI 未返回内容")
|
||||
|
||||
response_text = choices[0].get("message", {}).get("content", "")
|
||||
if not response_text:
|
||||
raise ValueError("OpenAI 响应缺少文本")
|
||||
|
||||
clean_json = response_text
|
||||
if response_text.startswith("```json"):
|
||||
clean_json = re.sub(r"^```json\n|```$", "", response_text)
|
||||
|
||||
try:
|
||||
parsed = json.loads(clean_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"OpenAI 输出 JSON 解析失败: {exc}")
|
||||
|
||||
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
|
||||
if any(field not in parsed for field in required_fields):
|
||||
raise ValueError("OpenAI 输出缺少必要字段")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"openai_text_request_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
title=parsed["title"],
|
||||
mode=parsed["mode"],
|
||||
)
|
||||
|
||||
return StoryOutput(
|
||||
mode=parsed["mode"],
|
||||
title=parsed["title"],
|
||||
story_text=parsed["story_text"],
|
||||
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(httpx.HTTPStatusError),
|
||||
)
|
||||
async def _call_api(self, payload: dict) -> dict:
|
||||
"""调用 OpenAI API,带重试机制。"""
|
||||
url = self.config.api_base or OPENAI_API_BASE
|
||||
|
||||
# 智能补全: 如果用户只填了 Base URL,自动补全路径
|
||||
if self.config.api_base and not url.endswith("/chat/completions"):
|
||||
base = url.rstrip("/")
|
||||
url = f"{base}/chat/completions"
|
||||
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 OpenAI API 是否可用。"""
|
||||
try:
|
||||
payload = {
|
||||
"model": self.config.model or "gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 5,
|
||||
}
|
||||
await self._call_api(payload)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估文本生成成本 (USD)。"""
|
||||
return 0.01
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""TTS 语音合成适配器。"""
|
||||
|
||||
from app.services.adapters.tts import edge_tts as _tts_edge_tts_adapter # noqa: F401
|
||||
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
|
||||
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
|
||||
"""TTS 语音合成适配器。"""
|
||||
|
||||
from app.services.adapters.tts import edge_tts as _tts_edge_tts_adapter # noqa: F401
|
||||
from app.services.adapters.tts import elevenlabs as _tts_elevenlabs_adapter # noqa: F401
|
||||
from app.services.adapters.tts import minimax as _tts_minimax_adapter # noqa: F401
|
||||
|
||||
@@ -1,66 +1,66 @@
|
||||
"""EdgeTTS 免费语音生成适配器。"""
|
||||
|
||||
import time
|
||||
|
||||
import edge_tts
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 默认中文女声 (晓晓)
|
||||
DEFAULT_VOICE = "zh-CN-XiaoxiaoNeural"
|
||||
|
||||
|
||||
@AdapterRegistry.register("tts", "edge_tts")
|
||||
class EdgeTTSAdapter(BaseAdapter[bytes]):
|
||||
"""EdgeTTS 语音生成适配器 (Free)。
|
||||
|
||||
不需要 API Key。
|
||||
"""
|
||||
|
||||
adapter_type = "tts"
|
||||
adapter_name = "edge_tts"
|
||||
|
||||
async def execute(self, text: str, **kwargs) -> bytes:
|
||||
"""生成语音。"""
|
||||
# 支持动态指定音色
|
||||
voice = kwargs.get("voice") or self.config.model or DEFAULT_VOICE
|
||||
|
||||
start_time = time.time()
|
||||
logger.info("edge_tts_generate_start", text_length=len(text), voice=voice)
|
||||
|
||||
# EdgeTTS 只能输出到文件,我们需要用临时文件周转一下
|
||||
# 或者直接 capture stream (communicate) 但 edge-tts 库主要面向文件
|
||||
|
||||
# 优化: 使用 communicate 直接获取 bytes,无需磁盘IO
|
||||
communicate = edge_tts.Communicate(text, voice)
|
||||
|
||||
audio_data = b""
|
||||
async for chunk in communicate.stream():
|
||||
if chunk["type"] == "audio":
|
||||
audio_data += chunk["data"]
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"edge_tts_generate_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
audio_size_bytes=len(audio_data),
|
||||
)
|
||||
|
||||
return audio_data
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 EdgeTTS 是否可用 (网络连通性)。"""
|
||||
try:
|
||||
# 简单生成一个词
|
||||
await self.execute("Hi")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
return 0.0 # Free!
|
||||
"""EdgeTTS 免费语音生成适配器。"""
|
||||
|
||||
import time
|
||||
|
||||
import edge_tts
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 默认中文女声 (晓晓)
|
||||
DEFAULT_VOICE = "zh-CN-XiaoxiaoNeural"
|
||||
|
||||
|
||||
@AdapterRegistry.register("tts", "edge_tts")
|
||||
class EdgeTTSAdapter(BaseAdapter[bytes]):
|
||||
"""EdgeTTS 语音生成适配器 (Free)。
|
||||
|
||||
不需要 API Key。
|
||||
"""
|
||||
|
||||
adapter_type = "tts"
|
||||
adapter_name = "edge_tts"
|
||||
|
||||
async def execute(self, text: str, **kwargs) -> bytes:
|
||||
"""生成语音。"""
|
||||
# 支持动态指定音色
|
||||
voice = kwargs.get("voice") or self.config.model or DEFAULT_VOICE
|
||||
|
||||
start_time = time.time()
|
||||
logger.info("edge_tts_generate_start", text_length=len(text), voice=voice)
|
||||
|
||||
# EdgeTTS 只能输出到文件,我们需要用临时文件周转一下
|
||||
# 或者直接 capture stream (communicate) 但 edge-tts 库主要面向文件
|
||||
|
||||
# 优化: 使用 communicate 直接获取 bytes,无需磁盘IO
|
||||
communicate = edge_tts.Communicate(text, voice)
|
||||
|
||||
audio_data = b""
|
||||
async for chunk in communicate.stream():
|
||||
if chunk["type"] == "audio":
|
||||
audio_data += chunk["data"]
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"edge_tts_generate_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
audio_size_bytes=len(audio_data),
|
||||
)
|
||||
|
||||
return audio_data
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 EdgeTTS 是否可用 (网络连通性)。"""
|
||||
try:
|
||||
# 简单生成一个词
|
||||
await self.execute("Hi")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
return 0.0 # Free!
|
||||
|
||||
@@ -1,104 +1,104 @@
|
||||
"""ElevenLabs TTS 语音合成适配器。"""
|
||||
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
ELEVENLABS_API_BASE = "https://api.elevenlabs.io/v1"
|
||||
DEFAULT_VOICE_ID = "21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
|
||||
|
||||
@AdapterRegistry.register("tts", "elevenlabs")
|
||||
class ElevenLabsTtsAdapter(BaseAdapter[bytes]):
|
||||
"""ElevenLabs TTS 语音合成适配器,返回 MP3 bytes。"""
|
||||
|
||||
adapter_type = "tts"
|
||||
adapter_name = "elevenlabs"
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
super().__init__(config)
|
||||
self.api_base = config.api_base or ELEVENLABS_API_BASE
|
||||
|
||||
async def execute(self, text: str, **kwargs) -> bytes:
|
||||
"""将文本转换为语音 MP3 bytes。"""
|
||||
start_time = time.time()
|
||||
logger.info("elevenlabs_tts_start", text_length=len(text))
|
||||
|
||||
voice_id = kwargs.get("voice_id") or DEFAULT_VOICE_ID
|
||||
model_id = kwargs.get("model") or self.config.model or "eleven_multilingual_v2"
|
||||
stability = kwargs.get("stability", 0.5)
|
||||
similarity_boost = kwargs.get("similarity_boost", 0.75)
|
||||
|
||||
url = f"{self.api_base}/text-to-speech/{voice_id}"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"model_id": model_id,
|
||||
"voice_settings": {
|
||||
"stability": stability,
|
||||
"similarity_boost": similarity_boost,
|
||||
},
|
||||
}
|
||||
|
||||
audio_bytes = await self._call_api(url, payload)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"elevenlabs_tts_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
audio_size_bytes=len(audio_bytes),
|
||||
)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 ElevenLabs API 是否可用。"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(
|
||||
f"{self.api_base}/voices",
|
||||
headers={"xi-api-key": self.config.api_key},
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估每千字符成本 (USD)。"""
|
||||
return 0.03
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_api(self, url: str, payload: dict) -> bytes:
|
||||
"""调用 ElevenLabs API,带重试机制。"""
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={
|
||||
"xi-api-key": self.config.api_key,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "audio/mpeg",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
"""ElevenLabs TTS 语音合成适配器。"""
|
||||
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
ELEVENLABS_API_BASE = "https://api.elevenlabs.io/v1"
|
||||
DEFAULT_VOICE_ID = "21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
|
||||
|
||||
@AdapterRegistry.register("tts", "elevenlabs")
|
||||
class ElevenLabsTtsAdapter(BaseAdapter[bytes]):
|
||||
"""ElevenLabs TTS 语音合成适配器,返回 MP3 bytes。"""
|
||||
|
||||
adapter_type = "tts"
|
||||
adapter_name = "elevenlabs"
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
super().__init__(config)
|
||||
self.api_base = config.api_base or ELEVENLABS_API_BASE
|
||||
|
||||
async def execute(self, text: str, **kwargs) -> bytes:
|
||||
"""将文本转换为语音 MP3 bytes。"""
|
||||
start_time = time.time()
|
||||
logger.info("elevenlabs_tts_start", text_length=len(text))
|
||||
|
||||
voice_id = kwargs.get("voice_id") or DEFAULT_VOICE_ID
|
||||
model_id = kwargs.get("model") or self.config.model or "eleven_multilingual_v2"
|
||||
stability = kwargs.get("stability", 0.5)
|
||||
similarity_boost = kwargs.get("similarity_boost", 0.75)
|
||||
|
||||
url = f"{self.api_base}/text-to-speech/{voice_id}"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"model_id": model_id,
|
||||
"voice_settings": {
|
||||
"stability": stability,
|
||||
"similarity_boost": similarity_boost,
|
||||
},
|
||||
}
|
||||
|
||||
audio_bytes = await self._call_api(url, payload)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"elevenlabs_tts_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
audio_size_bytes=len(audio_bytes),
|
||||
)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 ElevenLabs API 是否可用。"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(
|
||||
f"{self.api_base}/voices",
|
||||
headers={"xi-api-key": self.config.api_key},
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估每千字符成本 (USD)。"""
|
||||
return 0.03
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_api(self, url: str, payload: dict) -> bytes:
|
||||
"""调用 ElevenLabs API,带重试机制。"""
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={
|
||||
"xi-api-key": self.config.api_key,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "audio/mpeg",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
@@ -1,149 +1,149 @@
|
||||
"""MiniMax 语音生成适配器 (T2A V2)。"""
|
||||
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# MiniMax API 配置
|
||||
DEFAULT_API_URL = "https://api.minimaxi.com/v1/t2a_v2"
|
||||
DEFAULT_MODEL = "speech-2.6-turbo"
|
||||
|
||||
@AdapterRegistry.register("tts", "minimax")
|
||||
class MiniMaxTTSAdapter(BaseAdapter[bytes]):
|
||||
"""MiniMax 语音生成适配器。
|
||||
|
||||
需要配置:
|
||||
- api_key: MiniMax API Key
|
||||
- minimax_group_id: 可选 (取决于使用的模型/账户类型)
|
||||
"""
|
||||
|
||||
adapter_type = "tts"
|
||||
adapter_name = "minimax"
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
super().__init__(config)
|
||||
self.api_url = DEFAULT_API_URL
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
text: str,
|
||||
voice_id: str | None = None,
|
||||
model: str | None = None,
|
||||
speed: float | None = None,
|
||||
vol: float | None = None,
|
||||
pitch: int | None = None,
|
||||
emotion: str | None = None,
|
||||
**kwargs,
|
||||
) -> bytes:
|
||||
"""生成语音。"""
|
||||
# 1. 优先使用传入参数
|
||||
# 2. 其次使用 Adapter 配置里的 default
|
||||
# 3. 最后使用系统默认值
|
||||
model = model or self.config.model or DEFAULT_MODEL
|
||||
|
||||
cfg = self.config.extra_config or {}
|
||||
|
||||
voice_id = voice_id or cfg.get("voice_id") or "male-qn-qingse"
|
||||
speed = speed if speed is not None else (cfg.get("speed") or 1.0)
|
||||
vol = vol if vol is not None else (cfg.get("vol") or 1.0)
|
||||
pitch = pitch if pitch is not None else (cfg.get("pitch") or 0)
|
||||
emotion = emotion or cfg.get("emotion")
|
||||
group_id = kwargs.get("group_id") or settings.minimax_group_id
|
||||
|
||||
url = self.api_url
|
||||
if group_id:
|
||||
url = f"{self.api_url}?GroupId={group_id}"
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"text": text,
|
||||
"stream": False,
|
||||
"voice_setting": {
|
||||
"voice_id": voice_id,
|
||||
"speed": speed,
|
||||
"vol": vol,
|
||||
"pitch": pitch,
|
||||
},
|
||||
"audio_setting": {
|
||||
"sample_rate": 32000,
|
||||
"bitrate": 128000,
|
||||
"format": "mp3",
|
||||
"channel": 1
|
||||
}
|
||||
}
|
||||
|
||||
if emotion:
|
||||
payload["voice_setting"]["emotion"] = emotion
|
||||
|
||||
start_time = time.time()
|
||||
logger.info("minimax_generate_start", text_length=len(text), model=model)
|
||||
|
||||
result = await self._call_api(url, payload)
|
||||
|
||||
# 错误处理
|
||||
if result.get("base_resp", {}).get("status_code") != 0:
|
||||
error_msg = result.get("base_resp", {}).get("status_msg", "未知错误")
|
||||
raise ValueError(f"MiniMax API 错误: {error_msg}")
|
||||
|
||||
# Hex 解码 (关键逻辑,从 primary.py 迁移)
|
||||
hex_audio = result.get("data", {}).get("audio")
|
||||
if not hex_audio:
|
||||
raise ValueError("API 响应中未找到音频数据 (data.audio)")
|
||||
|
||||
try:
|
||||
audio_bytes = bytes.fromhex(hex_audio)
|
||||
except ValueError:
|
||||
raise ValueError("MiniMax 返回的音频数据不是有效的 Hex 字符串")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"minimax_generate_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
audio_size_bytes=len(audio_bytes),
|
||||
)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 Minimax API 是否可用。"""
|
||||
try:
|
||||
# 尝试生成极短文本
|
||||
await self.execute("Hi")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_api(self, url: str, payload: dict) -> dict:
|
||||
"""调用 API,带重试机制。"""
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
"""MiniMax 语音生成适配器 (T2A V2)。"""
|
||||
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# MiniMax API 配置
|
||||
DEFAULT_API_URL = "https://api.minimaxi.com/v1/t2a_v2"
|
||||
DEFAULT_MODEL = "speech-2.6-turbo"
|
||||
|
||||
@AdapterRegistry.register("tts", "minimax")
|
||||
class MiniMaxTTSAdapter(BaseAdapter[bytes]):
|
||||
"""MiniMax 语音生成适配器。
|
||||
|
||||
需要配置:
|
||||
- api_key: MiniMax API Key
|
||||
- minimax_group_id: 可选 (取决于使用的模型/账户类型)
|
||||
"""
|
||||
|
||||
adapter_type = "tts"
|
||||
adapter_name = "minimax"
|
||||
|
||||
def __init__(self, config: AdapterConfig):
|
||||
super().__init__(config)
|
||||
self.api_url = DEFAULT_API_URL
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
text: str,
|
||||
voice_id: str | None = None,
|
||||
model: str | None = None,
|
||||
speed: float | None = None,
|
||||
vol: float | None = None,
|
||||
pitch: int | None = None,
|
||||
emotion: str | None = None,
|
||||
**kwargs,
|
||||
) -> bytes:
|
||||
"""生成语音。"""
|
||||
# 1. 优先使用传入参数
|
||||
# 2. 其次使用 Adapter 配置里的 default
|
||||
# 3. 最后使用系统默认值
|
||||
model = model or self.config.model or DEFAULT_MODEL
|
||||
|
||||
cfg = self.config.extra_config or {}
|
||||
|
||||
voice_id = voice_id or cfg.get("voice_id") or "male-qn-qingse"
|
||||
speed = speed if speed is not None else (cfg.get("speed") or 1.0)
|
||||
vol = vol if vol is not None else (cfg.get("vol") or 1.0)
|
||||
pitch = pitch if pitch is not None else (cfg.get("pitch") or 0)
|
||||
emotion = emotion or cfg.get("emotion")
|
||||
group_id = kwargs.get("group_id") or settings.minimax_group_id
|
||||
|
||||
url = self.api_url
|
||||
if group_id:
|
||||
url = f"{self.api_url}?GroupId={group_id}"
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"text": text,
|
||||
"stream": False,
|
||||
"voice_setting": {
|
||||
"voice_id": voice_id,
|
||||
"speed": speed,
|
||||
"vol": vol,
|
||||
"pitch": pitch,
|
||||
},
|
||||
"audio_setting": {
|
||||
"sample_rate": 32000,
|
||||
"bitrate": 128000,
|
||||
"format": "mp3",
|
||||
"channel": 1
|
||||
}
|
||||
}
|
||||
|
||||
if emotion:
|
||||
payload["voice_setting"]["emotion"] = emotion
|
||||
|
||||
start_time = time.time()
|
||||
logger.info("minimax_generate_start", text_length=len(text), model=model)
|
||||
|
||||
result = await self._call_api(url, payload)
|
||||
|
||||
# 错误处理
|
||||
if result.get("base_resp", {}).get("status_code") != 0:
|
||||
error_msg = result.get("base_resp", {}).get("status_msg", "未知错误")
|
||||
raise ValueError(f"MiniMax API 错误: {error_msg}")
|
||||
|
||||
# Hex 解码 (关键逻辑,从 primary.py 迁移)
|
||||
hex_audio = result.get("data", {}).get("audio")
|
||||
if not hex_audio:
|
||||
raise ValueError("API 响应中未找到音频数据 (data.audio)")
|
||||
|
||||
try:
|
||||
audio_bytes = bytes.fromhex(hex_audio)
|
||||
except ValueError:
|
||||
raise ValueError("MiniMax 返回的音频数据不是有效的 Hex 字符串")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
"minimax_generate_success",
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
audio_size_bytes=len(audio_bytes),
|
||||
)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查 Minimax API 是否可用。"""
|
||||
try:
|
||||
# 尝试生成极短文本
|
||||
await self.execute("Hi")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_api(self, url: str, payload: dict) -> dict:
|
||||
"""调用 API,带重试机制。"""
|
||||
timeout = self.config.timeout_ms / 1000
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
Reference in New Issue
Block a user