refactor: separate provider capability policy

This commit is contained in:
2026-04-18 13:37:59 +08:00
parent 0444b81df6
commit 7b8e7c9944
11 changed files with 393 additions and 88 deletions

View File

@@ -147,3 +147,4 @@ See `backend/.env.example` for required variables:
| CRUD | `/api/reading-events` | Reading progress |
| CRUD | `/api/push-configs` | Push notification settings |
| GET/POST/PUT/DELETE | `/admin/providers` | Provider management (admin) |
| GET | `/admin/providers/capabilities` | Provider capability policy |

View File

@@ -131,6 +131,7 @@ npm run build
| GET | `/api/stories/{story_id}` | 故事详情 |
| DELETE | `/api/stories/{story_id}` | 删除故事 |
| GET/POST/PUT/DELETE | `/admin/providers` | Provider 管理,需开启管理后台 |
| GET | `/admin/providers/capabilities` | Provider 能力分层说明,需开启管理后台 |
## 文档入口
@@ -138,6 +139,7 @@ npm run build
- `docs/product/unified-generation-workflow-prd.md`:统一生成工作流 PRD
- `docs/planning/week-1-execution-backlog.md`:短期执行 backlog
- `docs/technical/memory-system-dev.md`:记忆系统技术说明
- `docs/technical/provider-routing.md`Provider 能力与路由策略说明
## 当前取舍

View File

@@ -8,7 +8,7 @@ from app.db.admin_models import Provider
from app.db.database import get_db
from app.services.adapters.registry import AdapterRegistry
from app.services.cost_tracker import cost_tracker
from app.services.provider_router import DEFAULT_PROVIDERS
from app.services.provider_policy import DEFAULT_PROVIDERS, list_capability_policies
from app.services.secret_service import SecretService
router = APIRouter(dependencies=[Depends(admin_guard)])
@@ -68,6 +68,12 @@ async def get_env_defaults():
return DEFAULT_PROVIDERS
@router.get("/providers/capabilities")
async def list_provider_capabilities():
"""获取 Provider 能力分层与默认路由策略。"""
return list_capability_policies()
@router.get("/providers", response_model=list[ProviderResponse])
async def list_providers(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Provider))

View File

@@ -2,7 +2,6 @@
import json
from collections import defaultdict
from typing import Literal
from pydantic import BaseModel
from sqlalchemy import select
@@ -11,11 +10,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.core.redis import get_redis
from app.db.admin_models import Provider
from app.services.provider_policy import ProviderType
logger = get_logger(__name__)
ProviderType = Literal["text", "image", "tts", "storybook"]
class CachedProvider(BaseModel):
"""Serializable provider configuration matching DB model fields."""

View File

