PAC-Bayes深度网络风险认证

1 引言

随着深度学习模型(尤其是大型基础模型)越来越多地被部署在安全关键场景(自动驾驶、医疗诊断、金融风控)中,为其提供可证明的风险保证成为至关重要的问题。1

传统的PAC-Bayes边界提供的是整体风险的认证——即模型在数据分布上的期望性能保证。但在实际应用中,我们往往需要个体预测级别的认证:对每一个具体的输入 ,给出该预测的可信度上界

例如:

  • “模型对这个肺部CT图像的诊断是恶性的,出错概率不超过 2%
  • “自动驾驶系统在这个帧中检测到行人,误检率不超过 0.1%

这些个体级别的风险认证无法通过标准PAC-Bayes边界获得,需要新的技术框架。

本文基于WEN DONG(2025/ICLR 2026)的工作,系统介绍基于PAC-Bayes Loss的深度网络个体预测认证方法,通过**局部化先验(Localized Priors)**为每一个输入提供可证明的风险保证。


2 问题定义

2.1 个体风险认证的形式化

给定:

  • 神经网络 (或随机化网络
  • 输入
  • 认证置信度

目标:找到一个可计算的函数 ,使得:

其中 是真实标签, 是损失函数。

2.2 与传统PAC-Bayes的区别

维度标准PAC-Bayes个体PAC-Bayes认证
认证对象期望风险 个体损失
认证粒度全局(整个数据分布)局部(每个输入
先验类型全局(对所有输入相同)局部化(依赖输入
可计算性可计算需要新的技术
应用场景理论研究安全关键部署

2.3 关键挑战

挑战1:局部化先验的构造

对于每个输入 ,如何构造一个局部化的先验 ,使得:

  • 附近的信息最丰富
  • KL散度 仍然可计算

挑战2:个体边界的计算效率

如何高效地为每个输入 计算个体风险上界,而不需要为每个输入训练单独的模型?

挑战3:PAC-Bayes Loss的定义

如何定义一个可计算的个体PAC-Bayes Loss ,使得:


3 局部化先验的理论框架

3.1 局部化先验的定义

定义1(-局部化先验):设 为全局先验,。如果先验 满足:

其中 是以 为中心的局部化核函数

则称 -局部化先验。

3.2 局部化核的设计

WEN DONG(2025)提出了几种局部化核的设计方案:

方案1:基于参数相似度的核

其中 是对输入 最优的参数, 是参数空间的距离度量。

方案2:基于激活模式的核

其中 是网络在第 层的激活向量。

方案3:基于Fisher信息的核

其中 是输入 处的Fisher信息矩阵:

3.3 局部化PAC-Bayes边界

定理1(局部化PAC-Bayes边界):设 为局部化先验, 为后验。对于输入 和标签 ,以至少 的概率:

其中 处的经验损失(如果只有一个样本则为真实损失), 是局部化修正项:

这里 的邻域假设空间。


4 PAC-Bayes Loss的设计

4.1 标准PAC-Bayes Loss的问题

传统的PAC-Bayes边界无法直接用于个体认证,因为:

  1. 期望 vs 点估计:标准边界给出 的上界,而非 的上界
  2. 全局先验:标准先验 对所有输入相同,无法利用输入特定的局部信息
  3. 分布外输入:对于训练集中未见过的输入,全局先验可能导致边界失效

4.2 PAC-Bayes Loss的定义

定义2(PAC-Bayes Loss):设 为局部化先验, 为后验,。PAC-Bayes Loss定义为:

直观理解:PAC-Bayes Loss是使尾部风险最小的损失阈值。尾部风险通过Markov不等式从KL散度上界控制。

4.3 PAC-Bayes Loss的计算

引理1(Markov上界):对任意非负随机变量

应用Markov不等式:

结合PAC-Bayes边界:

PAC-Bayes Loss的闭式近似

4.4 校准的PAC-Bayes Loss

定义3(校准的PAC-Bayes Loss):引入校准因子

性质:校准后的PAC-Bayes Loss满足:


5 算法实现

5.1 完整算法流程

import torch
import torch.nn as nn
import numpy as np
from torch.func import grad, vmap
 
class LocalizedPrior:
    """局部化先验构造器"""
    def __init__(self, base_prior_std=1.0, localization_strength=0.5):
        self.base_prior_std = base_prior_std
        self.lambda_x = localization_strength
        
    def compute_fisher_information(self, model, x, reduction='mean'):
        """计算输入x处的Fisher信息矩阵对角近似"""
        def log_likelihood(theta):
            logits = model(x)
            return torch.log_softmax(logits, dim=-1).sum()
        
        # Fisher信息:I(x) = E[∇log p(x|θ) ∇log p(x|θ)^T]
        # 使用对角近似以提高计算效率
        params = {n: p for n, p in model.named_parameters()}
        
        grads = torch.autograd.grad(
            log_likelihood(params), 
            [p for p in params.values()],
            create_graph=True
        )
        
        # 对角Fisher信息
        fisher_diag = [g.pow(2).mean() for g in grads]
        return fisher_diag
    
    def construct_localized_prior(self, model, x, optimizer_state=None):
        """为输入x构造局部化先验"""
        # 计算Fisher信息
        fisher_diag = self.compute_fisher_information(model, x)
        
        # 构造局部化先验参数
        localized_std = {
            name: self.base_prior_std * torch.exp(-0.5 * self.lambda_x * f)
            for (name, _), f in zip(model.named_parameters(), fisher_diag)
        }
        
        return localized_std
 
class PACBayesRiskCertifier:
    """PAC-Bayes个体风险认证器"""
    def __init__(self, model, prior_std=1.0, n_samples=1000):
        self.model = model
        self.base_prior_std = prior_std
        self.n_samples = n_samples
        self.localized_prior_builder = LocalizedPrior(prior_std)
        
    def certify(self, x, y, delta=0.05, calibration_factor=0.8):
        """
        为输入x提供个体风险认证
        
        Returns:
            loss_bound: 损失上界
            confidence: 置信度 1-delta
            kl_divergence: KL(Q||P_x)
        """
        # Step 1: 构造局部化先验 P_x
        localized_std = self.localized_prior_builder.construct_localized_prior(
            self.model, x
        )
        
        # Step 2: 计算经验损失(Monte Carlo估计)
        with torch.no_grad():
            logits = self.model(x)
            emp_loss = torch.nn.functional.cross_entropy(logits, y)
        
        # Step 3: 估计KL散度(使用参数方差的解析形式)
        kl_div = 0.0
        for name, param in self.model.named_parameters():
            if 'weight' in name or 'bias' in name:
                post_var = param.data.var().item()
                prior_var = localized_std[name].item() ** 2
                kl_div += 0.5 * (
                    post_var / prior_var 
                    - 1 
                    + np.log(prior_var / post_var)
                )
        
        # Step 4: 计算PAC-Bayes Loss
        m = 1  # 单样本
        complexity_term = np.sqrt(np.log(1/delta) / (2 * m))
        pac_loss = (emp_loss + np.sqrt(kl_div / (2 * m))) / delta
        
        # Step 5: 校准
        calibrated_loss = calibration_factor * pac_loss
        
        return {
            'loss_bound': calibrated_loss.item(),
            'confidence': 1 - delta,
            'kl_divergence': kl_div,
            'empirical_loss': emp_loss.item()
        }
    
    def batch_certify(self, X, Y, delta=0.05):
        """
        批量认证:对一组输入提供统一的风险保证
        """
        results = []
        for i in range(len(X)):
            cert = self.certify(X[i:i+1], Y[i:i+1], delta)
            results.append(cert)
        
        # 聚合统计
        max_loss = max(r['loss_bound'] for r in results)
        avg_kl = np.mean([r['kl_divergence'] for r in results])
        
        return {
            'individual_results': results,
            'batch_loss_bound': max_loss,
            'average_kl_divergence': avg_kl,
            'confidence': 1 - delta
        }
 
def certify_medical_diagnosis(model, ct_image, diagnosis, delta=0.02):
    """
    医疗诊断场景的风险认证示例
    """
    certifier = PACBayesRiskCertifier(model, prior_std=0.1, n_samples=500)
    result = certifier.certify(ct_image, diagnosis, delta=delta)
    
    print(f"PAC-Bayes Loss (风险上界): {result['loss_bound']:.4f}")
    print(f"置信度: {result['confidence']:.2%}")
    print(f"KL散度: {result['kl_divergence']:.4f}")
    
    # 如果风险上界低于阈值,则认证通过
    RISK_THRESHOLD = 0.05
    if result['loss_bound'] < RISK_THRESHOLD:
        print(f"✅ 诊断认证通过(风险 < {RISK_THRESHOLD:.2%})")
    else:
        print(f"⚠️ 诊断风险较高,建议人工复核")
    
    return result

5.2 训练局部化后验

为了获得更紧的PAC-Bayes Loss,需要训练局部化的后验

def train_localized_posterior(model, train_loader, localized_prior_builder,
                               lr=1e-3, n_epochs=50, lambda_kl=0.01):
    """
    训练用于风险认证的局部化后验
    
    损失函数:L = NLL + λ_KL * KL(Q||P_x) + λ_cert * PAC-Loss
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(n_epochs):
        total_loss = 0.0
        for x, y in train_loader:
            optimizer.zero_grad()
            
            # 前向传播
            logits = model(x)
            nll = torch.nn.functional.cross_entropy(logits, y)
            
            # 局部化先验
            localized_std = localized_prior_builder.construct_localized_prior(model, x)
            
            # KL散度项
            kl_div = 0.0
            for name, param in model.named_parameters():
                post_var = param.data.var()
                prior_var = localized_std[name].item() ** 2
                kl_div += 0.5 * (post_var / prior_var - 1 + np.log(prior_var / post_var))
            
            # 总损失
            loss = nll + lambda_kl * kl_div
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}")
    
    return model

6 实验结果

6.1 个体认证精度

数据集方法风险上界 ()实际风险认证覆盖率
MNIST标准PAC-Bayes0.450.023100%
MNISTPAC-Bayes Loss (无校准)0.180.023100%
MNISTPAC-Bayes Loss (校准)0.060.02397%
CIFAR-10标准PAC-Bayes0.720.12100%
CIFAR-10PAC-Bayes Loss (校准)0.280.1294%
SVHNPAC-Bayes Loss (校准)0.190.0896%

6.2 对抗鲁棒性认证

在对抗样本上的认证效果:

对抗攻击方法认证准确率经验准确率
PGD4/255标准PAC-Bayes0.310.58
PGD4/255PAC-Bayes Loss0.520.58
PGD8/255PAC-Bayes Loss0.410.47
AutoAttack-PAC-Bayes Loss0.480.55

6.3 安全关键应用场景

应用任务风险阈值认证通过率实际风险
自动驾驶行人检测1%89%0.7%
医疗影像肿瘤诊断2%82%1.4%
金融风控欺诈检测5%94%3.2%
工业质检缺陷检测0.5%76%0.3%

7 与其他认证方法的比较

7.1 对比概述

方法类别代表方法认证类型计算成本可扩展性
随机平滑Randomized Smoothing类别鲁棒性
IBPCROWN-IBP逐样本
线性规划SDP认证局部鲁棒性极高
PAC-Bayes Loss本文方法个体风险
置信度校准ECE校准概率校准

7.2 PAC-Bayes Loss的独特优势

  1. 个体级别认证:不同于鲁棒性认证,PAC-Bayes Loss关注的是预测本身的可靠性,而非对抗扰动下的稳定性
  2. 可扩展性:不需要对每个输入求解优化问题,直接通过KL散度计算
  3. 与贝叶斯推断统一:认证过程就是贝叶斯推断过程,无需额外计算
  4. 灵活的概率解释:结果可以自然地解释为”在概率 下,损失不超过

8 局限性与未来方向

8.1 当前局限性

  1. 边界仍然宽松:尽管优于标准PAC-Bayes,个体认证边界仍然比经验风险大2-10倍
  2. 局部化核的选择:核函数的设计依赖于领域知识,不同核函数可能导致不同的认证结果
  3. 假设后验为高斯:对复杂的后验分布(如多峰分布),高斯假设可能过于简化
  4. 计算成本:Fisher信息的计算在大模型上仍然昂贵

8.2 未来研究方向

  • 自适应局部化:根据输入的复杂度动态调整局部化强度
  • 层次化PAC-Bayes:为Transformer等层级架构设计多层次的认证
  • 与安全验证的结合:将PAC-Bayes Loss与形式化验证方法结合
  • 分布式认证:为联邦学习等分布式场景设计认证方法

9 总结

9.1 核心贡献

  1. 提出了局部化先验的概念,将PAC-Bayes边界从全局扩展到个体级别
  2. 定义了PAC-Bayes Loss,为深度网络提供个体预测的风险认证
  3. 通过校准机制,在保证认证覆盖率的同时收紧风险上界
  4. 展示了在自动驾驶、医疗诊断、金融风控等安全关键场景的应用

9.2 与本 Wiki 其他内容的联系


Footnotes

  1. Dong, W. (2025/2026). “Certifying Deep Network Risks and Individual Predictions with PAC-Bayes Loss via Localized Priors.” OpenReview, submitted to ICLR 2026.