135 lines
3.8 KiB
Python
135 lines
3.8 KiB
Python
"""Voice transcription helpers for co-creation sessions."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from io import BytesIO
|
|
|
|
from fastapi import HTTPException
|
|
from openai import AsyncOpenAI
|
|
|
|
from app.core.config import settings
|
|
from app.core.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class VoiceTranscriptionResult:
|
|
"""Normalized transcription result for one uploaded voice turn."""
|
|
|
|
transcript_text: str
|
|
confidence: float | None = None
|
|
provider: str = "demo"
|
|
|
|
|
|
def _normalize_transcript(transcript_text: str) -> str:
|
|
return transcript_text.strip()
|
|
|
|
|
|
async def _transcribe_demo(
|
|
*,
|
|
audio_bytes: bytes,
|
|
mime_type: str | None,
|
|
transcript_hint: str | None,
|
|
) -> VoiceTranscriptionResult:
|
|
hint = _normalize_transcript(transcript_hint or "")
|
|
if hint:
|
|
return VoiceTranscriptionResult(
|
|
transcript_text=hint,
|
|
confidence=1.0,
|
|
provider="demo",
|
|
)
|
|
|
|
if mime_type and mime_type.startswith("text/"):
|
|
text = _normalize_transcript(audio_bytes.decode("utf-8", errors="ignore"))
|
|
if text:
|
|
return VoiceTranscriptionResult(
|
|
transcript_text=text,
|
|
confidence=1.0,
|
|
provider="demo",
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=(
|
|
"当前环境未配置真实语音转写,请先使用文本共创模式,"
|
|
"或在开发模式下提供 transcript_hint。"
|
|
),
|
|
)
|
|
|
|
|
|
async def _transcribe_openai(
|
|
*,
|
|
audio_bytes: bytes,
|
|
file_name: str,
|
|
mime_type: str | None,
|
|
transcript_hint: str | None,
|
|
) -> VoiceTranscriptionResult:
|
|
if not settings.openai_api_key:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="OPENAI_API_KEY 未配置,无法使用 OpenAI 语音转写。",
|
|
)
|
|
|
|
client = AsyncOpenAI(api_key=settings.openai_api_key)
|
|
audio_file = BytesIO(audio_bytes)
|
|
audio_file.name = file_name
|
|
|
|
prompt = transcript_hint.strip() if transcript_hint else None
|
|
|
|
try:
|
|
response = await client.audio.transcriptions.create(
|
|
model=settings.voice_transcription_model,
|
|
file=audio_file,
|
|
language=settings.voice_transcription_language,
|
|
prompt=prompt,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("voice_transcription_openai_failed", error=str(exc))
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="语音转写服务暂时不可用,请稍后重试。",
|
|
) from exc
|
|
|
|
transcript_text = _normalize_transcript(getattr(response, "text", "") or "")
|
|
if not transcript_text:
|
|
raise HTTPException(status_code=502, detail="语音转写结果为空,请重试。")
|
|
|
|
return VoiceTranscriptionResult(
|
|
transcript_text=transcript_text,
|
|
confidence=None,
|
|
provider="openai",
|
|
)
|
|
|
|
|
|
async def transcribe_voice_audio(
|
|
*,
|
|
audio_bytes: bytes,
|
|
file_name: str,
|
|
mime_type: str | None,
|
|
transcript_hint: str | None = None,
|
|
) -> VoiceTranscriptionResult:
|
|
"""Transcribe one uploaded audio turn according to the configured mode."""
|
|
|
|
mode = (settings.voice_transcription_mode or "demo").strip().lower()
|
|
|
|
if mode == "disabled":
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="当前环境已禁用语音转写,请先使用文本共创模式。",
|
|
)
|
|
if mode == "openai":
|
|
return await _transcribe_openai(
|
|
audio_bytes=audio_bytes,
|
|
file_name=file_name,
|
|
mime_type=mime_type,
|
|
transcript_hint=transcript_hint,
|
|
)
|
|
|
|
return await _transcribe_demo(
|
|
audio_bytes=audio_bytes,
|
|
mime_type=mime_type,
|
|
transcript_hint=transcript_hint,
|
|
)
|