@@ -0,0 +1,148 @@
"""Provider capability and routing policy definitions."""
from dataclasses import dataclass
from enum import Enum
from typing import Literal, Protocol, TypeAlias
ProviderType: TypeAlias = Literal["text", "image", "tts", "storybook"]
class RoutingStrategy(str, Enum):
"""How providers should be ordered before failover execution."""
PRIORITY = "priority"
COST = "cost"
LATENCY = "latency"
ROUND_ROBIN = "round_robin"
@dataclass(frozen=True)
class CapabilityPolicy:
"""Product-level capability policy for one provider family."""
capability: ProviderType
label: str
description: str
settings_attr: str
default_providers: tuple[str, ...]
default_strategy: RoutingStrategy = RoutingStrategy.PRIORITY
demo_provider: str | None = None
class ProviderSettings(Protocol):
"""Settings fields required by provider policy resolution."""
text_providers: list[str]
image_providers: list[str]
tts_providers: list[str]
storybook_providers: list[str]
enable_demo_providers: bool
CAPABILITY_POLICIES: dict[ProviderType, CapabilityPolicy] = {
"text": CapabilityPolicy(
capability="text",
label="文本生成",
description="生成或润色儿童故事文本。",
settings_attr="text_providers",
default_providers=("gemini", "openai"),
demo_provider="demo",
),
"image": CapabilityPolicy(
capability="image",
label="图片生成",
description="生成故事封面或绘本插图。",
settings_attr="image_providers",
default_providers=("cqtai",),
demo_provider="demo",
),
"tts": CapabilityPolicy(
capability="tts",
label="语音合成",
description="将故事文本合成为可播放音频。",
settings_attr="tts_providers",
default_providers=("minimax", "elevenlabs", "edge_tts"),
),
"storybook": CapabilityPolicy(
capability="storybook",
label="绘本结构生成",
description="生成多页绘本结构、分镜文本和插图提示词。",
settings_attr="storybook_providers",
default_providers=("storybook_primary",),
demo_provider="demo",
),
}
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
capability: list(policy.default_providers)
for capability, policy in CAPABILITY_POLICIES.items()
}
API_KEY_MAP: dict[str, str] = {
# Text
"gemini": "text_api_key",
"text_primary": "text_api_key",
"text_api_key": "text_api_key",
"openai": "openai_api_key",
"openai_api_key": "openai_api_key",
# Image
"cqtai": "cqtai_api_key",
"cqtai_api_key": "cqtai_api_key",
"antigravity": "antigravity_api_key",
"antigravity_api_key": "antigravity_api_key",
"image_primary": "image_api_key",
"image_api_key": "image_api_key",
# TTS
"minimax": "minimax_api_key",
"minimax_api_key": "minimax_api_key",
"elevenlabs": "elevenlabs_api_key",
"elevenlabs_api_key": "elevenlabs_api_key",
"edge_tts": "tts_api_key",
"tts_primary": "tts_api_key",
"tts_api_key": "tts_api_key",
}
def get_capability_policy(capability: ProviderType) -> CapabilityPolicy:
"""Return the product policy for a provider capability."""
return CAPABILITY_POLICIES[capability]
def get_provider_names_from_settings(
capability: ProviderType,
settings: ProviderSettings,
) -> list[str]:
"""Resolve provider order from settings, falling back to capability defaults."""
policy = get_capability_policy(capability)
configured = getattr(settings, policy.settings_attr, None)
names = list(configured or policy.default_providers)
if (
settings.enable_demo_providers
and policy.demo_provider
and policy.demo_provider not in names
):
names = [policy.demo_provider, *names]
return names
def list_capability_policies() -> list[dict[str, object]]:
"""Return a serializable capability policy overview for admin/docs use."""
return [
{
"capability": policy.capability,
"label": policy.label,
"description": policy.description,
"settings_attr": policy.settings_attr,
"default_providers": list(policy.default_providers),
"default_strategy": policy.default_strategy.value,
"demo_provider": policy.demo_provider,
}
for policy in CAPABILITY_POLICIES.values()
]

View File

@@ -1,7 +1,6 @@
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
import time
from enum import Enum
from typing import TYPE_CHECKING, Literal, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
@@ -13,6 +12,13 @@ from app.services.adapters.text.models import StoryOutput
from app.services.cost_tracker import cost_tracker
from app.services.provider_cache import get_providers
from app.services.provider_metrics import health_checker, metrics_collector
from app.services.provider_policy import (
API_KEY_MAP,
DEFAULT_PROVIDERS,
ProviderType,
RoutingStrategy,
get_provider_names_from_settings,
)
if TYPE_CHECKING:
from app.db.admin_models import Provider
@@ -21,50 +27,9 @@ logger = get_logger(__name__)
T = TypeVar("T")
ProviderType = Literal["text", "image", "tts", "storybook"]
class RoutingStrategy(str, Enum):
"""路由策略枚举。"""
PRIORITY = "priority" # 按优先级排序(默认)
COST = "cost" # 按成本排序
LATENCY = "latency" # 按延迟排序
ROUND_ROBIN = "round_robin" # 轮询
# 默认配置映射(当 DB 无配置时使用)
# 这是“代码级”的默认策略,对应 .env 为空的情况
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
"text": ["gemini", "openai"],
"image": ["cqtai"],
"tts": ["minimax", "elevenlabs", "edge_tts"],
"storybook": ["storybook_primary"],
}
# API Key 映射adapter_name -> settings 属性名
API_KEY_MAP: dict[str, str] = {
# Text
"gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段
"text_primary": "text_api_key", # 兼容旧别名
"openai": "openai_api_key",
# Image
"cqtai": "cqtai_api_key",
"image_primary": "image_api_key", # 兼容旧别名
# TTS
"minimax": "minimax_api_key",
"elevenlabs": "elevenlabs_api_key",
"edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空)
"tts_primary": "tts_api_key", # 兼容旧别名
}
# 轮询计数器
_round_robin_counters: dict[ProviderType, int] = {
"text": 0,
"image": 0,
"tts": 0,
provider_type: 0 for provider_type in DEFAULT_PROVIDERS
}
# 延迟缓存(内存中,简化实现)
@@ -115,6 +80,13 @@ def _get_default_config(adapter_name: str) -> AdapterConfig | None:
model=settings.image_model or "nano-banana-pro",
timeout_ms=120000,
)
if adapter_name == "antigravity":
return AdapterConfig(
api_key=getattr(settings, "antigravity_api_key", ""),
api_base=getattr(settings, "antigravity_api_base", ""),
model=settings.antigravity_model,
timeout_ms=120000,
)
if adapter_name == "image_primary":
# 如果还有地方在用 image_primary暂时映射到快或者其他
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
@@ -196,15 +168,7 @@ async def _get_providers_with_config(
if db_providers:
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
settings_map = {
"text": settings.text_providers,
"image": settings.image_providers,
"tts": settings.tts_providers,
"storybook": settings.storybook_providers,
}
names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type]
if settings.enable_demo_providers and "demo" not in names:
names = ["demo", *names]
names = get_provider_names_from_settings(provider_type, settings)
result = []
for name in names:

