feat: add ASR provider support for voice co-creation
This commit is contained in:
@@ -4,7 +4,11 @@
|
||||
from app.services.adapters import demo as _demo_adapters # noqa: F401
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
|
||||
# Image adapters
|
||||
# ASR adapters
|
||||
from app.services.adapters.asr import demo as _asr_demo_adapter # noqa: F401
|
||||
from app.services.adapters.asr import openai as _asr_openai_adapter # noqa: F401
|
||||
|
||||
# Image adapters
|
||||
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
|
||||
1
backend/app/services/adapters/asr/__init__.py
Normal file
1
backend/app/services/adapters/asr/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""ASR adapters."""
|
||||
57
backend/app/services/adapters/asr/demo.py
Normal file
57
backend/app/services/adapters/asr/demo.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Demo ASR adapter for local voice co-creation smoke tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.services.adapters.asr.models import TranscriptionOutput
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
|
||||
@AdapterRegistry.register("asr", "demo")
|
||||
class DemoASRAdapter(BaseAdapter[TranscriptionOutput]):
|
||||
"""Return transcript hints or text uploads without external ASR services."""
|
||||
|
||||
adapter_type = "asr"
|
||||
adapter_name = "demo"
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
file_name: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
transcript_hint: str | None = None,
|
||||
**kwargs,
|
||||
) -> TranscriptionOutput:
|
||||
hint = (transcript_hint or "").strip()
|
||||
if hint:
|
||||
return TranscriptionOutput(
|
||||
transcript_text=hint,
|
||||
confidence=1.0,
|
||||
provider=self.adapter_name,
|
||||
)
|
||||
|
||||
if mime_type and mime_type.startswith("text/"):
|
||||
text = audio_bytes.decode("utf-8", errors="ignore").strip()
|
||||
if text:
|
||||
return TranscriptionOutput(
|
||||
transcript_text=text,
|
||||
confidence=1.0,
|
||||
provider=self.adapter_name,
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"当前环境未配置真实语音转写,请先使用文本共创模式,"
|
||||
"或在开发模式下提供 transcript_hint。"
|
||||
),
|
||||
)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
return 0.0
|
||||
11
backend/app/services/adapters/asr/models.py
Normal file
11
backend/app/services/adapters/asr/models.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""ASR adapter result models."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TranscriptionOutput(BaseModel):
|
||||
"""Normalized speech-to-text output from one ASR provider."""
|
||||
|
||||
transcript_text: str
|
||||
confidence: float | None = None
|
||||
provider: str
|
||||
76
backend/app/services/adapters/asr/openai.py
Normal file
76
backend/app/services/adapters/asr/openai.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""OpenAI ASR adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.asr.models import TranscriptionOutput
|
||||
from app.services.adapters.base import BaseAdapter
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@AdapterRegistry.register("asr", "openai_asr")
|
||||
class OpenAIASRAdapter(BaseAdapter[TranscriptionOutput]):
|
||||
"""Transcribe uploaded voice turn audio with OpenAI audio transcription."""
|
||||
|
||||
adapter_type = "asr"
|
||||
adapter_name = "openai_asr"
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
file_name: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
transcript_hint: str | None = None,
|
||||
language: str | None = None,
|
||||
**kwargs,
|
||||
) -> TranscriptionOutput:
|
||||
if not self.config.api_key:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="OPENAI_API_KEY 未配置,无法使用 OpenAI 语音转写。",
|
||||
)
|
||||
|
||||
client = AsyncOpenAI(api_key=self.config.api_key)
|
||||
audio_file = BytesIO(audio_bytes)
|
||||
audio_file.name = file_name or "voice-turn.webm"
|
||||
|
||||
prompt = transcript_hint.strip() if transcript_hint else None
|
||||
model = self.config.model or "gpt-4o-mini-transcribe"
|
||||
|
||||
try:
|
||||
response = await client.audio.transcriptions.create(
|
||||
model=model,
|
||||
file=audio_file,
|
||||
language=language,
|
||||
prompt=prompt,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("openai_asr_failed", error=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="语音转写服务暂时不可用,请稍后重试。",
|
||||
) from exc
|
||||
|
||||
transcript_text = (getattr(response, "text", "") or "").strip()
|
||||
if not transcript_text:
|
||||
raise HTTPException(status_code=502, detail="语音转写结果为空,请重试。")
|
||||
|
||||
return TranscriptionOutput(
|
||||
transcript_text=transcript_text,
|
||||
confidence=None,
|
||||
provider=self.adapter_name,
|
||||
)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
return bool(self.config.api_key)
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
return 0.006
|
||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Literal, Protocol, TypeAlias
|
||||
|
||||
ProviderType: TypeAlias = Literal["text", "image", "tts", "storybook"]
|
||||
ProviderType: TypeAlias = Literal["text", "image", "tts", "storybook", "asr"]
|
||||
|
||||
|
||||
class RoutingStrategy(str, Enum):
|
||||
@@ -36,6 +36,7 @@ class ProviderSettings(Protocol):
|
||||
image_providers: list[str]
|
||||
tts_providers: list[str]
|
||||
storybook_providers: list[str]
|
||||
asr_providers: list[str]
|
||||
enable_demo_providers: bool
|
||||
|
||||
|
||||
@@ -71,6 +72,14 @@ CAPABILITY_POLICIES: dict[ProviderType, CapabilityPolicy] = {
|
||||
default_providers=("storybook_primary",),
|
||||
demo_provider="demo",
|
||||
),
|
||||
"asr": CapabilityPolicy(
|
||||
capability="asr",
|
||||
label="语音识别",
|
||||
description="将孩子上传的语音回合转写为文本输入。",
|
||||
settings_attr="asr_providers",
|
||||
default_providers=("demo",),
|
||||
demo_provider="demo",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -94,6 +103,8 @@ API_KEY_MAP: dict[str, str] = {
|
||||
"antigravity_api_key": "antigravity_api_key",
|
||||
"image_primary": "image_api_key",
|
||||
"image_api_key": "image_api_key",
|
||||
# ASR
|
||||
"openai_asr": "openai_api_key",
|
||||
# TTS
|
||||
"minimax": "minimax_api_key",
|
||||
"minimax_api_key": "minimax_api_key",
|
||||
|
||||
@@ -113,6 +113,14 @@ def _get_default_config(adapter_name: str) -> AdapterConfig | None:
|
||||
timeout_ms=1000,
|
||||
)
|
||||
|
||||
# --- ASR Defaults ---
|
||||
if adapter_name == "openai_asr":
|
||||
return AdapterConfig(
|
||||
api_key=settings.openai_api_key,
|
||||
model=settings.voice_transcription_model,
|
||||
timeout_ms=60000,
|
||||
)
|
||||
|
||||
# --- Text Defaults ---
|
||||
if adapter_name in ("gemini", "text_primary"):
|
||||
return AdapterConfig(
|
||||
@@ -289,7 +297,7 @@ async def _route_with_failover(
|
||||
"""通用 provider failover 路由。
|
||||
|
||||
Args:
|
||||
provider_type: 供应商类型 (text/image/tts/storybook)
|
||||
provider_type: 供应商类型 (text/image/tts/storybook/asr)
|
||||
strategy: 路由策略
|
||||
db: 数据库会话(可选,用于指标收集和熔断检查)
|
||||
user_id: 用户 ID(可选,用于成本追踪和预算检查)
|
||||
@@ -297,7 +305,14 @@ async def _route_with_failover(
|
||||
story_id: 故事 ID(可选,用于关联 provider 事件)
|
||||
**kwargs: 传递给适配器的参数
|
||||
"""
|
||||
providers = await _get_providers_with_config(provider_type)
|
||||
provider_names = kwargs.pop("provider_names", None)
|
||||
if provider_names:
|
||||
providers = [
|
||||
(name, _get_default_config(name) or AdapterConfig(api_key=""), None)
|
||||
for name in provider_names
|
||||
]
|
||||
else:
|
||||
providers = await _get_providers_with_config(provider_type)
|
||||
|
||||
if not providers:
|
||||
raise ValueError(f"No {provider_type} providers configured.")
|
||||
@@ -457,6 +472,35 @@ async def _route_with_failover(
|
||||
raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}")
|
||||
|
||||
|
||||
async def transcribe_audio(
|
||||
audio_bytes: bytes,
|
||||
file_name: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
transcript_hint: str | None = None,
|
||||
language: str | None = None,
|
||||
provider_names: list[str] | None = None,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""语音转写,支持 provider failover。"""
|
||||
from app.services.adapters.asr.models import TranscriptionOutput
|
||||
|
||||
result: TranscriptionOutput = await _route_with_failover(
|
||||
"asr",
|
||||
strategy=strategy,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
audio_bytes=audio_bytes,
|
||||
file_name=file_name,
|
||||
mime_type=mime_type,
|
||||
transcript_hint=transcript_hint,
|
||||
language=language,
|
||||
provider_names=provider_names,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def generate_story_content(
|
||||
input_type: Literal["keywords", "full_story"],
|
||||
data: str,
|
||||
|
||||
@@ -1448,6 +1448,8 @@ async def create_voice_turn_from_upload_service(
|
||||
file_name=file_name,
|
||||
mime_type=mime_type,
|
||||
transcript_hint=transcript_hint,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
)
|
||||
except HTTPException as exc:
|
||||
session.last_error = str(exc.detail)
|
||||
|
||||
@@ -3,15 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
from app.services.provider_router import transcribe_audio
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -23,84 +20,9 @@ class VoiceTranscriptionResult:
|
||||
provider: str = "demo"
|
||||
|
||||
|
||||
def _normalize_transcript(transcript_text: str) -> str:
|
||||
return transcript_text.strip()
|
||||
|
||||
|
||||
async def _transcribe_demo(
|
||||
*,
|
||||
audio_bytes: bytes,
|
||||
mime_type: str | None,
|
||||
transcript_hint: str | None,
|
||||
) -> VoiceTranscriptionResult:
|
||||
hint = _normalize_transcript(transcript_hint or "")
|
||||
if hint:
|
||||
return VoiceTranscriptionResult(
|
||||
transcript_text=hint,
|
||||
confidence=1.0,
|
||||
provider="demo",
|
||||
)
|
||||
|
||||
if mime_type and mime_type.startswith("text/"):
|
||||
text = _normalize_transcript(audio_bytes.decode("utf-8", errors="ignore"))
|
||||
if text:
|
||||
return VoiceTranscriptionResult(
|
||||
transcript_text=text,
|
||||
confidence=1.0,
|
||||
provider="demo",
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"当前环境未配置真实语音转写,请先使用文本共创模式,"
|
||||
"或在开发模式下提供 transcript_hint。"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _transcribe_openai(
|
||||
*,
|
||||
audio_bytes: bytes,
|
||||
file_name: str,
|
||||
mime_type: str | None,
|
||||
transcript_hint: str | None,
|
||||
) -> VoiceTranscriptionResult:
|
||||
if not settings.openai_api_key:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="OPENAI_API_KEY 未配置,无法使用 OpenAI 语音转写。",
|
||||
)
|
||||
|
||||
client = AsyncOpenAI(api_key=settings.openai_api_key)
|
||||
audio_file = BytesIO(audio_bytes)
|
||||
audio_file.name = file_name
|
||||
|
||||
prompt = transcript_hint.strip() if transcript_hint else None
|
||||
|
||||
try:
|
||||
response = await client.audio.transcriptions.create(
|
||||
model=settings.voice_transcription_model,
|
||||
file=audio_file,
|
||||
language=settings.voice_transcription_language,
|
||||
prompt=prompt,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("voice_transcription_openai_failed", error=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="语音转写服务暂时不可用,请稍后重试。",
|
||||
) from exc
|
||||
|
||||
transcript_text = _normalize_transcript(getattr(response, "text", "") or "")
|
||||
if not transcript_text:
|
||||
raise HTTPException(status_code=502, detail="语音转写结果为空,请重试。")
|
||||
|
||||
return VoiceTranscriptionResult(
|
||||
transcript_text=transcript_text,
|
||||
confidence=None,
|
||||
provider="openai",
|
||||
)
|
||||
def _resolve_transcript_hint(transcript_hint: str | None) -> str | None:
|
||||
normalized = (transcript_hint or "").strip()
|
||||
return normalized or None
|
||||
|
||||
|
||||
async def transcribe_voice_audio(
|
||||
@@ -109,26 +31,35 @@ async def transcribe_voice_audio(
|
||||
file_name: str,
|
||||
mime_type: str | None,
|
||||
transcript_hint: str | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> VoiceTranscriptionResult:
|
||||
"""Transcribe one uploaded audio turn according to the configured mode."""
|
||||
"""Transcribe one uploaded audio turn using configured ASR providers."""
|
||||
|
||||
mode = (settings.voice_transcription_mode or "demo").strip().lower()
|
||||
mode = (settings.voice_transcription_mode or "provider").strip().lower()
|
||||
|
||||
if mode == "disabled":
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="当前环境已禁用语音转写,请先使用文本共创模式。",
|
||||
)
|
||||
if mode == "openai":
|
||||
return await _transcribe_openai(
|
||||
audio_bytes=audio_bytes,
|
||||
file_name=file_name,
|
||||
mime_type=mime_type,
|
||||
transcript_hint=transcript_hint,
|
||||
)
|
||||
|
||||
return await _transcribe_demo(
|
||||
hint = _resolve_transcript_hint(transcript_hint)
|
||||
provider_name = "openai_asr" if mode == "openai" else mode
|
||||
strategy_providers = None if mode == "provider" else [provider_name]
|
||||
result = await transcribe_audio(
|
||||
audio_bytes=audio_bytes,
|
||||
file_name=file_name,
|
||||
mime_type=mime_type,
|
||||
transcript_hint=transcript_hint,
|
||||
transcript_hint=hint,
|
||||
language=settings.voice_transcription_language,
|
||||
provider_names=strategy_providers,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return VoiceTranscriptionResult(
|
||||
transcript_text=result.transcript_text,
|
||||
confidence=result.confidence,
|
||||
provider=result.provider,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user