Some checks are pending
Build and Push Docker Images / changes (push) Waiting to run
Build and Push Docker Images / build-backend (push) Blocked by required conditions
Build and Push Docker Images / build-frontend (push) Blocked by required conditions
Build and Push Docker Images / build-admin-frontend (push) Blocked by required conditions
208 lines
6.0 KiB
Python
208 lines
6.0 KiB
Python
"""供应商密钥加密存储服务。
|
||
|
||
使用 Fernet 对称加密,密钥从 SECRET_KEY 派生。
|
||
"""
|
||
|
||
import base64
|
||
import hashlib
|
||
from typing import TYPE_CHECKING
|
||
|
||
from cryptography.fernet import Fernet, InvalidToken
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.config import settings
|
||
from app.core.logging import get_logger
|
||
from app.db.admin_models import ProviderSecret
|
||
|
||
if TYPE_CHECKING:
|
||
pass
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class SecretEncryptionError(Exception):
|
||
"""密钥加密/解密错误。"""
|
||
|
||
pass
|
||
|
||
|
||
class SecretService:
|
||
"""供应商密钥加密存储服务。"""
|
||
|
||
_fernet: Fernet | None = None
|
||
|
||
@classmethod
|
||
def _get_fernet(cls) -> Fernet:
|
||
"""获取 Fernet 实例,从 SECRET_KEY 派生加密密钥。"""
|
||
if cls._fernet is None:
|
||
# 从 SECRET_KEY 派生 32 字节密钥
|
||
key_bytes = hashlib.sha256(settings.secret_key.encode()).digest()
|
||
fernet_key = base64.urlsafe_b64encode(key_bytes)
|
||
cls._fernet = Fernet(fernet_key)
|
||
return cls._fernet
|
||
|
||
@classmethod
|
||
def encrypt(cls, plaintext: str) -> str:
|
||
"""加密明文,返回 base64 编码的密文。
|
||
|
||
Args:
|
||
plaintext: 要加密的明文
|
||
|
||
Returns:
|
||
base64 编码的密文
|
||
"""
|
||
if not plaintext:
|
||
return ""
|
||
fernet = cls._get_fernet()
|
||
encrypted = fernet.encrypt(plaintext.encode())
|
||
return encrypted.decode()
|
||
|
||
@classmethod
|
||
def decrypt(cls, ciphertext: str) -> str:
|
||
"""解密密文,返回明文。
|
||
|
||
Args:
|
||
ciphertext: base64 编码的密文
|
||
|
||
Returns:
|
||
解密后的明文
|
||
|
||
Raises:
|
||
SecretEncryptionError: 解密失败
|
||
"""
|
||
if not ciphertext:
|
||
return ""
|
||
try:
|
||
fernet = cls._get_fernet()
|
||
decrypted = fernet.decrypt(ciphertext.encode())
|
||
return decrypted.decode()
|
||
except InvalidToken as e:
|
||
logger.error("secret_decrypt_failed", error=str(e))
|
||
raise SecretEncryptionError("密钥解密失败,可能是 SECRET_KEY 已更改") from e
|
||
|
||
@classmethod
|
||
async def get_secret(cls, db: AsyncSession, name: str) -> str | None:
|
||
"""从数据库获取并解密密钥。
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
name: 密钥名称
|
||
|
||
Returns:
|
||
解密后的密钥值,不存在返回 None
|
||
"""
|
||
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
|
||
secret = result.scalar_one_or_none()
|
||
if secret is None:
|
||
return None
|
||
return cls.decrypt(secret.encrypted_value)
|
||
|
||
@classmethod
|
||
async def set_secret(cls, db: AsyncSession, name: str, value: str) -> ProviderSecret:
|
||
"""存储或更新加密密钥。
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
name: 密钥名称
|
||
value: 密钥明文值
|
||
|
||
Returns:
|
||
ProviderSecret 实例
|
||
"""
|
||
encrypted = cls.encrypt(value)
|
||
|
||
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
|
||
secret = result.scalar_one_or_none()
|
||
|
||
if secret is None:
|
||
secret = ProviderSecret(name=name, encrypted_value=encrypted)
|
||
db.add(secret)
|
||
else:
|
||
secret.encrypted_value = encrypted
|
||
|
||
await db.commit()
|
||
await db.refresh(secret)
|
||
logger.info("secret_stored", name=name)
|
||
return secret
|
||
|
||
@classmethod
|
||
async def delete_secret(cls, db: AsyncSession, name: str) -> bool:
|
||
"""删除密钥。
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
name: 密钥名称
|
||
|
||
Returns:
|
||
是否删除成功
|
||
"""
|
||
result = await db.execute(select(ProviderSecret).where(ProviderSecret.name == name))
|
||
secret = result.scalar_one_or_none()
|
||
if secret is None:
|
||
return False
|
||
|
||
await db.delete(secret)
|
||
await db.commit()
|
||
logger.info("secret_deleted", name=name)
|
||
return True
|
||
|
||
@classmethod
|
||
async def list_secrets(cls, db: AsyncSession) -> list[str]:
|
||
"""列出所有密钥名称(不返回值)。
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
|
||
Returns:
|
||
密钥名称列表
|
||
"""
|
||
result = await db.execute(select(ProviderSecret.name))
|
||
return [row[0] for row in result.fetchall()]
|
||
|
||
@classmethod
|
||
async def get_api_key(
|
||
cls,
|
||
db: AsyncSession,
|
||
provider_api_key: str | None,
|
||
config_ref: str | None,
|
||
) -> str | None:
|
||
"""获取 Provider 的 API Key,按优先级查找。
|
||
|
||
优先级:
|
||
1. provider.api_key (数据库明文/加密)
|
||
2. provider.config_ref 指向的 ProviderSecret
|
||
3. 环境变量 (config_ref 作为变量名)
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
provider_api_key: Provider 表中的 api_key 字段
|
||
config_ref: Provider 表中的 config_ref 字段
|
||
|
||
Returns:
|
||
API Key 或 None
|
||
"""
|
||
# 1. 直接使用 provider.api_key
|
||
if provider_api_key:
|
||
# 尝试解密,如果失败则当作明文
|
||
try:
|
||
decrypted = cls.decrypt(provider_api_key)
|
||
if decrypted:
|
||
return decrypted
|
||
except SecretEncryptionError:
|
||
pass
|
||
return provider_api_key
|
||
|
||
# 2. 从 ProviderSecret 表查找
|
||
if config_ref:
|
||
secret_value = await cls.get_secret(db, config_ref)
|
||
if secret_value:
|
||
return secret_value
|
||
|
||
# 3. 从环境变量查找
|
||
env_value = getattr(settings, config_ref.lower(), None)
|
||
if env_value:
|
||
return env_value
|
||
|
||
return None
|