Some checks are pending
Build and Push Docker Images / changes (push) Waiting to run
Build and Push Docker Images / build-backend (push) Blocked by required conditions
Build and Push Docker Images / build-frontend (push) Blocked by required conditions
Build and Push Docker Images / build-admin-frontend (push) Blocked by required conditions
253 lines
8.2 KiB
Python
253 lines
8.2 KiB
Python
"""CQTAI nano 图像生成适配器。
|
||
|
||
支持异步生成 + 轮询获取结果。
|
||
API 文档: https://api.cqtai.com
|
||
"""
|
||
|
||
import asyncio
|
||
import time
|
||
|
||
import httpx
|
||
from tenacity import (
|
||
retry,
|
||
retry_if_exception_type,
|
||
stop_after_attempt,
|
||
wait_exponential,
|
||
)
|
||
|
||
from app.core.logging import get_logger
|
||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||
from app.services.adapters.registry import AdapterRegistry
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 默认配置
|
||
DEFAULT_API_BASE = "https://api.cqtai.com"
|
||
DEFAULT_MODEL = "nano-banana"
|
||
DEFAULT_RESOLUTION = "2K"
|
||
DEFAULT_ASPECT_RATIO = "1:1"
|
||
POLL_INTERVAL_SECONDS = 2
|
||
MAX_POLL_ATTEMPTS = 60 # 最多轮询 2 分钟
|
||
|
||
|
||
@AdapterRegistry.register("image", "cqtai")
|
||
class CQTAIImageAdapter(BaseAdapter[str]):
|
||
"""CQTAI nano 图像生成适配器,返回图片 URL。
|
||
|
||
特点:
|
||
- 异步生成 + 轮询获取结果
|
||
- 支持 nano-banana (标准) 和 nano-banana-pro (高画质)
|
||
- 支持多种分辨率和画面比例
|
||
- 支持图生图 (filesUrl)
|
||
"""
|
||
|
||
adapter_type = "image"
|
||
adapter_name = "cqtai"
|
||
|
||
def __init__(self, config: AdapterConfig):
|
||
super().__init__(config)
|
||
self.api_base = config.api_base or DEFAULT_API_BASE
|
||
|
||
async def execute(
|
||
self,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
resolution: str | None = None,
|
||
aspect_ratio: str | None = None,
|
||
num_images: int = 1,
|
||
files_url: list[str] | None = None,
|
||
**kwargs,
|
||
) -> str | list[str]:
|
||
"""根据提示词生成图片,返回 URL 或 URL 列表。
|
||
|
||
Args:
|
||
prompt: 图片描述提示词
|
||
model: 模型名称 (nano-banana / nano-banana-pro)
|
||
resolution: 分辨率 (1K / 2K / 4K)
|
||
aspect_ratio: 画面比例 (1:1, 16:9, 9:16, 4:3, 3:4 等)
|
||
num_images: 生成图片数量 (1-4)
|
||
files_url: 输入图片 URL 列表 (图生图)
|
||
|
||
Returns:
|
||
单张图片返回 str,多张返回 list[str]
|
||
"""
|
||
# 1. 优先使用传入参数
|
||
# 2. 其次使用 Adapter 配置里的 default (extra_config)
|
||
# 3. 最后使用系统默认值
|
||
model = model or self.config.model or DEFAULT_MODEL
|
||
|
||
cfg = self.config.extra_config or {}
|
||
resolution = resolution or cfg.get("resolution") or DEFAULT_RESOLUTION
|
||
aspect_ratio = aspect_ratio or cfg.get("aspect_ratio") or DEFAULT_ASPECT_RATIO
|
||
num_images = min(max(num_images, 1), 4) # 限制 1-4
|
||
|
||
start_time = time.time()
|
||
logger.info(
|
||
"cqtai_generate_start",
|
||
prompt_length=len(prompt),
|
||
model=model,
|
||
resolution=resolution,
|
||
aspect_ratio=aspect_ratio,
|
||
num_images=num_images,
|
||
)
|
||
|
||
# 1. 提交生成任务
|
||
task_id = await self._submit_task(
|
||
prompt=prompt,
|
||
model=model,
|
||
resolution=resolution,
|
||
aspect_ratio=aspect_ratio,
|
||
num_images=num_images,
|
||
files_url=files_url or [],
|
||
)
|
||
|
||
logger.info("cqtai_task_submitted", task_id=task_id)
|
||
|
||
# 2. 轮询获取结果
|
||
result = await self._poll_result(task_id)
|
||
|
||
elapsed = time.time() - start_time
|
||
logger.info(
|
||
"cqtai_generate_success",
|
||
task_id=task_id,
|
||
elapsed_seconds=round(elapsed, 2),
|
||
image_count=len(result) if isinstance(result, list) else 1,
|
||
)
|
||
|
||
# 单张图片返回字符串,多张返回列表
|
||
if num_images == 1 and isinstance(result, list) and len(result) == 1:
|
||
return result[0]
|
||
return result
|
||
|
||
async def health_check(self) -> bool:
|
||
"""检查 CQTAI API 是否可用。"""
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10) as client:
|
||
# 简单的连通性测试
|
||
response = await client.get(
|
||
f"{self.api_base}/api/cqt/info/nano",
|
||
params={"id": "health_check_test"},
|
||
headers={"Authorization": self.config.api_key},
|
||
)
|
||
# 即使返回错误也说明服务可达
|
||
return response.status_code in (200, 400, 401, 403, 404)
|
||
except Exception:
|
||
return False
|
||
|
||
@property
|
||
def estimated_cost(self) -> float:
|
||
"""预估每张图片成本 (USD)。
|
||
|
||
nano-banana: ¥0.1 ≈ $0.014
|
||
nano-banana-pro: ¥0.2 ≈ $0.028
|
||
"""
|
||
model = self.config.model or DEFAULT_MODEL
|
||
if model == "nano-banana-pro":
|
||
return 0.028
|
||
return 0.014
|
||
|
||
@retry(
|
||
stop=stop_after_attempt(3),
|
||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||
reraise=True,
|
||
)
|
||
async def _submit_task(
|
||
self,
|
||
prompt: str,
|
||
model: str,
|
||
resolution: str,
|
||
aspect_ratio: str,
|
||
num_images: int,
|
||
files_url: list[str],
|
||
) -> str:
|
||
"""提交图像生成任务,返回任务 ID。"""
|
||
timeout = self.config.timeout_ms / 1000
|
||
|
||
payload = {
|
||
"prompt": prompt,
|
||
"numImages": num_images,
|
||
"aspectRatio": aspect_ratio,
|
||
"filesUrl": files_url,
|
||
}
|
||
|
||
# 可选参数,不传则使用默认值
|
||
if model != DEFAULT_MODEL:
|
||
payload["model"] = model
|
||
if resolution != DEFAULT_RESOLUTION:
|
||
payload["resolution"] = resolution
|
||
|
||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||
response = await client.post(
|
||
f"{self.api_base}/api/cqt/generator/nano",
|
||
json=payload,
|
||
headers={
|
||
"Authorization": self.config.api_key,
|
||
"Content-Type": "application/json",
|
||
},
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
|
||
if data.get("code") != 200:
|
||
raise ValueError(f"CQTAI 任务提交失败: {data.get('msg', '未知错误')}")
|
||
|
||
task_id = data.get("data")
|
||
if not task_id:
|
||
raise ValueError("CQTAI 未返回任务 ID")
|
||
|
||
return task_id
|
||
|
||
async def _poll_result(self, task_id: str) -> list[str]:
|
||
"""轮询获取生成结果。
|
||
|
||
Returns:
|
||
图片 URL 列表
|
||
"""
|
||
timeout = self.config.timeout_ms / 1000
|
||
|
||
for attempt in range(MAX_POLL_ATTEMPTS):
|
||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||
response = await client.get(
|
||
f"{self.api_base}/api/cqt/info/nano",
|
||
params={"id": task_id},
|
||
headers={"Authorization": self.config.api_key},
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
|
||
if data.get("code") != 200:
|
||
raise ValueError(f"CQTAI 查询失败: {data.get('msg', '未知错误')}")
|
||
|
||
result_data = data.get("data", {})
|
||
status = result_data.get("status")
|
||
|
||
if status == "completed":
|
||
# 提取图片 URL
|
||
images = result_data.get("images", [])
|
||
if not images:
|
||
# 兼容不同返回格式
|
||
image_url = result_data.get("imageUrl") or result_data.get("url")
|
||
if image_url:
|
||
images = [image_url]
|
||
|
||
if not images:
|
||
raise ValueError("CQTAI 未返回图片 URL")
|
||
|
||
return images
|
||
|
||
elif status == "failed":
|
||
error_msg = result_data.get("error", "生成失败")
|
||
raise ValueError(f"CQTAI 图像生成失败: {error_msg}")
|
||
|
||
# 继续等待
|
||
logger.debug(
|
||
"cqtai_poll_waiting",
|
||
task_id=task_id,
|
||
attempt=attempt + 1,
|
||
status=status,
|
||
)
|
||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||
|
||
raise TimeoutError(f"CQTAI 任务超时: {task_id}")
|