"""Redis-backed cache for providers loaded from DB.""" import json from collections import defaultdict from typing import Literal from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.core.redis import get_redis from app.db.admin_models import Provider logger = get_logger(__name__) ProviderType = Literal["text", "image", "tts", "storybook"] class CachedProvider(BaseModel): """Serializable provider configuration matching DB model fields.""" id: str name: str type: str adapter: str model: str | None = None api_base: str | None = None api_key: str | None = None timeout_ms: int = 60000 max_retries: int = 1 weight: int = 1 priority: int = 0 enabled: bool = True config_json: dict | None = None config_ref: str | None = None # Local memory fallback (L1 cache) _local_cache: dict[ProviderType, list[CachedProvider]] = defaultdict(list) CACHE_KEY = "dreamweaver:providers:config" async def reload_providers(db: AsyncSession) -> dict[ProviderType, list[CachedProvider]]: """Reload providers from DB and update Redis cache.""" try: result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712 providers = result.scalars().all() # Convert to Pydantic models cached_list = [] for p in providers: cached_list.append(CachedProvider( id=p.id, name=p.name, type=p.type, adapter=p.adapter, model=p.model, api_base=p.api_base, api_key=p.api_key, timeout_ms=p.timeout_ms, max_retries=p.max_retries, weight=p.weight, priority=p.priority, enabled=p.enabled, config_json=p.config_json, config_ref=p.config_ref )) # Group by type grouped: dict[str, list[CachedProvider]] = defaultdict(list) for cp in cached_list: grouped[cp.type].append(cp) # Sort for k in grouped: grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True) # Update Redis redis = await get_redis() # Serialize entire dict structure # Pydantic -> dict -> json json_data = {k: [p.model_dump() for p in v] for k, v in grouped.items()} await redis.set(CACHE_KEY, json.dumps(json_data)) # Update local cache _local_cache.clear() _local_cache.update(grouped) return grouped except Exception as e: logger.error("failed_to_reload_providers", error=str(e)) raise async def get_providers(provider_type: ProviderType) -> list[CachedProvider]: """Get providers from Redis (preferred) or local fallback.""" try: redis = await get_redis() data = await redis.get(CACHE_KEY) if data: raw_dict = json.loads(data) if provider_type in raw_dict: return [CachedProvider(**item) for item in raw_dict[provider_type]] return [] except Exception as e: logger.warning("redis_cache_read_failed", error=str(e)) # Fallback to local memory return _local_cache.get(provider_type, [])