"""Storybook 适配器 - 生成可翻页的分页故事书。""" import json import random import re import time from dataclasses import dataclass, field import httpx from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) from app.core.logging import get_logger from app.core.prompts import ( RANDOM_ELEMENTS, SYSTEM_INSTRUCTION_STORYBOOK, USER_PROMPT_STORYBOOK, ) from app.services.adapters.base import BaseAdapter from app.services.adapters.registry import AdapterRegistry logger = get_logger(__name__) TEXT_API_BASE = "https://generativelanguage.googleapis.com/v1beta/models" @dataclass class StorybookPage: """故事书单页。""" page_number: int text: str image_prompt: str image_url: str | None = None @dataclass class Storybook: """故事书输出。""" title: str main_character: str art_style: str pages: list[StorybookPage] = field(default_factory=list) cover_prompt: str = "" cover_url: str | None = None @AdapterRegistry.register("storybook", "storybook_primary") class StorybookPrimaryAdapter(BaseAdapter[Storybook]): """Storybook 生成适配器(默认)。 生成分页故事书结构,包含每页文字和图像提示词。 图像生成需要单独调用 image adapter。 """ adapter_type = "storybook" adapter_name = "storybook_primary" async def execute( self, keywords: str, page_count: int = 6, education_theme: str | None = None, memory_context: str | None = None, **kwargs, ) -> Storybook: """生成分页故事书。 Args: keywords: 故事关键词 page_count: 页数 (4-12) education_theme: 教育主题 memory_context: 记忆上下文 Returns: Storybook 对象,包含标题、页面列表和封面提示词 """ start_time = time.time() page_count = max(4, min(page_count, 12)) # 限制 4-12 页 logger.info( "storybook_generate_start", keywords=keywords, page_count=page_count, has_memory=bool(memory_context), ) theme = education_theme or "成长" random_element = random.choice(RANDOM_ELEMENTS) prompt = USER_PROMPT_STORYBOOK.format( keywords=keywords, education_theme=theme, random_element=random_element, page_count=page_count, memory_context=memory_context or "", ) payload = { "system_instruction": {"parts": [{"text": SYSTEM_INSTRUCTION_STORYBOOK}]}, "contents": [{"parts": [{"text": prompt}]}], "generationConfig": { "responseMimeType": "application/json", "temperature": 0.95, "topP": 0.9, }, } result = await self._call_api(payload) candidates = result.get("candidates") or [] if not candidates: raise ValueError("Storybook 服务未返回内容") parts = candidates[0].get("content", {}).get("parts") or [] if not parts or "text" not in parts[0]: raise ValueError("Storybook 服务响应缺少文本") response_text = parts[0]["text"] clean_json = response_text if response_text.startswith("```json"): clean_json = re.sub(r"^```json\n|```$", "", response_text) try: parsed = json.loads(clean_json) except json.JSONDecodeError as exc: raise ValueError(f"Storybook JSON 解析失败: {exc}") # 构建 Storybook 对象 pages = [ StorybookPage( page_number=p.get("page_number", i + 1), text=p.get("text", ""), image_prompt=p.get("image_prompt", ""), ) for i, p in enumerate(parsed.get("pages", [])) ] storybook = Storybook( title=parsed.get("title", "未命名故事"), main_character=parsed.get("main_character", ""), art_style=parsed.get("art_style", ""), pages=pages, cover_prompt=parsed.get("cover_prompt", ""), ) elapsed = time.time() - start_time logger.info( "storybook_generate_success", elapsed_seconds=round(elapsed, 2), title=storybook.title, page_count=len(pages), ) return storybook async def health_check(self) -> bool: """检查 API 是否可用。""" try: payload = { "contents": [{"parts": [{"text": "Hi"}]}], "generationConfig": {"maxOutputTokens": 10}, } await self._call_api(payload) return True except Exception: return False @property def estimated_cost(self) -> float: """预估成本(仅文本生成,不含图像)。""" return 0.002 # 比普通故事稍贵,因为输出更长 @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10), retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)), reraise=True, ) async def _call_api(self, payload: dict) -> dict: """调用 API,带重试机制。""" model = self.config.model or "gemini-2.0-flash" url = f"{TEXT_API_BASE}/{model}:generateContent?key={self.config.api_key}" timeout = self.config.timeout_ms / 1000 async with httpx.AsyncClient(timeout=timeout) as client: response = await client.post(url, json=payload) response.raise_for_status() return response.json()