"""In-memory cache for providers loaded from DB.""" from collections import defaultdict from typing import Literal from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.db.admin_models import Provider ProviderType = Literal["text", "image", "tts", "storybook"] _cache: dict[ProviderType, list[Provider]] = defaultdict(list) async def reload_providers(db: AsyncSession): result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712 providers = result.scalars().all() grouped: dict[ProviderType, list[Provider]] = defaultdict(list) for p in providers: grouped[p.type].append(p) # sort by priority desc, then weight desc for k in grouped: grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True) _cache.clear() _cache.update(grouped) return _cache def get_providers(provider_type: ProviderType) -> list[Provider]: return _cache.get(provider_type, [])