refactor: separate provider capability policy

This commit is contained in:
2026-04-18 13:37:59 +08:00
parent 0444b81df6
commit 7b8e7c9944
11 changed files with 393 additions and 88 deletions

View File

@@ -147,3 +147,4 @@ See `backend/.env.example` for required variables:
| CRUD | `/api/reading-events` | Reading progress |
| CRUD | `/api/push-configs` | Push notification settings |
| GET/POST/PUT/DELETE | `/admin/providers` | Provider management (admin) |
| GET | `/admin/providers/capabilities` | Provider capability policy |

View File

@@ -131,6 +131,7 @@ npm run build
| GET | `/api/stories/{story_id}` | 故事详情 |
| DELETE | `/api/stories/{story_id}` | 删除故事 |
| GET/POST/PUT/DELETE | `/admin/providers` | Provider 管理,需开启管理后台 |
| GET | `/admin/providers/capabilities` | Provider 能力分层说明,需开启管理后台 |
## 文档入口
@@ -138,6 +139,7 @@ npm run build
- `docs/product/unified-generation-workflow-prd.md`:统一生成工作流 PRD
- `docs/planning/week-1-execution-backlog.md`:短期执行 backlog
- `docs/technical/memory-system-dev.md`:记忆系统技术说明
- `docs/technical/provider-routing.md`Provider 能力与路由策略说明
## 当前取舍

View File

@@ -8,7 +8,7 @@ from app.db.admin_models import Provider
from app.db.database import get_db
from app.services.adapters.registry import AdapterRegistry
from app.services.cost_tracker import cost_tracker
from app.services.provider_router import DEFAULT_PROVIDERS
from app.services.provider_policy import DEFAULT_PROVIDERS, list_capability_policies
from app.services.secret_service import SecretService
router = APIRouter(dependencies=[Depends(admin_guard)])
@@ -68,6 +68,12 @@ async def get_env_defaults():
return DEFAULT_PROVIDERS
@router.get("/providers/capabilities")
async def list_provider_capabilities():
"""获取 Provider 能力分层与默认路由策略。"""
return list_capability_policies()
@router.get("/providers", response_model=list[ProviderResponse])
async def list_providers(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Provider))

View File

@@ -1,22 +1,20 @@
"""Redis-backed cache for providers loaded from DB."""
import json
from collections import defaultdict
from typing import Literal
from pydantic import BaseModel
import json
from collections import defaultdict
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.core.redis import get_redis
from app.db.admin_models import Provider
logger = get_logger(__name__)
ProviderType = Literal["text", "image", "tts", "storybook"]
from app.core.redis import get_redis
from app.db.admin_models import Provider
from app.services.provider_policy import ProviderType
logger = get_logger(__name__)
class CachedProvider(BaseModel):
"""Serializable provider configuration matching DB model fields."""

View File

