"""Redis-backed cache for providers loaded from DB.""" import json from collections import defaultdict from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.core.redis import get_redis from app.db.admin_models import Provider from app.services.provider_policy import ProviderType logger = get_logger(__name__) class CachedProvider(BaseModel): """Serializable provider configuration matching DB model fields.""" id: str name: str type: str adapter: str model: str | None = None api_base: str | None = None api_key: str | None = None timeout_ms: int = 60000 max_retries: int = 1 weight: int = 1 priority: int = 0 enabled: bool = True config_json: dict | None = None config_ref: str | None = None # Local memory fallback (L1 cache) _local_cache: dict[ProviderType, list[CachedProvider]] = defaultdict(list) CACHE_KEY = "dreamweaver:providers:config" async def reload_providers(db: AsyncSession) -> dict[ProviderType, list[CachedProvider]]: """Reload providers from DB and update Redis cache.""" try: result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712 providers = result.scalars().all() # Convert to Pydantic models cached_list = [] for p in providers: cached_list.append( CachedProvider( id=p.id, name=p.name, type=p.type, adapter=p.adapter, model=p.model, api_base=p.api_base, api_key=p.api_key, timeout_ms=p.timeout_ms, max_retries=p.max_retries, weight=p.weight, priority=p.priority, enabled=p.enabled, config_json=p.config_json, config_ref=p.config_ref, ) ) # Group by type grouped: dict[str, list[CachedProvider]] = defaultdict(list) for cp in cached_list: grouped[cp.type].append(cp) # Sort for k in grouped: grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True) # Update Redis redis = await get_redis() # Serialize entire dict structure # Pydantic -> dict -> json json_data = {k: [p.model_dump() for p in v] for k, v in grouped.items()} await redis.set(CACHE_KEY, json.dumps(json_data)) # Update local cache _local_cache.clear() _local_cache.update(grouped) return grouped except Exception as e: logger.error("failed_to_reload_providers", error=str(e)) raise async def get_providers(provider_type: ProviderType) -> list[CachedProvider]: """Get providers from Redis (preferred) or local fallback.""" try: redis = await get_redis() data = await redis.get(CACHE_KEY) if data: raw_dict = json.loads(data) if provider_type in raw_dict: return [CachedProvider(**item) for item in raw_dict[provider_type]] return [] except Exception as e: logger.warning("redis_cache_read_failed", error=str(e)) # Fallback to local memory return _local_cache.get(provider_type, [])