"""文本生成适配器 (Google Gemini)。""" import json import random import re import time from typing import Literal import httpx from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) from app.core.logging import get_logger from app.core.prompts import ( RANDOM_ELEMENTS, SYSTEM_INSTRUCTION_ENHANCER, SYSTEM_INSTRUCTION_STORYTELLER, USER_PROMPT_ENHANCEMENT, USER_PROMPT_GENERATION, ) from app.services.adapters.base import BaseAdapter from app.services.adapters.registry import AdapterRegistry from app.services.adapters.text.models import StoryOutput logger = get_logger(__name__) TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models" @AdapterRegistry.register("text", "gemini") class GeminiTextAdapter(BaseAdapter[StoryOutput]): """Google Gemini 文本生成适配器。""" adapter_type = "text" adapter_name = "gemini" async def execute( self, input_type: Literal["keywords", "full_story"], data: str, education_theme: str | None = None, memory_context: str | None = None, **kwargs, ) -> StoryOutput: """生成或润色故事。""" start_time = time.time() logger.info("request_start", adapter="gemini", input_type=input_type, data_length=len(data)) theme = education_theme or "成长" random_element = random.choice(RANDOM_ELEMENTS) if input_type == "keywords": system_instruction = SYSTEM_INSTRUCTION_STORYTELLER prompt = USER_PROMPT_GENERATION.format( keywords=data, education_theme=theme, random_element=random_element, memory_context=memory_context or "", ) else: system_instruction = SYSTEM_INSTRUCTION_ENHANCER prompt = USER_PROMPT_ENHANCEMENT.format( full_story=data, education_theme=theme, random_element=random_element, memory_context=memory_context or "", ) # Gemini API Payload supports 'system_instruction' payload = { "system_instruction": {"parts": [{"text": system_instruction}]}, "contents": [{"parts": [{"text": prompt}]}], "generationConfig": { "responseMimeType": "application/json", "temperature": 0.95, "topP": 0.9, }, } result = await self._call_api(payload) candidates = result.get("candidates") or [] if not candidates: raise ValueError("Gemini 未返回内容") parts = candidates[0].get("content", {}).get("parts") or [] if not parts or "text" not in parts[0]: raise ValueError("Gemini 响应缺少文本") response_text = parts[0]["text"] clean_json = response_text if response_text.startswith("```json"): clean_json = re.sub(r"^```json\n|```$", "", response_text) try: parsed = json.loads(clean_json) except json.JSONDecodeError as exc: raise ValueError(f"Gemini 输出 JSON 解析失败: {exc}") required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"] if any(field not in parsed for field in required_fields): raise ValueError("Gemini 输出缺少必要字段") elapsed = time.time() - start_time logger.info( "request_success", adapter="gemini", elapsed_seconds=round(elapsed, 2), title=parsed["title"], ) return StoryOutput( mode=parsed["mode"], title=parsed["title"], story_text=parsed["story_text"], cover_prompt_suggestion=parsed["cover_prompt_suggestion"], ) async def health_check(self) -> bool: """检查 Gemini API 是否可用。""" try: payload = { "contents": [{"parts": [{"text": "Hi"}]}], "generationConfig": {"maxOutputTokens": 10}, } await self._call_api(payload) return True except Exception: return False @property def estimated_cost(self) -> float: return 0.001 @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10), retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)), reraise=True, ) async def _call_api(self, payload: dict) -> dict: """调用 Gemini API。""" model = self.config.model or "gemini-2.0-flash" base_url = self.config.api_base or TEXT_API_BASE # 智能补全: # 1. 如果用户填了完整路径 (以 /models 结尾),就直接用 (支持 v1 或 v1beta) if self.config.api_base and base_url.rstrip("/").endswith("/models"): pass # 2. 如果没填路径 (只是域名),默认补全代码适配的 /v1beta/models elif self.config.api_base: base_url = f"{base_url.rstrip('/')}/v1beta/models" url = f"{base_url}/{model}:generateContent?key={self.config.api_key}" timeout = self.config.timeout_ms / 1000 async with httpx.AsyncClient(timeout=timeout) as client: response = await client.post(url, json=payload) response.raise_for_status() return response.json()