@@ -0,0 +1,148 @@
"""Provider capability and routing policy definitions."""
from dataclasses import dataclass
from enum import Enum
from typing import Literal, Protocol, TypeAlias
ProviderType: TypeAlias = Literal["text", "image", "tts", "storybook"]
class RoutingStrategy(str, Enum):
"""How providers should be ordered before failover execution."""
PRIORITY = "priority"
COST = "cost"
LATENCY = "latency"
ROUND_ROBIN = "round_robin"
@dataclass(frozen=True)
class CapabilityPolicy:
"""Product-level capability policy for one provider family."""
capability: ProviderType
label: str
description: str
settings_attr: str
default_providers: tuple[str, ...]
default_strategy: RoutingStrategy = RoutingStrategy.PRIORITY
demo_provider: str | None = None
class ProviderSettings(Protocol):
"""Settings fields required by provider policy resolution."""
text_providers: list[str]
image_providers: list[str]
tts_providers: list[str]
storybook_providers: list[str]
enable_demo_providers: bool
CAPABILITY_POLICIES: dict[ProviderType, CapabilityPolicy] = {
"text": CapabilityPolicy(
capability="text",
label="文本生成",
description="生成或润色儿童故事文本。",
settings_attr="text_providers",
default_providers=("gemini", "openai"),
demo_provider="demo",
),
"image": CapabilityPolicy(
capability="image",
label="图片生成",
description="生成故事封面或绘本插图。",
settings_attr="image_providers",
default_providers=("cqtai",),
demo_provider="demo",
),
"tts": CapabilityPolicy(
capability="tts",
label="语音合成",
description="将故事文本合成为可播放音频。",
settings_attr="tts_providers",
default_providers=("minimax", "elevenlabs", "edge_tts"),
),
"storybook": CapabilityPolicy(
capability="storybook",
label="绘本结构生成",
description="生成多页绘本结构、分镜文本和插图提示词。",
settings_attr="storybook_providers",
default_providers=("storybook_primary",),
demo_provider="demo",
),
}
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
capability: list(policy.default_providers)
for capability, policy in CAPABILITY_POLICIES.items()
}
API_KEY_MAP: dict[str, str] = {
# Text
"gemini": "text_api_key",
"text_primary": "text_api_key",
"text_api_key": "text_api_key",
"openai": "openai_api_key",
"openai_api_key": "openai_api_key",
# Image
"cqtai": "cqtai_api_key",
"cqtai_api_key": "cqtai_api_key",
"antigravity": "antigravity_api_key",
"antigravity_api_key": "antigravity_api_key",
"image_primary": "image_api_key",
"image_api_key": "image_api_key",
# TTS
"minimax": "minimax_api_key",
"minimax_api_key": "minimax_api_key",
"elevenlabs": "elevenlabs_api_key",
"elevenlabs_api_key": "elevenlabs_api_key",
"edge_tts": "tts_api_key",
"tts_primary": "tts_api_key",
"tts_api_key": "tts_api_key",
}
def get_capability_policy(capability: ProviderType) -> CapabilityPolicy:
"""Return the product policy for a provider capability."""
return CAPABILITY_POLICIES[capability]
def get_provider_names_from_settings(
capability: ProviderType,
settings: ProviderSettings,
) -> list[str]:
"""Resolve provider order from settings, falling back to capability defaults."""
policy = get_capability_policy(capability)
configured = getattr(settings, policy.settings_attr, None)
names = list(configured or policy.default_providers)
if (
settings.enable_demo_providers
and policy.demo_provider
and policy.demo_provider not in names
):
names = [policy.demo_provider, *names]
return names
def list_capability_policies() -> list[dict[str, object]]:
"""Return a serializable capability policy overview for admin/docs use."""
return [
{
"capability": policy.capability,
"label": policy.label,
"description": policy.description,
"settings_attr": policy.settings_attr,
"default_providers": list(policy.default_providers),
"default_strategy": policy.default_strategy.value,
"demo_provider": policy.demo_provider,
}
for policy in CAPABILITY_POLICIES.values()
]

View File

