"""供应商密钥加密存储服务。 使用 Fernet 对称加密,密钥从 SECRET_KEY 派生。 """ import base64 import hashlib from typing import TYPE_CHECKING from cryptography.fernet import Fernet, InvalidToken from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.logging import get_logger from app.db.admin_models import ProviderSecret if TYPE_CHECKING: pass logger = get_logger(__name__) class SecretEncryptionError(Exception): """密钥加密/解密错误。""" pass class SecretService: """供应商密钥加密存储服务。""" _fernet: Fernet | None = None @classmethod def _get_fernet(cls) -> Fernet: """获取 Fernet 实例,从 SECRET_KEY 派生加密密钥。""" if cls._fernet is None: # 从 SECRET_KEY 派生 32 字节密钥 key_bytes = hashlib.sha256(settings.secret_key.encode()).digest() fernet_key = base64.urlsafe_b64encode(key_bytes) cls._fernet = Fernet(fernet_key) return cls._fernet @classmethod def encrypt(cls, plaintext: str) -> str: """加密明文,返回 base64 编码的密文。 Args: plaintext: 要加密的明文 Returns: base64 编码的密文 """ if not plaintext: return "" fernet = cls._get_fernet() encrypted = fernet.encrypt(plaintext.encode()) return encrypted.decode() @classmethod def decrypt(cls, ciphertext: str) -> str: """解密密文,返回明文。 Args: ciphertext: base64 编码的密文 Returns: 解密后的明文 Raises: SecretEncryptionError: 解密失败 """ if not ciphertext: return "" try: fernet = cls._get_fernet() decrypted = fernet.decrypt(ciphertext.encode()) return decrypted.decode() except InvalidToken as e: logger.error("secret_decrypt_failed", error=str(e)) raise SecretEncryptionError("密钥解密失败,可能是 SECRET_KEY 已更改") from e @classmethod async def get_secret(cls, db: AsyncSession, name: str) -> str | None: """从数据库获取并解密密钥。 Args: db: 数据库会话 name: 密钥名称 Returns: 解密后的密钥值,不存在返回 None """ result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name)) secret = result.scalar_one_or_none() if secret is None: return None return cls.decrypt(secret.encrypted_value) @classmethod async def set_secret(cls, db: AsyncSession, name: str, value: str) -> ProviderSecret: """存储或更新加密密钥。 Args: db: 数据库会话 name: 密钥名称 value: 密钥明文值 Returns: ProviderSecret 实例 """ encrypted = cls.encrypt(value) result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name)) secret = result.scalar_one_or_none() if secret is None: secret = ProviderSecret(name=name, encrypted_value=encrypted) db.add(secret) else: secret.encrypted_value = encrypted await db.commit() await db.refresh(secret) logger.info("secret_stored", name=name) return secret @classmethod async def delete_secret(cls, db: AsyncSession, name: str) -> bool: """删除密钥。 Args: db: 数据库会话 name: 密钥名称 Returns: 是否删除成功 """ result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name)) secret = result.scalar_one_or_none() if secret is None: return False await db.delete(secret) await db.commit() logger.info("secret_deleted", name=name) return True @classmethod async def list_secrets(cls, db: AsyncSession) -> list[str]: """列出所有密钥名称(不返回值)。 Args: db: 数据库会话 Returns: 密钥名称列表 """ result = await db.execute(select(ProviderSecret.name)) return [row[0] for row in result.fetchall()] @classmethod async def get_api_key( cls, db: AsyncSession, provider_api_key: str | None, config_ref: str | None, ) -> str | None: """获取 Provider 的 API Key,按优先级查找。 优先级: 1. provider.api_key (数据库明文/加密) 2. provider.config_ref 指向的 ProviderSecret 3. 环境变量 (config_ref 作为变量名) Args: db: 数据库会话 provider_api_key: Provider 表中的 api_key 字段 config_ref: Provider 表中的 config_ref 字段 Returns: API Key 或 None """ # 1. 直接使用 provider.api_key if provider_api_key: # 尝试解密,如果失败则当作明文 try: decrypted = cls.decrypt(provider_api_key) if decrypted: return decrypted except SecretEncryptionError: pass return provider_api_key # 2. 从 ProviderSecret 表查找 if config_ref: secret_value = await cls.get_secret(db, config_ref) if secret_value: return secret_value # 3. 从环境变量查找 env_value = getattr(settings, config_ref.lower(), None) if env_value: return env_value return None