"""Antigravity 图像生成适配器。 使用 OpenAI 兼容 API 生成图像。 支持 gemini-3-pro-image 等模型。 """ import base64 import time from typing import Any from openai import AsyncOpenAI from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) from app.core.logging import get_logger from app.services.adapters.base import AdapterConfig, BaseAdapter from app.services.adapters.registry import AdapterRegistry logger = get_logger(__name__) # 默认配置 DEFAULT_API_BASE = "http://127.0.0.1:8045/v1" DEFAULT_MODEL = "gemini-3-pro-image" DEFAULT_SIZE = "1024x1024" # 支持的尺寸映射 SUPPORTED_SIZES = { "1024x1024": "1:1", "1280x720": "16:9", "720x1280": "9:16", "1216x896": "4:3", } @AdapterRegistry.register("image", "antigravity") class AntigravityImageAdapter(BaseAdapter[str]): """Antigravity 图像生成适配器 (OpenAI 兼容 API)。 特点: - 使用 OpenAI 兼容的 chat.completions 端点 - 通过 extra_body.size 指定图像尺寸 - 支持 gemini-3-pro-image 等模型 - 返回图片 URL 或 base64 """ adapter_type = "image" adapter_name = "antigravity" def __init__(self, config: AdapterConfig): super().__init__(config) self.api_base = config.api_base or DEFAULT_API_BASE self.client = AsyncOpenAI( base_url=self.api_base, api_key=config.api_key, timeout=config.timeout_ms / 1000, ) async def execute( self, prompt: str, model: str | None = None, size: str | None = None, num_images: int = 1, **kwargs, ) -> str | list[str]: """根据提示词生成图片,返回 URL 或 base64。 Args: prompt: 图片描述提示词 model: 模型名称 (gemini-3-pro-image / gemini-3-pro-image-16-9 等) size: 图像尺寸 (1024x1024, 1280x720, 720x1280, 1216x896) num_images: 生成图片数量 (暂只支持 1) Returns: 图片 URL 或 base64 字符串 """ # 优先使用传入参数,其次使用 Adapter 配置,最后使用默认值 model = model or self.config.model or DEFAULT_MODEL cfg = self.config.extra_config or {} size = size or cfg.get("size") or DEFAULT_SIZE start_time = time.time() logger.info( "antigravity_generate_start", prompt_length=len(prompt), model=model, size=size, ) # 调用 API image_url = await self._generate_image(prompt, model, size) elapsed = time.time() - start_time logger.info( "antigravity_generate_success", elapsed_seconds=round(elapsed, 2), model=model, ) return image_url async def health_check(self) -> bool: """检查 Antigravity API 是否可用。""" try: # 简单测试连通性 response = await self.client.chat.completions.create( model=self.config.model or DEFAULT_MODEL, messages=[{"role": "user", "content": "test"}], max_tokens=1, ) return True except Exception as e: logger.warning("antigravity_health_check_failed", error=str(e)) return False @property def estimated_cost(self) -> float: """预估每张图片成本 (USD)。 Antigravity 使用 Gemini 模型,成本约 $0.02/张。 """ return 0.02 @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10), retry=retry_if_exception_type((Exception,)), reraise=True, ) async def _generate_image( self, prompt: str, model: str, size: str, ) -> str: """调用 Antigravity API 生成图像。 Returns: 图片 URL 或 base64 data URI """ try: response = await self.client.chat.completions.create( model=model, messages=[{"role": "user", "content": prompt}], extra_body={"size": size}, ) # 解析响应 content = response.choices[0].message.content if not content: raise ValueError("Antigravity 未返回内容") # 尝试解析为图片 URL 或 base64 # 响应可能是纯 URL、base64 或 markdown 格式的图片 image_url = self._extract_image_url(content) if image_url: return image_url raise ValueError(f"Antigravity 响应无法解析为图片: {content[:200]}") except Exception as e: logger.error( "antigravity_generate_error", error=str(e), model=model, ) raise def _extract_image_url(self, content: str) -> str | None: """从响应内容中提取图片 URL。 支持多种格式: - 纯 URL: https://... - Markdown: ![...](https://...) - Base64 data URI: data:image/... - 纯 base64 字符串 """ content = content.strip() # 1. 检查是否为 data URI if content.startswith("data:image/"): return content # 2. 检查是否为纯 URL if content.startswith("http://") or content.startswith("https://"): # 可能有多行,取第一行 return content.split("\n")[0].strip() # 3. 检查 Markdown 图片格式 ![...](url) import re md_match = re.search(r"!\[.*?\]\((https?://[^\)]+)\)", content) if md_match: return md_match.group(1) # 4. 检查是否像 base64 编码的图片数据 if self._looks_like_base64(content): # 假设是 PNG return f"data:image/png;base64,{content}" return None def _looks_like_base64(self, s: str) -> bool: """判断字符串是否看起来像 base64 编码。""" # Base64 只包含 A-Z, a-z, 0-9, +, /, = # 且长度通常较长 if len(s) < 100: return False import re return bool(re.match(r"^[A-Za-z0-9+/=]+$", s.replace("\n", "")))