@@ -1,7 +1,6 @@
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
import time
from enum import Enum
from typing import TYPE_CHECKING, Literal, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
@@ -13,6 +12,13 @@ from app.services.adapters.text.models import StoryOutput
from app.services.cost_tracker import cost_tracker
from app.services.provider_cache import get_providers
from app.services.provider_metrics import health_checker, metrics_collector
from app.services.provider_policy import (
API_KEY_MAP,
DEFAULT_PROVIDERS,
ProviderType,
RoutingStrategy,
get_provider_names_from_settings,
)
if TYPE_CHECKING:
from app.db.admin_models import Provider
@@ -21,50 +27,9 @@ logger = get_logger(__name__)
T = TypeVar("T")
ProviderType = Literal["text", "image", "tts", "storybook"]
class RoutingStrategy(str, Enum):
"""路由策略枚举。"""
PRIORITY = "priority" # 按优先级排序(默认)
COST = "cost" # 按成本排序
LATENCY = "latency" # 按延迟排序
ROUND_ROBIN = "round_robin" # 轮询
# 默认配置映射(当 DB 无配置时使用)
# 这是“代码级”的默认策略,对应 .env 为空的情况
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
"text": ["gemini", "openai"],
"image": ["cqtai"],
"tts": ["minimax", "elevenlabs", "edge_tts"],
"storybook": ["storybook_primary"],
}
# API Key 映射adapter_name -> settings 属性名
API_KEY_MAP: dict[str, str] = {
# Text
"gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段
"text_primary": "text_api_key", # 兼容旧别名
"openai": "openai_api_key",
# Image
"cqtai": "cqtai_api_key",
"image_primary": "image_api_key", # 兼容旧别名
# TTS
"minimax": "minimax_api_key",
"elevenlabs": "elevenlabs_api_key",
"edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空)
"tts_primary": "tts_api_key", # 兼容旧别名
}
# 轮询计数器
_round_robin_counters: dict[ProviderType, int] = {
"text": 0,
"image": 0,
"tts": 0,
provider_type: 0 for provider_type in DEFAULT_PROVIDERS
}
# 延迟缓存(内存中,简化实现)
@@ -115,6 +80,13 @@ def _get_default_config(adapter_name: str) -> AdapterConfig | None:
model=settings.image_model or "nano-banana-pro",
timeout_ms=120000,
)
if adapter_name == "antigravity":
return AdapterConfig(
api_key=getattr(settings, "antigravity_api_key", ""),
api_base=getattr(settings, "antigravity_api_base", ""),
model=settings.antigravity_model,
timeout_ms=120000,
)
if adapter_name == "image_primary":
# 如果还有地方在用 image_primary暂时映射到快或者其他
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
@@ -196,15 +168,7 @@ async def _get_providers_with_config(
if db_providers:
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
settings_map = {
"text": settings.text_providers,
"image": settings.image_providers,
"tts": settings.tts_providers,
"storybook": settings.storybook_providers,
}
names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type]
if settings.enable_demo_providers and "demo" not in names:
names = ["demo", *names]
names = get_provider_names_from_settings(provider_type, settings)
result = []
for name in names:

View File

