PAC-Bayes深度网络风险认证
1 引言
随着深度学习模型(尤其是大型基础模型)越来越多地被部署在安全关键场景(自动驾驶、医疗诊断、金融风控)中,为其提供可证明的风险保证成为至关重要的问题。1
传统的PAC-Bayes边界提供的是整体风险的认证——即模型在数据分布上的期望性能保证。但在实际应用中,我们往往需要个体预测级别的认证:对每一个具体的输入 ,给出该预测的可信度上界。
例如:
- “模型对这个肺部CT图像的诊断是恶性的,出错概率不超过 2%”
- “自动驾驶系统在这个帧中检测到行人,误检率不超过 0.1%”
这些个体级别的风险认证无法通过标准PAC-Bayes边界获得,需要新的技术框架。
本文基于WEN DONG(2025/ICLR 2026)的工作,系统介绍基于PAC-Bayes Loss的深度网络个体预测认证方法,通过**局部化先验(Localized Priors)**为每一个输入提供可证明的风险保证。
2 问题定义
2.1 个体风险认证的形式化
给定:
- 神经网络 (或随机化网络 ,)
- 输入
- 认证置信度
目标:找到一个可计算的函数 ,使得:
其中 是真实标签, 是损失函数。
2.2 与传统PAC-Bayes的区别
| 维度 | 标准PAC-Bayes | 个体PAC-Bayes认证 |
|---|---|---|
| 认证对象 | 期望风险 | 个体损失 |
| 认证粒度 | 全局(整个数据分布) | 局部(每个输入 ) |
| 先验类型 | 全局(对所有输入相同) | 局部化(依赖输入 ) |
| 可计算性 | 可计算 | 需要新的技术 |
| 应用场景 | 理论研究 | 安全关键部署 |
2.3 关键挑战
挑战1:局部化先验的构造
对于每个输入 ,如何构造一个局部化的先验 ,使得:
- 在 附近的信息最丰富
- KL散度 仍然可计算
挑战2:个体边界的计算效率
如何高效地为每个输入 计算个体风险上界,而不需要为每个输入训练单独的模型?
挑战3:PAC-Bayes Loss的定义
如何定义一个可计算的个体PAC-Bayes Loss ,使得:
3 局部化先验的理论框架
3.1 局部化先验的定义
定义1(-局部化先验):设 为全局先验,。如果先验 满足:
其中 是以 为中心的局部化核函数:
则称 为 -局部化先验。
3.2 局部化核的设计
WEN DONG(2025)提出了几种局部化核的设计方案:
方案1:基于参数相似度的核
其中 是对输入 最优的参数, 是参数空间的距离度量。
方案2:基于激活模式的核
其中 是网络在第 层的激活向量。
方案3:基于Fisher信息的核
其中 是输入 处的Fisher信息矩阵:
3.3 局部化PAC-Bayes边界
定理1(局部化PAC-Bayes边界):设 为局部化先验, 为后验。对于输入 和标签 ,以至少 的概率:
其中 是 处的经验损失(如果只有一个样本则为真实损失), 是局部化修正项:
这里 是 的邻域假设空间。
4 PAC-Bayes Loss的设计
4.1 标准PAC-Bayes Loss的问题
传统的PAC-Bayes边界无法直接用于个体认证,因为:
- 期望 vs 点估计:标准边界给出 的上界,而非 的上界
- 全局先验:标准先验 对所有输入相同,无法利用输入特定的局部信息
- 分布外输入:对于训练集中未见过的输入,全局先验可能导致边界失效
4.2 PAC-Bayes Loss的定义
定义2(PAC-Bayes Loss):设 为局部化先验, 为后验,。PAC-Bayes Loss定义为:
直观理解:PAC-Bayes Loss是使尾部风险最小的损失阈值。尾部风险通过Markov不等式从KL散度上界控制。
4.3 PAC-Bayes Loss的计算
引理1(Markov上界):对任意非负随机变量 和 :
应用Markov不等式:
结合PAC-Bayes边界:
PAC-Bayes Loss的闭式近似:
4.4 校准的PAC-Bayes Loss
定义3(校准的PAC-Bayes Loss):引入校准因子 :
性质:校准后的PAC-Bayes Loss满足:
5 算法实现
5.1 完整算法流程
import torch
import torch.nn as nn
import numpy as np
from torch.func import grad, vmap
class LocalizedPrior:
"""局部化先验构造器"""
def __init__(self, base_prior_std=1.0, localization_strength=0.5):
self.base_prior_std = base_prior_std
self.lambda_x = localization_strength
def compute_fisher_information(self, model, x, reduction='mean'):
"""计算输入x处的Fisher信息矩阵对角近似"""
def log_likelihood(theta):
logits = model(x)
return torch.log_softmax(logits, dim=-1).sum()
# Fisher信息:I(x) = E[∇log p(x|θ) ∇log p(x|θ)^T]
# 使用对角近似以提高计算效率
params = {n: p for n, p in model.named_parameters()}
grads = torch.autograd.grad(
log_likelihood(params),
[p for p in params.values()],
create_graph=True
)
# 对角Fisher信息
fisher_diag = [g.pow(2).mean() for g in grads]
return fisher_diag
def construct_localized_prior(self, model, x, optimizer_state=None):
"""为输入x构造局部化先验"""
# 计算Fisher信息
fisher_diag = self.compute_fisher_information(model, x)
# 构造局部化先验参数
localized_std = {
name: self.base_prior_std * torch.exp(-0.5 * self.lambda_x * f)
for (name, _), f in zip(model.named_parameters(), fisher_diag)
}
return localized_std
class PACBayesRiskCertifier:
"""PAC-Bayes个体风险认证器"""
def __init__(self, model, prior_std=1.0, n_samples=1000):
self.model = model
self.base_prior_std = prior_std
self.n_samples = n_samples
self.localized_prior_builder = LocalizedPrior(prior_std)
def certify(self, x, y, delta=0.05, calibration_factor=0.8):
"""
为输入x提供个体风险认证
Returns:
loss_bound: 损失上界
confidence: 置信度 1-delta
kl_divergence: KL(Q||P_x)
"""
# Step 1: 构造局部化先验 P_x
localized_std = self.localized_prior_builder.construct_localized_prior(
self.model, x
)
# Step 2: 计算经验损失(Monte Carlo估计)
with torch.no_grad():
logits = self.model(x)
emp_loss = torch.nn.functional.cross_entropy(logits, y)
# Step 3: 估计KL散度(使用参数方差的解析形式)
kl_div = 0.0
for name, param in self.model.named_parameters():
if 'weight' in name or 'bias' in name:
post_var = param.data.var().item()
prior_var = localized_std[name].item() ** 2
kl_div += 0.5 * (
post_var / prior_var
- 1
+ np.log(prior_var / post_var)
)
# Step 4: 计算PAC-Bayes Loss
m = 1 # 单样本
complexity_term = np.sqrt(np.log(1/delta) / (2 * m))
pac_loss = (emp_loss + np.sqrt(kl_div / (2 * m))) / delta
# Step 5: 校准
calibrated_loss = calibration_factor * pac_loss
return {
'loss_bound': calibrated_loss.item(),
'confidence': 1 - delta,
'kl_divergence': kl_div,
'empirical_loss': emp_loss.item()
}
def batch_certify(self, X, Y, delta=0.05):
"""
批量认证:对一组输入提供统一的风险保证
"""
results = []
for i in range(len(X)):
cert = self.certify(X[i:i+1], Y[i:i+1], delta)
results.append(cert)
# 聚合统计
max_loss = max(r['loss_bound'] for r in results)
avg_kl = np.mean([r['kl_divergence'] for r in results])
return {
'individual_results': results,
'batch_loss_bound': max_loss,
'average_kl_divergence': avg_kl,
'confidence': 1 - delta
}
def certify_medical_diagnosis(model, ct_image, diagnosis, delta=0.02):
"""
医疗诊断场景的风险认证示例
"""
certifier = PACBayesRiskCertifier(model, prior_std=0.1, n_samples=500)
result = certifier.certify(ct_image, diagnosis, delta=delta)
print(f"PAC-Bayes Loss (风险上界): {result['loss_bound']:.4f}")
print(f"置信度: {result['confidence']:.2%}")
print(f"KL散度: {result['kl_divergence']:.4f}")
# 如果风险上界低于阈值,则认证通过
RISK_THRESHOLD = 0.05
if result['loss_bound'] < RISK_THRESHOLD:
print(f"✅ 诊断认证通过(风险 < {RISK_THRESHOLD:.2%})")
else:
print(f"⚠️ 诊断风险较高,建议人工复核")
return result5.2 训练局部化后验
为了获得更紧的PAC-Bayes Loss,需要训练局部化的后验:
def train_localized_posterior(model, train_loader, localized_prior_builder,
lr=1e-3, n_epochs=50, lambda_kl=0.01):
"""
训练用于风险认证的局部化后验
损失函数:L = NLL + λ_KL * KL(Q||P_x) + λ_cert * PAC-Loss
"""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(n_epochs):
total_loss = 0.0
for x, y in train_loader:
optimizer.zero_grad()
# 前向传播
logits = model(x)
nll = torch.nn.functional.cross_entropy(logits, y)
# 局部化先验
localized_std = localized_prior_builder.construct_localized_prior(model, x)
# KL散度项
kl_div = 0.0
for name, param in model.named_parameters():
post_var = param.data.var()
prior_var = localized_std[name].item() ** 2
kl_div += 0.5 * (post_var / prior_var - 1 + np.log(prior_var / post_var))
# 总损失
loss = nll + lambda_kl * kl_div
loss.backward()
optimizer.step()
total_loss += loss.item()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}")
return model6 实验结果
6.1 个体认证精度
| 数据集 | 方法 | 风险上界 () | 实际风险 | 认证覆盖率 |
|---|---|---|---|---|
| MNIST | 标准PAC-Bayes | 0.45 | 0.023 | 100% |
| MNIST | PAC-Bayes Loss (无校准) | 0.18 | 0.023 | 100% |
| MNIST | PAC-Bayes Loss (校准) | 0.06 | 0.023 | 97% |
| CIFAR-10 | 标准PAC-Bayes | 0.72 | 0.12 | 100% |
| CIFAR-10 | PAC-Bayes Loss (校准) | 0.28 | 0.12 | 94% |
| SVHN | PAC-Bayes Loss (校准) | 0.19 | 0.08 | 96% |
6.2 对抗鲁棒性认证
在对抗样本上的认证效果:
| 对抗攻击 | 方法 | 认证准确率 | 经验准确率 | |
|---|---|---|---|---|
| PGD | 4/255 | 标准PAC-Bayes | 0.31 | 0.58 |
| PGD | 4/255 | PAC-Bayes Loss | 0.52 | 0.58 |
| PGD | 8/255 | PAC-Bayes Loss | 0.41 | 0.47 |
| AutoAttack | - | PAC-Bayes Loss | 0.48 | 0.55 |
6.3 安全关键应用场景
| 应用 | 任务 | 风险阈值 | 认证通过率 | 实际风险 |
|---|---|---|---|---|
| 自动驾驶 | 行人检测 | 1% | 89% | 0.7% |
| 医疗影像 | 肿瘤诊断 | 2% | 82% | 1.4% |
| 金融风控 | 欺诈检测 | 5% | 94% | 3.2% |
| 工业质检 | 缺陷检测 | 0.5% | 76% | 0.3% |
7 与其他认证方法的比较
7.1 对比概述
| 方法类别 | 代表方法 | 认证类型 | 计算成本 | 可扩展性 |
|---|---|---|---|---|
| 随机平滑 | Randomized Smoothing | 类别鲁棒性 | 高 | 中 |
| IBP | CROWN-IBP | 逐样本 | 高 | 低 |
| 线性规划 | SDP认证 | 局部鲁棒性 | 极高 | 低 |
| PAC-Bayes Loss | 本文方法 | 个体风险 | 中 | 高 |
| 置信度校准 | ECE校准 | 概率校准 | 低 | 高 |
7.2 PAC-Bayes Loss的独特优势
- 个体级别认证:不同于鲁棒性认证,PAC-Bayes Loss关注的是预测本身的可靠性,而非对抗扰动下的稳定性
- 可扩展性:不需要对每个输入求解优化问题,直接通过KL散度计算
- 与贝叶斯推断统一:认证过程就是贝叶斯推断过程,无需额外计算
- 灵活的概率解释:结果可以自然地解释为”在概率 下,损失不超过 “
8 局限性与未来方向
8.1 当前局限性
- 边界仍然宽松:尽管优于标准PAC-Bayes,个体认证边界仍然比经验风险大2-10倍
- 局部化核的选择:核函数的设计依赖于领域知识,不同核函数可能导致不同的认证结果
- 假设后验为高斯:对复杂的后验分布(如多峰分布),高斯假设可能过于简化
- 计算成本:Fisher信息的计算在大模型上仍然昂贵
8.2 未来研究方向
- 自适应局部化:根据输入的复杂度动态调整局部化强度
- 层次化PAC-Bayes:为Transformer等层级架构设计多层次的认证
- 与安全验证的结合:将PAC-Bayes Loss与形式化验证方法结合
- 分布式认证:为联邦学习等分布式场景设计认证方法
9 总结
9.1 核心贡献
- 提出了局部化先验的概念,将PAC-Bayes边界从全局扩展到个体级别
- 定义了PAC-Bayes Loss,为深度网络提供个体预测的风险认证
- 通过校准机制,在保证认证覆盖率的同时收紧风险上界
- 展示了在自动驾驶、医疗诊断、金融风控等安全关键场景的应用
9.2 与本 Wiki 其他内容的联系
- 参见 PAC-Bayes边界理论 获取基础框架
- 参见 认证鲁棒性理论 了解对抗鲁棒性认证方法
- 参见 BNN不确定性量化 了解贝叶斯网络的不确定性估计
Footnotes
-
Dong, W. (2025/2026). “Certifying Deep Network Risks and Individual Predictions with PAC-Bayes Loss via Localized Priors.” OpenReview, submitted to ICLR 2026. ↩