feat: add ASR provider support for voice co-creation

This commit is contained in:
2026-04-24 17:58:49 +08:00
parent 7e450aa5fc
commit 3805c18622
22 changed files with 471 additions and 126 deletions

View File

@@ -17,7 +17,7 @@ router = APIRouter(dependencies=[Depends(admin_guard)])
class ProviderCreate(BaseModel):
name: str
type: str = Field(..., pattern="^(text|image|tts|storybook)$")
type: str = Field(..., pattern="^(text|image|tts|storybook|asr)$")
adapter: str
model: str | None = None
api_base: str | None = None

View File

@@ -58,6 +58,7 @@ class Settings(BaseSettings):
image_providers: list[str] = Field(default_factory=lambda: ["cqtai"])
tts_providers: list[str] = Field(default_factory=lambda: ["minimax", "elevenlabs", "edge_tts"])
storybook_providers: list[str] = Field(default_factory=lambda: ["storybook_primary"])
asr_providers: list[str] = Field(default_factory=lambda: ["demo"])
enable_demo_providers: bool = Field(
False,
description="Enable local deterministic demo providers for portfolio demos",
@@ -71,8 +72,8 @@ class Settings(BaseSettings):
description="Directory for persisted voice co-creation session assets",
)
voice_transcription_mode: str = Field(
"demo",
description="Voice transcription mode: demo, openai, or disabled",
"provider",
description="Voice transcription mode: provider or disabled; provider order is controlled by ASR_PROVIDERS",
)
voice_transcription_model: str = Field(
"gpt-4o-mini-transcribe",

View File

@@ -19,7 +19,7 @@ class Provider(Base):
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
name: Mapped[str] = mapped_column(String(100), nullable=False)
type: Mapped[str] = mapped_column(String(50), nullable=False) # text/image/tts/storybook
type: Mapped[str] = mapped_column(String(50), nullable=False) # text/image/tts/storybook/asr
adapter: Mapped[str] = mapped_column(String(100), nullable=False)
model: Mapped[str] = mapped_column(String(200), nullable=True)
api_base: Mapped[str] = mapped_column(String(300), nullable=True)
@@ -97,7 +97,7 @@ class CostRecord(Base):
user_id: Mapped[str] = mapped_column(String(36), nullable=False, index=True)
provider_id: Mapped[str] = mapped_column(String(36), nullable=True) # 可能是环境变量配置
provider_name: Mapped[str] = mapped_column(String(100), nullable=False)
capability: Mapped[str] = mapped_column(String(50), nullable=False) # text/image/tts/storybook
capability: Mapped[str] = mapped_column(String(50), nullable=False) # text/image/tts/storybook/asr
estimated_cost: Mapped[Decimal] = mapped_column(Numeric(10, 6), nullable=False)
timestamp: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow, index=True

View File

@@ -4,7 +4,11 @@
from app.services.adapters import demo as _demo_adapters # noqa: F401
from app.services.adapters.base import AdapterConfig, BaseAdapter
# Image adapters
# ASR adapters
from app.services.adapters.asr import demo as _asr_demo_adapter # noqa: F401
from app.services.adapters.asr import openai as _asr_openai_adapter # noqa: F401
# Image adapters
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.registry import AdapterRegistry

View File

@@ -0,0 +1 @@
"""ASR adapters."""

View File

@@ -0,0 +1,57 @@
"""Demo ASR adapter for local voice co-creation smoke tests."""
from __future__ import annotations
from fastapi import HTTPException
from app.services.adapters.asr.models import TranscriptionOutput
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
@AdapterRegistry.register("asr", "demo")
class DemoASRAdapter(BaseAdapter[TranscriptionOutput]):
"""Return transcript hints or text uploads without external ASR services."""
adapter_type = "asr"
adapter_name = "demo"
async def execute(
self,
audio_bytes: bytes,
file_name: str | None = None,
mime_type: str | None = None,
transcript_hint: str | None = None,
**kwargs,
) -> TranscriptionOutput:
hint = (transcript_hint or "").strip()
if hint:
return TranscriptionOutput(
transcript_text=hint,
confidence=1.0,
provider=self.adapter_name,
)
if mime_type and mime_type.startswith("text/"):
text = audio_bytes.decode("utf-8", errors="ignore").strip()
if text:
return TranscriptionOutput(
transcript_text=text,
confidence=1.0,
provider=self.adapter_name,
)
raise HTTPException(
status_code=503,
detail=(
"当前环境未配置真实语音转写,请先使用文本共创模式,"
"或在开发模式下提供 transcript_hint。"
),
)
async def health_check(self) -> bool:
return True
@property
def estimated_cost(self) -> float:
return 0.0

View File

@@ -0,0 +1,11 @@
"""ASR adapter result models."""
from pydantic import BaseModel
class TranscriptionOutput(BaseModel):
"""Normalized speech-to-text output from one ASR provider."""
transcript_text: str
confidence: float | None = None
provider: str

View File

@@ -0,0 +1,76 @@
"""OpenAI ASR adapter."""
from __future__ import annotations
from io import BytesIO
from fastapi import HTTPException
from openai import AsyncOpenAI
from app.core.logging import get_logger
from app.services.adapters.asr.models import TranscriptionOutput
from app.services.adapters.base import BaseAdapter
from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
@AdapterRegistry.register("asr", "openai_asr")
class OpenAIASRAdapter(BaseAdapter[TranscriptionOutput]):
"""Transcribe uploaded voice turn audio with OpenAI audio transcription."""
adapter_type = "asr"
adapter_name = "openai_asr"
async def execute(
self,
audio_bytes: bytes,
file_name: str | None = None,
mime_type: str | None = None,
transcript_hint: str | None = None,
language: str | None = None,
**kwargs,
) -> TranscriptionOutput:
if not self.config.api_key:
raise HTTPException(
status_code=503,
detail="OPENAI_API_KEY 未配置,无法使用 OpenAI 语音转写。",
)
client = AsyncOpenAI(api_key=self.config.api_key)
audio_file = BytesIO(audio_bytes)
audio_file.name = file_name or "voice-turn.webm"
prompt = transcript_hint.strip() if transcript_hint else None
model = self.config.model or "gpt-4o-mini-transcribe"
try:
response = await client.audio.transcriptions.create(
model=model,
file=audio_file,
language=language,
prompt=prompt,
)
except Exception as exc:
logger.warning("openai_asr_failed", error=str(exc))
raise HTTPException(
status_code=503,
detail="语音转写服务暂时不可用,请稍后重试。",
) from exc
transcript_text = (getattr(response, "text", "") or "").strip()
if not transcript_text:
raise HTTPException(status_code=502, detail="语音转写结果为空,请重试。")
return TranscriptionOutput(
transcript_text=transcript_text,
confidence=None,
provider=self.adapter_name,
)
async def health_check(self) -> bool:
return bool(self.config.api_key)
@property
def estimated_cost(self) -> float:
return 0.006

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass
from enum import Enum
from typing import Literal, Protocol, TypeAlias
ProviderType: TypeAlias = Literal["text", "image", "tts", "storybook"]
ProviderType: TypeAlias = Literal["text", "image", "tts", "storybook", "asr"]
class RoutingStrategy(str, Enum):
@@ -36,6 +36,7 @@ class ProviderSettings(Protocol):
image_providers: list[str]
tts_providers: list[str]
storybook_providers: list[str]
asr_providers: list[str]
enable_demo_providers: bool
@@ -71,6 +72,14 @@ CAPABILITY_POLICIES: dict[ProviderType, CapabilityPolicy] = {
default_providers=("storybook_primary",),
demo_provider="demo",
),
"asr": CapabilityPolicy(
capability="asr",
label="语音识别",
description="将孩子上传的语音回合转写为文本输入。",
settings_attr="asr_providers",
default_providers=("demo",),
demo_provider="demo",
),
}
@@ -94,6 +103,8 @@ API_KEY_MAP: dict[str, str] = {
"antigravity_api_key": "antigravity_api_key",
"image_primary": "image_api_key",
"image_api_key": "image_api_key",
# ASR
"openai_asr": "openai_api_key",
# TTS
"minimax": "minimax_api_key",
"minimax_api_key": "minimax_api_key",

View File

@@ -113,6 +113,14 @@ def _get_default_config(adapter_name: str) -> AdapterConfig | None:
timeout_ms=1000,
)
# --- ASR Defaults ---
if adapter_name == "openai_asr":
return AdapterConfig(
api_key=settings.openai_api_key,
model=settings.voice_transcription_model,
timeout_ms=60000,
)
# --- Text Defaults ---
if adapter_name in ("gemini", "text_primary"):
return AdapterConfig(
@@ -289,7 +297,7 @@ async def _route_with_failover(
"""通用 provider failover 路由。
Args:
provider_type: 供应商类型 (text/image/tts/storybook)
provider_type: 供应商类型 (text/image/tts/storybook/asr)
strategy: 路由策略
db: 数据库会话(可选,用于指标收集和熔断检查)
user_id: 用户 ID可选用于成本追踪和预算检查
@@ -297,7 +305,14 @@ async def _route_with_failover(
story_id: 故事 ID可选用于关联 provider 事件)
**kwargs: 传递给适配器的参数
"""
providers = await _get_providers_with_config(provider_type)
provider_names = kwargs.pop("provider_names", None)
if provider_names:
providers = [
(name, _get_default_config(name) or AdapterConfig(api_key=""), None)
for name in provider_names
]
else:
providers = await _get_providers_with_config(provider_type)
if not providers:
raise ValueError(f"No {provider_type} providers configured.")
@@ -457,6 +472,35 @@ async def _route_with_failover(
raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}")
async def transcribe_audio(
audio_bytes: bytes,
file_name: str | None = None,
mime_type: str | None = None,
transcript_hint: str | None = None,
language: str | None = None,
provider_names: list[str] | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
):
"""语音转写,支持 provider failover。"""
from app.services.adapters.asr.models import TranscriptionOutput
result: TranscriptionOutput = await _route_with_failover(
"asr",
strategy=strategy,
db=db,
user_id=user_id,
audio_bytes=audio_bytes,
file_name=file_name,
mime_type=mime_type,
transcript_hint=transcript_hint,
language=language,
provider_names=provider_names,
)
return result
async def generate_story_content(
input_type: Literal["keywords", "full_story"],
data: str,

View File

@@ -1448,6 +1448,8 @@ async def create_voice_turn_from_upload_service(
file_name=file_name,
mime_type=mime_type,
transcript_hint=transcript_hint,
db=db,
user_id=user_id,
)
except HTTPException as exc:
session.last_error = str(exc.detail)

View File

@@ -3,15 +3,12 @@
from __future__ import annotations
from dataclasses import dataclass
from io import BytesIO
from fastapi import HTTPException
from openai import AsyncOpenAI
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
from app.services.provider_router import transcribe_audio
@dataclass(frozen=True)
@@ -23,84 +20,9 @@ class VoiceTranscriptionResult:
provider: str = "demo"
def _normalize_transcript(transcript_text: str) -> str:
return transcript_text.strip()
async def _transcribe_demo(
*,
audio_bytes: bytes,
mime_type: str | None,
transcript_hint: str | None,
) -> VoiceTranscriptionResult:
hint = _normalize_transcript(transcript_hint or "")
if hint:
return VoiceTranscriptionResult(
transcript_text=hint,
confidence=1.0,
provider="demo",
)
if mime_type and mime_type.startswith("text/"):
text = _normalize_transcript(audio_bytes.decode("utf-8", errors="ignore"))
if text:
return VoiceTranscriptionResult(
transcript_text=text,
confidence=1.0,
provider="demo",
)
raise HTTPException(
status_code=503,
detail=(
"当前环境未配置真实语音转写,请先使用文本共创模式,"
"或在开发模式下提供 transcript_hint。"
),
)
async def _transcribe_openai(
*,
audio_bytes: bytes,
file_name: str,
mime_type: str | None,
transcript_hint: str | None,
) -> VoiceTranscriptionResult:
if not settings.openai_api_key:
raise HTTPException(
status_code=503,
detail="OPENAI_API_KEY 未配置,无法使用 OpenAI 语音转写。",
)
client = AsyncOpenAI(api_key=settings.openai_api_key)
audio_file = BytesIO(audio_bytes)
audio_file.name = file_name
prompt = transcript_hint.strip() if transcript_hint else None
try:
response = await client.audio.transcriptions.create(
model=settings.voice_transcription_model,
file=audio_file,
language=settings.voice_transcription_language,
prompt=prompt,
)
except Exception as exc:
logger.warning("voice_transcription_openai_failed", error=str(exc))
raise HTTPException(
status_code=503,
detail="语音转写服务暂时不可用,请稍后重试。",
) from exc
transcript_text = _normalize_transcript(getattr(response, "text", "") or "")
if not transcript_text:
raise HTTPException(status_code=502, detail="语音转写结果为空,请重试。")
return VoiceTranscriptionResult(
transcript_text=transcript_text,
confidence=None,
provider="openai",
)
def _resolve_transcript_hint(transcript_hint: str | None) -> str | None:
normalized = (transcript_hint or "").strip()
return normalized or None
async def transcribe_voice_audio(
@@ -109,26 +31,35 @@ async def transcribe_voice_audio(
file_name: str,
mime_type: str | None,
transcript_hint: str | None = None,
db: AsyncSession | None = None,
user_id: str | None = None,
) -> VoiceTranscriptionResult:
"""Transcribe one uploaded audio turn according to the configured mode."""
"""Transcribe one uploaded audio turn using configured ASR providers."""
mode = (settings.voice_transcription_mode or "demo").strip().lower()
mode = (settings.voice_transcription_mode or "provider").strip().lower()
if mode == "disabled":
raise HTTPException(
status_code=503,
detail="当前环境已禁用语音转写,请先使用文本共创模式。",
)
if mode == "openai":
return await _transcribe_openai(
audio_bytes=audio_bytes,
file_name=file_name,
mime_type=mime_type,
transcript_hint=transcript_hint,
)
return await _transcribe_demo(
hint = _resolve_transcript_hint(transcript_hint)
provider_name = "openai_asr" if mode == "openai" else mode
strategy_providers = None if mode == "provider" else [provider_name]
result = await transcribe_audio(
audio_bytes=audio_bytes,
file_name=file_name,
mime_type=mime_type,
transcript_hint=transcript_hint,
transcript_hint=hint,
language=settings.voice_transcription_language,
provider_names=strategy_providers,
db=db,
user_id=user_id,
)
return VoiceTranscriptionResult(
transcript_text=result.transcript_text,
confidence=result.confidence,
provider=result.provider,
)