@@ -1,14 +1,21 @@
"""Provider router 测试 - failover 和配置加载。"""
from unittest.mock import AsyncMock, MagicMock, patch
"""Provider router 测试 - failover 和配置加载。"""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.adapters import AdapterConfig
from app.services.adapters.text.models import StoryOutput
from app.services.adapters.text.models import StoryOutput
from app.services.provider_policy import (
DEFAULT_PROVIDERS,
RoutingStrategy,
get_provider_names_from_settings,
list_capability_policies,
)
class TestProviderFailover:
class TestProviderFailover:
"""Provider failover 测试。"""
@pytest.mark.asyncio
@@ -126,7 +133,7 @@ class TestProviderFailover:
)
@pytest.mark.asyncio
async def test_default_provider_skips_fk_backed_metrics(self):
async def test_default_provider_skips_fk_backed_metrics(self):
"""环境变量/default provider 没有 providers 表记录,不写带外键的指标表。"""
from app.services import provider_router
@@ -187,7 +194,91 @@ class TestProviderFailover:
mock_record_call.assert_not_called()
mock_record_call_result.assert_not_called()
mock_record_cost.assert_awaited_once()
assert mock_record_cost.await_args.kwargs["provider_id"] is None
assert mock_record_cost.await_args.kwargs["provider_id"] is None
@pytest.mark.asyncio
async def test_storybook_round_robin_strategy_is_supported(self):
"""所有能力都应能使用 routing policystorybook 不能漏掉轮询计数器。"""
from app.services import provider_router
from app.services.adapters.storybook.primary import Storybook
mock_storybook = Storybook(
title="轮询绘本",
main_character="小星",
art_style="温暖水彩",
pages=[],
cover_prompt="cover",
)
class MockAdapter:
estimated_cost = 0.0
def __init__(self, config):
self.config = config
async def execute(self, **kwargs):
return mock_storybook
with patch.object(
provider_router,
"_get_providers_with_config",
new_callable=AsyncMock,
) as mock_providers:
mock_providers.return_value = [
("storybook_primary", AdapterConfig(api_key=""), None),
]
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
result = await provider_router.generate_storybook(
keywords="测试",
strategy=RoutingStrategy.ROUND_ROBIN,
)
assert result == mock_storybook
class TestProviderPolicy:
"""Provider capability / routing policy boundary tests."""
def test_policy_lists_all_capabilities(self):
policies = list_capability_policies()
capabilities = {item["capability"] for item in policies}
assert capabilities == {"text", "image", "tts", "storybook"}
assert DEFAULT_PROVIDERS["storybook"] == ["storybook_primary"]
def test_demo_provider_only_added_to_supported_capabilities(self):
settings = SimpleNamespace(
text_providers=["gemini"],
image_providers=["cqtai"],
tts_providers=["edge_tts"],
storybook_providers=["storybook_primary"],
enable_demo_providers=True,
)
assert get_provider_names_from_settings("text", settings) == ["demo", "gemini"]
assert get_provider_names_from_settings("image", settings) == ["demo", "cqtai"]
assert get_provider_names_from_settings("storybook", settings) == [
"demo",
"storybook_primary",
]
assert get_provider_names_from_settings("tts", settings) == ["edge_tts"]
def test_policy_defaults_when_settings_lists_are_empty(self):
settings = SimpleNamespace(
text_providers=[],
image_providers=[],
tts_providers=[],
storybook_providers=[],
enable_demo_providers=False,
)
assert get_provider_names_from_settings("text", settings) == ["gemini", "openai"]
assert get_provider_names_from_settings("tts", settings) == [
"minimax",
"elevenlabs",
"edge_tts",
]
class TestProviderConfigFromDB:
@@ -215,10 +306,10 @@ class TestProviderConfigFromDB:
assert config.timeout_ms == 30000
assert config.max_retries == 5
def test_build_config_fallback_to_settings(self):
"""Provider 无 api_key 时回退到 settings。"""
from app.services.provider_router import _build_config_from_provider
def test_build_config_fallback_to_settings(self):
"""Provider 无 api_key 时回退到 settings。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "text_primary"
mock_provider.api_key = None
@@ -234,8 +325,33 @@ class TestProviderConfigFromDB:
mock_settings.text_model = "gemini-2.0-flash"
config = _build_config_from_provider(mock_provider)
assert config.api_key == "settings-api-key"
assert config.api_key == "settings-api-key"
def test_build_config_uses_direct_config_ref_name(self):
"""config_ref 可以直接使用 settings 字段名,便于后台配置。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "antigravity"
mock_provider.api_key = None
mock_provider.api_base = None
mock_provider.model = None
mock_provider.timeout_ms = None
mock_provider.max_retries = None
mock_provider.config_ref = "antigravity_api_key"
mock_provider.config_json = {}
with patch("app.services.provider_router.settings") as mock_settings:
mock_settings.antigravity_api_key = "antigravity-key"
mock_settings.antigravity_api_base = "https://antigravity.example"
mock_settings.antigravity_model = "gemini-3-pro-image"
config = _build_config_from_provider(mock_provider)
assert config.api_key == "antigravity-key"
assert config.api_base == "https://antigravity.example"
assert config.model == "gemini-3-pro-image"
class TestProviderCacheStartup:

View File

@@ -20,6 +20,9 @@
- `technical/memory-system-dev.md`
记忆系统技术说明。用于后续继续做孩子档案、故事宇宙和个性化生成。
- `technical/provider-routing.md`
Provider Routing 技术说明。用于解释 Capability / Provider / Adapter / Routing Policy 的职责边界。
## 维护规则
- 新 PRD 放到 `docs/product/`

