Files
dreamweaver/backend/app/services/adapters/text/openai.py
torin b8d3cb4644
Some checks are pending
Build and Push Docker Images / changes (push) Waiting to run
Build and Push Docker Images / build-backend (push) Blocked by required conditions
Build and Push Docker Images / build-frontend (push) Blocked by required conditions
Build and Push Docker Images / build-admin-frontend (push) Blocked by required conditions
wip: snapshot full local workspace state
2026-04-17 18:58:11 +08:00

173 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""OpenAI 文本生成适配器。"""
import json
import random
import re
import time
from typing import Literal
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.logging import get_logger
from app.core.prompts import (
RANDOM_ELEMENTS,
SYSTEM_INSTRUCTION_ENHANCER,
SYSTEM_INSTRUCTION_STORYTELLER,
USER_PROMPT_ENHANCEMENT,
USER_PROMPT_GENERATION,
)
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
from app.services.adapters.text.models import StoryOutput
logger = get_logger(__name__)
OPENAI_API_BASE = "https://api.openai.com/v1/chat/completions"
@AdapterRegistry.register("text", "openai")
class OpenAITextAdapter(BaseAdapter[StoryOutput]):
"""OpenAI 文本生成适配器。"""
adapter_type = "text"
adapter_name = "openai"
async def execute(
self,
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
**kwargs,
) -> StoryOutput:
"""生成或润色故事。"""
start_time = time.time()
logger.info("openai_text_request_start", input_type=input_type, data_length=len(data))
theme = education_theme or "成长"
random_element = random.choice(RANDOM_ELEMENTS)
if input_type == "keywords":
system_instruction = SYSTEM_INSTRUCTION_STORYTELLER
prompt = USER_PROMPT_GENERATION.format(
keywords=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
else:
system_instruction = SYSTEM_INSTRUCTION_ENHANCER
prompt = USER_PROMPT_ENHANCEMENT.format(
full_story=data,
education_theme=theme,
random_element=random_element,
memory_context=memory_context or "",
)
model = self.config.model or "gpt-4o-mini"
payload = {
"model": model,
"messages": [
{
"role": "system",
"content": system_instruction,
},
{"role": "user", "content": prompt},
],
"response_format": {"type": "json_object"},
"temperature": 0.95,
"top_p": 0.9,
}
result = await self._call_api(payload)
choices = result.get("choices") or []
if not choices:
raise ValueError("OpenAI 未返回内容")
response_text = choices[0].get("message", {}).get("content", "")
if not response_text:
raise ValueError("OpenAI 响应缺少文本")
clean_json = response_text
if response_text.startswith("```json"):
clean_json = re.sub(r"^```json\n|```$", "", response_text)
try:
parsed = json.loads(clean_json)
except json.JSONDecodeError as exc:
raise ValueError(f"OpenAI 输出 JSON 解析失败: {exc}")
required_fields = ["mode", "title", "story_text", "cover_prompt_suggestion"]
if any(field not in parsed for field in required_fields):
raise ValueError("OpenAI 输出缺少必要字段")
elapsed = time.time() - start_time
logger.info(
"openai_text_request_success",
elapsed_seconds=round(elapsed, 2),
title=parsed["title"],
mode=parsed["mode"],
)
return StoryOutput(
mode=parsed["mode"],
title=parsed["title"],
story_text=parsed["story_text"],
cover_prompt_suggestion=parsed["cover_prompt_suggestion"],
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(httpx.HTTPStatusError),
)
async def _call_api(self, payload: dict) -> dict:
"""调用 OpenAI API带重试机制。"""
url = self.config.api_base or OPENAI_API_BASE
# 智能补全: 如果用户只填了 Base URL自动补全路径
if self.config.api_base and not url.endswith("/chat/completions"):
base = url.rstrip("/")
url = f"{base}/chat/completions"
timeout = self.config.timeout_ms / 1000
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
return response.json()
async def health_check(self) -> bool:
"""检查 OpenAI API 是否可用。"""
try:
payload = {
"model": self.config.model or "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5,
}
await self._call_api(payload)
return True
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估文本生成成本 (USD)。"""
return 0.01