Some checks are pending
Build and Push Docker Images / changes (push) Waiting to run
Build and Push Docker Images / build-backend (push) Blocked by required conditions
Build and Push Docker Images / build-frontend (push) Blocked by required conditions
Build and Push Docker Images / build-admin-frontend (push) Blocked by required conditions
110 lines
3.5 KiB
Python
110 lines
3.5 KiB
Python
"""Redis-backed cache for providers loaded from DB."""
|
|
|
|
import json
|
|
from collections import defaultdict
|
|
from typing import Literal
|
|
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.core.redis import get_redis
|
|
from app.db.admin_models import Provider
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
ProviderType = Literal["text", "image", "tts", "storybook"]
|
|
|
|
|
|
class CachedProvider(BaseModel):
|
|
"""Serializable provider configuration matching DB model fields."""
|
|
id: str
|
|
name: str
|
|
type: str
|
|
adapter: str
|
|
model: str | None = None
|
|
api_base: str | None = None
|
|
api_key: str | None = None
|
|
timeout_ms: int = 60000
|
|
max_retries: int = 1
|
|
weight: int = 1
|
|
priority: int = 0
|
|
enabled: bool = True
|
|
config_json: dict | None = None
|
|
config_ref: str | None = None
|
|
|
|
|
|
# Local memory fallback (L1 cache)
|
|
_local_cache: dict[ProviderType, list[CachedProvider]] = defaultdict(list)
|
|
CACHE_KEY = "dreamweaver:providers:config"
|
|
|
|
|
|
async def reload_providers(db: AsyncSession) -> dict[ProviderType, list[CachedProvider]]:
|
|
"""Reload providers from DB and update Redis cache."""
|
|
try:
|
|
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
|
|
providers = result.scalars().all()
|
|
|
|
# Convert to Pydantic models
|
|
cached_list = []
|
|
for p in providers:
|
|
cached_list.append(CachedProvider(
|
|
id=p.id,
|
|
name=p.name,
|
|
type=p.type,
|
|
adapter=p.adapter,
|
|
model=p.model,
|
|
api_base=p.api_base,
|
|
api_key=p.api_key,
|
|
timeout_ms=p.timeout_ms,
|
|
max_retries=p.max_retries,
|
|
weight=p.weight,
|
|
priority=p.priority,
|
|
enabled=p.enabled,
|
|
config_json=p.config_json,
|
|
config_ref=p.config_ref
|
|
))
|
|
|
|
# Group by type
|
|
grouped: dict[str, list[CachedProvider]] = defaultdict(list)
|
|
for cp in cached_list:
|
|
grouped[cp.type].append(cp)
|
|
|
|
# Sort
|
|
for k in grouped:
|
|
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
|
|
|
|
# Update Redis
|
|
redis = await get_redis()
|
|
# Serialize entire dict structure
|
|
# Pydantic -> dict -> json
|
|
json_data = {k: [p.model_dump() for p in v] for k, v in grouped.items()}
|
|
await redis.set(CACHE_KEY, json.dumps(json_data))
|
|
|
|
# Update local cache
|
|
_local_cache.clear()
|
|
_local_cache.update(grouped)
|
|
return grouped
|
|
|
|
except Exception as e:
|
|
logger.error("failed_to_reload_providers", error=str(e))
|
|
raise
|
|
|
|
|
|
async def get_providers(provider_type: ProviderType) -> list[CachedProvider]:
|
|
"""Get providers from Redis (preferred) or local fallback."""
|
|
try:
|
|
redis = await get_redis()
|
|
data = await redis.get(CACHE_KEY)
|
|
if data:
|
|
raw_dict = json.loads(data)
|
|
if provider_type in raw_dict:
|
|
return [CachedProvider(**item) for item in raw_dict[provider_type]]
|
|
return []
|
|
except Exception as e:
|
|
logger.warning("redis_cache_read_failed", error=str(e))
|
|
|
|
# Fallback to local memory
|
|
return _local_cache.get(provider_type, [])
|