View File

@@ -53,7 +53,7 @@
- 已新增数据库迁移:
- `0009_add_story_generation_statuses.py`
- `0010_add_story_audio_cache_path.py`
- 已完成一轮后端回归验证:`backend/``pytest -q` 结果为 `66 passed`
- 已完成一轮后端回归验证:`backend/``pytest -q` 结果为 `71 passed`
- 已完成全量后端 lint 清理:`ruff check app tests` 可通过
- 已修复 admin-frontend 构建阻塞,主前端与管理端前端均可生产构建
- 已落地首版统一资产重试入口:`POST /api/stories/{story_id}/assets/retry`
@@ -70,6 +70,11 @@
- 绘本缺失插图补全
- 故事音频缓存读取与 TTS 生成
- 已引入首版服务层 `AssetCompletionResult`,用于统一表达资产补全结果
- 已完成 Provider 分层首版落地:
- 新增 `provider_policy.py`,定义 Capability / Routing Policy / 默认 Provider 顺序
- Provider Router 专注运行时 failover、熔断、成本和指标记录
- 新增 `/admin/providers/capabilities` 展示能力分层
- 新增 `docs/technical/provider-routing.md` 作为术语表和分层说明
### What Is In Progress
@@ -79,7 +84,6 @@
### What Is Still Pending
- Provider 的 Capability / Provider / Routing Policy 边界整理
- Week 2 可直接执行的开发任务表
- 演示 checklist 与最终收尾策略
@@ -178,7 +182,7 @@
- [ ] 明确 admin-frontend 的处理方案
- [x] 明确 Storybook 恢复方案
- [ ] 明确 Provider 重构边界
- [x] 明确 Provider 重构边界
---
@@ -191,7 +195,7 @@
| W1-03 | Product / System | 盘点现有生成路径:普通故事、完整生成、绘本生成 | 现状流程图或对照表 | P0 | 0.5d | Done |
| W1-04 | Product / System | 定义统一 Generation Workflow 状态模型 | 状态流转说明 | P0 | 1.0d | Done |
| W1-05 | Product / Backend | 定义统一工作流下的 API / 数据结构影响 | 接口与模型变更清单 | P0 | 0.5d | In Progress |
| W1-06 | Product / Backend | 梳理 Provider 概念层Capability / Provider / Routing Policy | 分层图与术语表 | P1 | 0.5d | Pending |
| W1-06 | Product / Backend | 梳理 Provider 概念层Capability / Provider / Routing Policy | 分层图与术语表 | P1 | 0.5d | Done |
| W1-07 | Product / Frontend | 梳理 Storybook 当前问题与恢复方案 | 恢复方案说明 | P0 | 0.5d | Done |
| W1-08 | Product / Frontend | 确认 admin 前端是修复、裁剪还是暂时降级 | 决策记录 | P0 | 0.5d | Done |
| W1-09 | Planning | 产出 Week 2 开发任务清单 | 下周 backlog | P1 | 0.5d | In Progress |

View File

@@ -182,10 +182,10 @@ DreamWeaver 是一款面向 3-8 岁亲子场景的个性化 AI 绘本与陪伴
**Acceptance Criteria**
- [ ] Provider 配置以能力、供应商、模型配置的方式组织
- [ ] 路由策略与凭证管理职责分离
- [ ] 系统能清楚展示失败降级逻辑
- [ ] 管理端或配置文档能说明当前有效供应链路
- [x] Provider 配置以能力、供应商、模型配置的方式组织
- [x] 路由策略与凭证管理职责分离
- [x] 系统能清楚展示失败降级逻辑
- [x] 管理端或配置文档能说明当前有效供应链路
---
@@ -443,7 +443,7 @@ DreamWeaver 是一款面向 3-8 岁亲子场景的个性化 AI 绘本与陪伴
### Glossary
- **Generation Workflow**: 从用户输入到文本、图片、语音完成的一整套生成流程。
- **Capability**: 底层 AI 能力分类,如文本、图片、语音。
- **Capability**: 底层 AI 能力分类,如文本、图片、语音、绘本结构
- **Provider**: 具体供应商,如 Gemini、OpenAI、MiniMax。
- **Routing Policy**: 供应商选择与降级策略。
- **Degraded Completion**: 资产部分失败但主结果可用的完成状态。