View File

@@ -1,11 +1,18 @@
"""Provider router 测试 - failover 和配置加载。"""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.adapters import AdapterConfig
from app.services.adapters.text.models import StoryOutput
from app.services.provider_policy import (
DEFAULT_PROVIDERS,
RoutingStrategy,
get_provider_names_from_settings,
list_capability_policies,
)
class TestProviderFailover:
@@ -189,6 +196,90 @@ class TestProviderFailover:
mock_record_cost.assert_awaited_once()
assert mock_record_cost.await_args.kwargs["provider_id"] is None
@pytest.mark.asyncio
async def test_storybook_round_robin_strategy_is_supported(self):
"""所有能力都应能使用 routing policystorybook 不能漏掉轮询计数器。"""
from app.services import provider_router
from app.services.adapters.storybook.primary import Storybook
mock_storybook = Storybook(
title="轮询绘本",
main_character="小星",
art_style="温暖水彩",
pages=[],
cover_prompt="cover",
)
class MockAdapter:
estimated_cost = 0.0
def __init__(self, config):
self.config = config
async def execute(self, **kwargs):
return mock_storybook
with patch.object(
provider_router,
"_get_providers_with_config",
new_callable=AsyncMock,
) as mock_providers:
mock_providers.return_value = [
("storybook_primary", AdapterConfig(api_key=""), None),
]
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
result = await provider_router.generate_storybook(
keywords="测试",
strategy=RoutingStrategy.ROUND_ROBIN,
)
assert result == mock_storybook
class TestProviderPolicy:
"""Provider capability / routing policy boundary tests."""
def test_policy_lists_all_capabilities(self):
policies = list_capability_policies()
capabilities = {item["capability"] for item in policies}
assert capabilities == {"text", "image", "tts", "storybook"}
assert DEFAULT_PROVIDERS["storybook"] == ["storybook_primary"]
def test_demo_provider_only_added_to_supported_capabilities(self):
settings = SimpleNamespace(
text_providers=["gemini"],
image_providers=["cqtai"],
tts_providers=["edge_tts"],
storybook_providers=["storybook_primary"],
enable_demo_providers=True,
)
assert get_provider_names_from_settings("text", settings) == ["demo", "gemini"]
assert get_provider_names_from_settings("image", settings) == ["demo", "cqtai"]
assert get_provider_names_from_settings("storybook", settings) == [
"demo",
"storybook_primary",
]
assert get_provider_names_from_settings("tts", settings) == ["edge_tts"]
def test_policy_defaults_when_settings_lists_are_empty(self):
settings = SimpleNamespace(
text_providers=[],
image_providers=[],
tts_providers=[],
storybook_providers=[],
enable_demo_providers=False,
)
assert get_provider_names_from_settings("text", settings) == ["gemini", "openai"]
assert get_provider_names_from_settings("tts", settings) == [
"minimax",
"elevenlabs",
"edge_tts",
]
class TestProviderConfigFromDB:
"""从 DB 加载 provider 配置测试。"""
@@ -237,6 +328,31 @@ class TestProviderConfigFromDB:
assert config.api_key == "settings-api-key"
def test_build_config_uses_direct_config_ref_name(self):
"""config_ref 可以直接使用 settings 字段名,便于后台配置。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "antigravity"
mock_provider.api_key = None
mock_provider.api_base = None
mock_provider.model = None
mock_provider.timeout_ms = None
mock_provider.max_retries = None
mock_provider.config_ref = "antigravity_api_key"
mock_provider.config_json = {}
with patch("app.services.provider_router.settings") as mock_settings:
mock_settings.antigravity_api_key = "antigravity-key"
mock_settings.antigravity_api_base = "https://antigravity.example"
mock_settings.antigravity_model = "gemini-3-pro-image"
config = _build_config_from_provider(mock_provider)
assert config.api_key == "antigravity-key"
assert config.api_base == "https://antigravity.example"
assert config.model == "gemini-3-pro-image"
class TestProviderCacheStartup:
"""Provider cache 启动加载测试。"""

View File

@@ -20,6 +20,9 @@
- `technical/memory-system-dev.md`
记忆系统技术说明。用于后续继续做孩子档案、故事宇宙和个性化生成。
- `technical/provider-routing.md`
Provider Routing 技术说明。用于解释 Capability / Provider / Adapter / Routing Policy 的职责边界。
## 维护规则
- 新 PRD 放到 `docs/product/`

View File

@@ -53,7 +53,7 @@
- 已新增数据库迁移:
- `0009_add_story_generation_statuses.py`
- `0010_add_story_audio_cache_path.py`
- 已完成一轮后端回归验证:`backend/``pytest -q` 结果为 `66 passed`
- 已完成一轮后端回归验证:`backend/``pytest -q` 结果为 `71 passed`
- 已完成全量后端 lint 清理:`ruff check app tests` 可通过
- 已修复 admin-frontend 构建阻塞,主前端与管理端前端均可生产构建
- 已落地首版统一资产重试入口:`POST /api/stories/{story_id}/assets/retry`
@@ -70,6 +70,11 @@
- 绘本缺失插图补全
- 故事音频缓存读取与 TTS 生成
- 已引入首版服务层 `AssetCompletionResult`,用于统一表达资产补全结果
- 已完成 Provider 分层首版落地:
- 新增 `provider_policy.py`,定义 Capability / Routing Policy / 默认 Provider 顺序
- Provider Router 专注运行时 failover、熔断、成本和指标记录
- 新增 `/admin/providers/capabilities` 展示能力分层
- 新增 `docs/technical/provider-routing.md` 作为术语表和分层说明
### What Is In Progress
@@ -79,7 +84,6 @@
### What Is Still Pending
- Provider 的 Capability / Provider / Routing Policy 边界整理
- Week 2 可直接执行的开发任务表
- 演示 checklist 与最终收尾策略
@@ -178,7 +182,7 @@
- [ ] 明确 admin-frontend 的处理方案
- [x] 明确 Storybook 恢复方案
- [ ] 明确 Provider 重构边界
- [x] 明确 Provider 重构边界
---
@@ -191,7 +195,7 @@
| W1-03 | Product / System | 盘点现有生成路径:普通故事、完整生成、绘本生成 | 现状流程图或对照表 | P0 | 0.5d | Done |
| W1-04 | Product / System | 定义统一 Generation Workflow 状态模型 | 状态流转说明 | P0 | 1.0d | Done |
| W1-05 | Product / Backend | 定义统一工作流下的 API / 数据结构影响 | 接口与模型变更清单 | P0 | 0.5d | In Progress |
| W1-06 | Product / Backend | 梳理 Provider 概念层Capability / Provider / Routing Policy | 分层图与术语表 | P1 | 0.5d | Pending |
| W1-06 | Product / Backend | 梳理 Provider 概念层Capability / Provider / Routing Policy | 分层图与术语表 | P1 | 0.5d | Done |
| W1-07 | Product / Frontend | 梳理 Storybook 当前问题与恢复方案 | 恢复方案说明 | P0 | 0.5d | Done |
| W1-08 | Product / Frontend | 确认 admin 前端是修复、裁剪还是暂时降级 | 决策记录 | P0 | 0.5d | Done |
| W1-09 | Planning | 产出 Week 2 开发任务清单 | 下周 backlog | P1 | 0.5d | In Progress |

View File

@@ -182,10 +182,10 @@ DreamWeaver 是一款面向 3-8 岁亲子场景的个性化 AI 绘本与陪伴
**Acceptance Criteria**
- [ ] Provider 配置以能力、供应商、模型配置的方式组织
- [ ] 路由策略与凭证管理职责分离
- [ ] 系统能清楚展示失败降级逻辑
- [ ] 管理端或配置文档能说明当前有效供应链路
- [x] Provider 配置以能力、供应商、模型配置的方式组织
- [x] 路由策略与凭证管理职责分离
- [x] 系统能清楚展示失败降级逻辑
- [x] 管理端或配置文档能说明当前有效供应链路
---
@@ -443,7 +443,7 @@ DreamWeaver 是一款面向 3-8 岁亲子场景的个性化 AI 绘本与陪伴
### Glossary
- **Generation Workflow**: 从用户输入到文本、图片、语音完成的一整套生成流程。
- **Capability**: 底层 AI 能力分类,如文本、图片、语音。
- **Capability**: 底层 AI 能力分类,如文本、图片、语音、绘本结构
- **Provider**: 具体供应商,如 Gemini、OpenAI、MiniMax。
- **Routing Policy**: 供应商选择与降级策略。
- **Degraded Completion**: 资产部分失败但主结果可用的完成状态。

View File

@@ -0,0 +1,63 @@
# Provider Routing 技术说明
本说明用于支撑求职版 DreamWeaver 的 Provider 分层表达。当前目标不是做复杂平台化,而是把 AI 能力供应链讲清楚、跑稳定、便于后续演进。
## 核心概念
| 概念 | 含义 | 当前代码位置 |
| --- | --- | --- |
| Capability | 产品需要的 AI 能力类型,例如文本、图片、语音、绘本结构 | `backend/app/services/provider_policy.py` |
| Provider | 某个能力下的可调用供应商配置,例如 Gemini、OpenAI、CQTAI、MiniMax | `providers` 表与 `provider_cache.py` |
| Adapter | 供应商调用实现,负责把统一入参翻译成具体 API 调用 | `backend/app/services/adapters/` |
| Routing Policy | 调用前如何排序与选择 Provider例如优先级、成本、延迟、轮询 | `provider_policy.py` + `provider_router.py` |
| Failover | 当前 Provider 调用失败后自动尝试下一 Provider | `provider_router.py` |
## 当前 Capability
| Capability | 用途 | 默认 Provider | Demo Provider |
| --- | --- | --- | --- |
| `text` | 生成/润色儿童故事文本 | `gemini`, `openai` | `demo` |
| `image` | 生成封面和绘本插图 | `cqtai` | `demo` |
| `tts` | 故事语音合成 | `minimax`, `elevenlabs`, `edge_tts` | 无 |
| `storybook` | 生成多页绘本结构和插图提示词 | `storybook_primary` | `demo` |
`ENABLE_DEMO_PROVIDERS=true` 时,只会给具备 demo adapter 的能力前置 `demo` provider。TTS 暂无 demo adapter因此不会插入不存在的 `tts:demo`
## 代码边界
`provider_policy.py` 负责定义“产品级策略”:
- Capability 清单
- 默认 Provider 顺序
- `.env` 中对应的 provider 列表字段
- 默认 routing strategy
- API key ref 到 settings 字段的映射
- 哪些能力支持本地 demo provider
`provider_router.py` 负责执行“运行时路由”:
- 从 DB cache 或 `.env` 读取 Provider 配置
- 构建 `AdapterConfig`
- 按 Routing Policy 排序
- 熔断过滤
- 调用 adapter
- 记录 metrics、health、cost
- failover 并聚合错误
`adapters/` 负责具体供应商 API
- 不决定业务工作流
- 不读取用户故事上下文
- 不负责 Provider 排序或熔断
## 演进原则
- 新增 AI 能力时,先在 `provider_policy.py` 增加 Capability再注册 adapter。
- 新增供应商时,先实现 adapter再把默认顺序或 DB 配置接入对应 Capability。
- 路由策略只影响调用顺序,不应该改变故事/绘本/音频的产品工作流。
- 本轮求职版不做多租户供应商市场,也不做复杂负载均衡;优先保证能力分层清楚、失败可恢复、演示稳定。
- 后台 `config_ref` 可以使用 adapter 别名,也可以直接使用 settings 字段名,例如 `text_api_key``antigravity_api_key`
## 面试表达口径
DreamWeaver 的 Provider 体系不是把供应商暴露给用户,而是把多模型能力整理成稳定的产品能力。用户看到的是“生成故事、生成封面、播放语音”,系统内部才把它映射到 `text``image``tts``storybook` 这些 Capability再通过 Routing Policy 选择具体 Provider 和 Adapter。