"""CQTAI nano 图像生成适配器。 支持异步生成 + 轮询获取结果。 API 文档: https://api.cqtai.com """ import asyncio import time import httpx from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) from app.core.logging import get_logger from app.services.adapters.base import AdapterConfig, BaseAdapter from app.services.adapters.registry import AdapterRegistry logger = get_logger(__name__) # 默认配置 DEFAULT_API_BASE = "https://api.cqtai.com" DEFAULT_MODEL = "nano-banana" DEFAULT_RESOLUTION = "2K" DEFAULT_ASPECT_RATIO = "1:1" POLL_INTERVAL_SECONDS = 2 MAX_POLL_ATTEMPTS = 60 # 最多轮询 2 分钟 @AdapterRegistry.register("image", "cqtai") class CQTAIImageAdapter(BaseAdapter[str]): """CQTAI nano 图像生成适配器,返回图片 URL。 特点: - 异步生成 + 轮询获取结果 - 支持 nano-banana (标准) 和 nano-banana-pro (高画质) - 支持多种分辨率和画面比例 - 支持图生图 (filesUrl) """ adapter_type = "image" adapter_name = "cqtai" def __init__(self, config: AdapterConfig): super().__init__(config) self.api_base = config.api_base or DEFAULT_API_BASE async def execute( self, prompt: str, model: str | None = None, resolution: str | None = None, aspect_ratio: str | None = None, num_images: int = 1, files_url: list[str] | None = None, **kwargs, ) -> str | list[str]: """根据提示词生成图片,返回 URL 或 URL 列表。 Args: prompt: 图片描述提示词 model: 模型名称 (nano-banana / nano-banana-pro) resolution: 分辨率 (1K / 2K / 4K) aspect_ratio: 画面比例 (1:1, 16:9, 9:16, 4:3, 3:4 等) num_images: 生成图片数量 (1-4) files_url: 输入图片 URL 列表 (图生图) Returns: 单张图片返回 str,多张返回 list[str] """ # 1. 优先使用传入参数 # 2. 其次使用 Adapter 配置里的 default (extra_config) # 3. 最后使用系统默认值 model = model or self.config.model or DEFAULT_MODEL cfg = self.config.extra_config or {} resolution = resolution or cfg.get("resolution") or DEFAULT_RESOLUTION aspect_ratio = aspect_ratio or cfg.get("aspect_ratio") or DEFAULT_ASPECT_RATIO num_images = min(max(num_images, 1), 4) # 限制 1-4 start_time = time.time() logger.info( "cqtai_generate_start", prompt_length=len(prompt), model=model, resolution=resolution, aspect_ratio=aspect_ratio, num_images=num_images, ) # 1. 提交生成任务 task_id = await self._submit_task( prompt=prompt, model=model, resolution=resolution, aspect_ratio=aspect_ratio, num_images=num_images, files_url=files_url or [], ) logger.info("cqtai_task_submitted", task_id=task_id) # 2. 轮询获取结果 result = await self._poll_result(task_id) elapsed = time.time() - start_time logger.info( "cqtai_generate_success", task_id=task_id, elapsed_seconds=round(elapsed, 2), image_count=len(result) if isinstance(result, list) else 1, ) # 单张图片返回字符串,多张返回列表 if num_images == 1 and isinstance(result, list) and len(result) == 1: return result[0] return result async def health_check(self) -> bool: """检查 CQTAI API 是否可用。""" try: async with httpx.AsyncClient(timeout=10) as client: # 简单的连通性测试 response = await client.get( f"{self.api_base}/api/cqt/info/nano", params={"id": "health_check_test"}, headers={"Authorization": self.config.api_key}, ) # 即使返回错误也说明服务可达 return response.status_code in (200, 400, 401, 403, 404) except Exception: return False @property def estimated_cost(self) -> float: """预估每张图片成本 (USD)。 nano-banana: ¥0.1 ≈ $0.014 nano-banana-pro: ¥0.2 ≈ $0.028 """ model = self.config.model or DEFAULT_MODEL if model == "nano-banana-pro": return 0.028 return 0.014 @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10), retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)), reraise=True, ) async def _submit_task( self, prompt: str, model: str, resolution: str, aspect_ratio: str, num_images: int, files_url: list[str], ) -> str: """提交图像生成任务,返回任务 ID。""" timeout = self.config.timeout_ms / 1000 payload = { "prompt": prompt, "numImages": num_images, "aspectRatio": aspect_ratio, "filesUrl": files_url, } # 可选参数,不传则使用默认值 if model != DEFAULT_MODEL: payload["model"] = model if resolution != DEFAULT_RESOLUTION: payload["resolution"] = resolution async with httpx.AsyncClient(timeout=timeout) as client: response = await client.post( f"{self.api_base}/api/cqt/generator/nano", json=payload, headers={ "Authorization": self.config.api_key, "Content-Type": "application/json", }, ) response.raise_for_status() data = response.json() if data.get("code") != 200: raise ValueError(f"CQTAI 任务提交失败: {data.get('msg', '未知错误')}") task_id = data.get("data") if not task_id: raise ValueError("CQTAI 未返回任务 ID") return task_id async def _poll_result(self, task_id: str) -> list[str]: """轮询获取生成结果。 Returns: 图片 URL 列表 """ timeout = self.config.timeout_ms / 1000 for attempt in range(MAX_POLL_ATTEMPTS): async with httpx.AsyncClient(timeout=timeout) as client: response = await client.get( f"{self.api_base}/api/cqt/info/nano", params={"id": task_id}, headers={"Authorization": self.config.api_key}, ) response.raise_for_status() data = response.json() if data.get("code") != 200: raise ValueError(f"CQTAI 查询失败: {data.get('msg', '未知错误')}") result_data = data.get("data", {}) status = result_data.get("status") if status == "completed": # 提取图片 URL images = result_data.get("images", []) if not images: # 兼容不同返回格式 image_url = result_data.get("imageUrl") or result_data.get("url") if image_url: images = [image_url] if not images: raise ValueError("CQTAI 未返回图片 URL") return images elif status == "failed": error_msg = result_data.get("error", "生成失败") raise ValueError(f"CQTAI 图像生成失败: {error_msg}") # 继续等待 logger.debug( "cqtai_poll_waiting", task_id=task_id, attempt=attempt + 1, status=status, ) await asyncio.sleep(POLL_INTERVAL_SECONDS) raise TimeoutError(f"CQTAI 任务超时: {task_id}")