View File

@@ -0,0 +1,63 @@
# Provider Routing 技术说明
本说明用于支撑求职版 DreamWeaver 的 Provider 分层表达。当前目标不是做复杂平台化,而是把 AI 能力供应链讲清楚、跑稳定、便于后续演进。
## 核心概念
| 概念 | 含义 | 当前代码位置 |
| --- | --- | --- |
| Capability | 产品需要的 AI 能力类型,例如文本、图片、语音、绘本结构 | `backend/app/services/provider_policy.py` |
| Provider | 某个能力下的可调用供应商配置,例如 Gemini、OpenAI、CQTAI、MiniMax | `providers` 表与 `provider_cache.py` |
| Adapter | 供应商调用实现,负责把统一入参翻译成具体 API 调用 | `backend/app/services/adapters/` |
| Routing Policy | 调用前如何排序与选择 Provider例如优先级、成本、延迟、轮询 | `provider_policy.py` + `provider_router.py` |
| Failover | 当前 Provider 调用失败后自动尝试下一 Provider | `provider_router.py` |
## 当前 Capability
| Capability | 用途 | 默认 Provider | Demo Provider |
| --- | --- | --- | --- |
| `text` | 生成/润色儿童故事文本 | `gemini`, `openai` | `demo` |
| `image` | 生成封面和绘本插图 | `cqtai` | `demo` |
| `tts` | 故事语音合成 | `minimax`, `elevenlabs`, `edge_tts` | 无 |
| `storybook` | 生成多页绘本结构和插图提示词 | `storybook_primary` | `demo` |
`ENABLE_DEMO_PROVIDERS=true` 时,只会给具备 demo adapter 的能力前置 `demo` provider。TTS 暂无 demo adapter因此不会插入不存在的 `tts:demo`
## 代码边界
`provider_policy.py` 负责定义“产品级策略”:
- Capability 清单
- 默认 Provider 顺序
- `.env` 中对应的 provider 列表字段
- 默认 routing strategy
- API key ref 到 settings 字段的映射
- 哪些能力支持本地 demo provider
`provider_router.py` 负责执行“运行时路由”:
- 从 DB cache 或 `.env` 读取 Provider 配置
- 构建 `AdapterConfig`
- 按 Routing Policy 排序
- 熔断过滤
- 调用 adapter
- 记录 metrics、health、cost
- failover 并聚合错误
`adapters/` 负责具体供应商 API
- 不决定业务工作流
- 不读取用户故事上下文
- 不负责 Provider 排序或熔断
## 演进原则
- 新增 AI 能力时,先在 `provider_policy.py` 增加 Capability再注册 adapter。
- 新增供应商时,先实现 adapter再把默认顺序或 DB 配置接入对应 Capability。
- 路由策略只影响调用顺序,不应该改变故事/绘本/音频的产品工作流。
- 本轮求职版不做多租户供应商市场,也不做复杂负载均衡;优先保证能力分层清楚、失败可恢复、演示稳定。
- 后台 `config_ref` 可以使用 adapter 别名,也可以直接使用 settings 字段名,例如 `text_api_key``antigravity_api_key`
## 面试表达口径
DreamWeaver 的 Provider 体系不是把供应商暴露给用户,而是把多模型能力整理成稳定的产品能力。用户看到的是“生成故事、生成封面、播放语音”,系统内部才把它映射到 `text``image``tts``storybook` 这些 Capability再通过 Routing Policy 选择具体 Provider 和 Adapter。