Files
dreamweaver/backend/app/services/provider_cache.py

111 lines
3.4 KiB
Python

"""Redis-backed cache for providers loaded from DB."""
import json
from collections import defaultdict
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.core.redis import get_redis
from app.db.admin_models import Provider
from app.services.provider_policy import ProviderType
logger = get_logger(__name__)
class CachedProvider(BaseModel):
"""Serializable provider configuration matching DB model fields."""
id: str
name: str
type: str
adapter: str
model: str | None = None
api_base: str | None = None
api_key: str | None = None
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_json: dict | None = None
config_ref: str | None = None
# Local memory fallback (L1 cache)
_local_cache: dict[ProviderType, list[CachedProvider]] = defaultdict(list)
CACHE_KEY = "dreamweaver:providers:config"
async def reload_providers(db: AsyncSession) -> dict[ProviderType, list[CachedProvider]]:
"""Reload providers from DB and update Redis cache."""
try:
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
providers = result.scalars().all()
# Convert to Pydantic models
cached_list = []
for p in providers:
cached_list.append(
CachedProvider(
id=p.id,
name=p.name,
type=p.type,
adapter=p.adapter,
model=p.model,
api_base=p.api_base,
api_key=p.api_key,
timeout_ms=p.timeout_ms,
max_retries=p.max_retries,
weight=p.weight,
priority=p.priority,
enabled=p.enabled,
config_json=p.config_json,
config_ref=p.config_ref,
)
)
# Group by type
grouped: dict[str, list[CachedProvider]] = defaultdict(list)
for cp in cached_list:
grouped[cp.type].append(cp)
# Sort
for k in grouped:
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
# Update Redis
redis = await get_redis()
# Serialize entire dict structure
# Pydantic -> dict -> json
json_data = {k: [p.model_dump() for p in v] for k, v in grouped.items()}
await redis.set(CACHE_KEY, json.dumps(json_data))
# Update local cache
_local_cache.clear()
_local_cache.update(grouped)
return grouped
except Exception as e:
logger.error("failed_to_reload_providers", error=str(e))
raise
async def get_providers(provider_type: ProviderType) -> list[CachedProvider]:
"""Get providers from Redis (preferred) or local fallback."""
try:
redis = await get_redis()
data = await redis.get(CACHE_KEY)
if data:
raw_dict = json.loads(data)
if provider_type in raw_dict:
return [CachedProvider(**item) for item in raw_dict[provider_type]]
return []
except Exception as e:
logger.warning("redis_cache_read_failed", error=str(e))
# Fallback to local memory
return _local_cache.get(provider_type, [])