108 lines
3.6 KiB
Python
108 lines
3.6 KiB
Python
"""OpenAI ASR adapter."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import re
|
||
from io import BytesIO
|
||
|
||
from fastapi import HTTPException
|
||
from openai import APIConnectionError, APIStatusError, APITimeoutError, AsyncOpenAI
|
||
|
||
from app.core.logging import get_logger
|
||
from app.services.adapters.asr.models import TranscriptionOutput
|
||
from app.services.adapters.base import BaseAdapter
|
||
from app.services.adapters.registry import AdapterRegistry
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def _mask_openai_error(message: str) -> str:
|
||
"""Avoid leaking bearer tokens while keeping ASR smoke failures actionable."""
|
||
|
||
sanitized = message.replace("\n", " ").strip()
|
||
sanitized = re.sub(r"Bearer\s+[A-Za-z0-9._-]+", "Bearer ***", sanitized)
|
||
return re.sub(r"sk-[A-Za-z0-9_-]+", "sk-***", sanitized)
|
||
|
||
|
||
@AdapterRegistry.register("asr", "openai_asr")
|
||
class OpenAIASRAdapter(BaseAdapter[TranscriptionOutput]):
|
||
"""Transcribe uploaded voice turn audio with OpenAI audio transcription."""
|
||
|
||
adapter_type = "asr"
|
||
adapter_name = "openai_asr"
|
||
|
||
async def execute(
|
||
self,
|
||
audio_bytes: bytes,
|
||
file_name: str | None = None,
|
||
mime_type: str | None = None,
|
||
transcript_hint: str | None = None,
|
||
language: str | None = None,
|
||
**kwargs,
|
||
) -> TranscriptionOutput:
|
||
if not self.config.api_key:
|
||
raise HTTPException(
|
||
status_code=503,
|
||
detail="OPENAI_API_KEY 未配置,无法使用 OpenAI 语音转写。",
|
||
)
|
||
|
||
client = AsyncOpenAI(
|
||
api_key=self.config.api_key,
|
||
base_url=self.config.api_base or None,
|
||
timeout=self.config.timeout_ms / 1000,
|
||
)
|
||
audio_file = BytesIO(audio_bytes)
|
||
audio_file.name = file_name or "voice-turn.webm"
|
||
|
||
prompt = transcript_hint.strip() if transcript_hint else None
|
||
model = self.config.model or "gpt-4o-mini-transcribe"
|
||
|
||
try:
|
||
response = await client.audio.transcriptions.create(
|
||
model=model,
|
||
file=audio_file,
|
||
language=language,
|
||
prompt=prompt,
|
||
)
|
||
except APIStatusError as exc:
|
||
detail = _mask_openai_error(getattr(exc, "message", str(exc)))
|
||
logger.warning(
|
||
"openai_asr_failed",
|
||
status_code=exc.status_code,
|
||
error=detail,
|
||
)
|
||
raise HTTPException(
|
||
status_code=503,
|
||
detail=f"OpenAI ASR 调用失败(HTTP {exc.status_code}):{detail}",
|
||
) from exc
|
||
except (APITimeoutError, APIConnectionError) as exc:
|
||
detail = _mask_openai_error(str(exc))
|
||
logger.warning("openai_asr_failed", error=detail)
|
||
raise HTTPException(
|
||
status_code=503,
|
||
detail=f"OpenAI ASR 网络连接失败:{detail}",
|
||
) from exc
|
||
except Exception as exc:
|
||
logger.warning("openai_asr_failed", error=str(exc))
|
||
raise HTTPException(
|
||
status_code=503,
|
||
detail=f"OpenAI ASR 调用异常:{_mask_openai_error(str(exc))}",
|
||
) from exc
|
||
|
||
transcript_text = (getattr(response, "text", "") or "").strip()
|
||
if not transcript_text:
|
||
raise HTTPException(status_code=502, detail="语音转写结果为空,请重试。")
|
||
|
||
return TranscriptionOutput(
|
||
transcript_text=transcript_text,
|
||
confidence=None,
|
||
provider=self.adapter_name,
|
||
)
|
||
|
||
async def health_check(self) -> bool:
|
||
return bool(self.config.api_key)
|
||
|
||
@property
|
||
def estimated_cost(self) -> float:
|
||
return 0.006
|