Deep VIB:变分信息瓶颈实战
Deep VIB(Deep Variational Information Bottleneck)由 Alemi 等人于 2017 年提出,将信息瓶颈理论应用于深度学习,通过变分近似实现可扩展的端到端训练。1
核心思想回顾
信息瓶颈目标
其中:
- :输入与表示的互信息(压缩)
- :标签与表示的互信息(信息保留)
- :权衡参数
Deep VIB 解决方案
由于 和 难以直接计算,Deep VIB 使用变分近似:
完整 PyTorch 实现
基础模块
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from typing import Tuple, Optional
import matplotlib.pyplot as plt
class VIBModule(nn.Module):
"""
深度变分信息瓶颈模块
Architecture:
┌─────────────┐ ┌──────────────┐ ┌───────────────┐
│ Encoder │────▶│ Reparameterize│────▶│ Classifier │
│ q(z|μ,σ) │ │ z=μ+σε │ │ q(y|z) │
└─────────────┘ └──────────────┘ └───────────────┘
│ │
│ │
▼ ▼
┌──────────────┐ ┌───────────────┐
│ KL Div(z||r) │ │ Prior r(z) │
│ (正则项) │ │ N(0,I) │
└──────────────┘ └───────────────┘
Loss = CrossEntropy + β * KL
"""
def __init__(
self,
input_dim: int,
latent_dim: int,
num_classes: int,
hidden_dims: list = [1024, 512, 256],
beta: float = 1e-3,
use_batch_norm: bool = True
):
super().__init__()
self.latent_dim = latent_dim
self.beta = beta
# ==================== 编码器 ====================
encoder_layers = []
prev_dim = input_dim
for h_dim in hidden_dims:
encoder_layers.append(nn.Linear(prev_dim, h_dim))
if use_batch_norm:
encoder_layers.append(nn.BatchNorm1d(h_dim))
encoder_layers.append(nn.ReLU())
prev_dim = h_dim
# 最后输出均值和对数方差
encoder_layers.append(nn.Linear(prev_dim, 2 * latent_dim))
self.encoder = nn.Sequential(*encoder_layers)
# ==================== 分类器 ====================
self.classifier = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, num_classes)
)
# 先验分布参数
self.register_buffer('prior_mean', torch.zeros(latent_dim))
self.register_buffer('prior_log_var', torch.zeros(latent_dim))
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
"""
重参数化技巧
使得梯度可以通过随机采样反向传播:
z = μ + σ * ε, where ε ~ N(0, I)
"""
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""编码:返回均值和对数方差"""
h = self.encoder(x)
mu, log_var = h.chunk(2, dim=-1)
return mu, log_var
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""解码:返回分类 logits"""
return self.classifier(z)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
前向传播
Returns:
logits: 分类 logits
mu: 均值
log_var: 对数方差
z: 重参数化后的隐变量
"""
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
logits = self.decode(z)
return logits, mu, log_var, z
def kl_divergence(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
"""
计算 KL(q(z|x) || r(z))
对于 q = N(μ, σ²) 和 r = N(0, I):
D_KL = 0.5 * (σ² + μ² - 1 - log(σ²))
"""
# D_KL(N(μ,σ) || N(0,I))
kl = 0.5 * torch.sum(
log_var.exp() + mu.pow(2) - 1 - log_var,
dim=-1
)
return kl.mean()
def loss(
self,
x: torch.Tensor,
y: torch.Tensor,
reduction: str = 'mean'
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
VIB 损失函数
L = E[-log q(y|z)] + β * D_KL(q(z|x) || r(z))
第一项:分类损失(最大化 I(Z;Y))
第二项:KL 正则项(最小化 I(Z;X))
"""
logits, mu, log_var, z = self.forward(x)
# 交叉熵分类损失
ce_loss = F.cross_entropy(logits, y, reduction=reduction)
# KL 散度正则项
kl_loss = self.kl_divergence(mu, log_var)
# 总损失
total_loss = ce_loss + self.beta * kl_loss
return total_loss, ce_loss, kl_loss
def get_mutual_info_bounds(
self,
x: torch.Tensor,
y: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
估算互信息的变分界
I(Z;Y) >= E_z~q(z|x)[log q(y|z)] + H(Y) (下界)
I(Z;X) <= D_KL(q(z|x) || r(z)) (上界)
"""
with torch.no_grad():
logits, mu, log_var, z = self.forward(x)
# I(Z;Y) 的下界
log_probs = F.log_softmax(logits, dim=-1)
i_zy_lower = torch.gather(log_probs, 1, y.unsqueeze(1)).mean()
# I(Z;X) 的上界
i_zx_upper = self.kl_divergence(mu, log_var)
return i_zy_lower, i_zx_upper训练器
class VIBTrainer:
"""VIB 模型训练器"""
def __init__(
self,
model: VIBModule,
optimizer: torch.optim.Optimizer,
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
):
self.model = model.to(device)
self.optimizer = optimizer
self.device = device
self.history = {
'train_loss': [],
'train_ce': [],
'train_kl': [],
'train_acc': [],
'i_zy': [],
'i_zx': [],
}
def train_epoch(self, dataloader: DataLoader) -> dict:
"""训练一个 epoch"""
self.model.train()
total_loss = 0
total_ce = 0
total_kl = 0
correct = 0
total = 0
i_zy_sum = 0
i_zx_sum = 0
for batch_x, batch_y in dataloader:
batch_x = batch_x.to(self.device)
batch_y = batch_y.to(self.device)
# 前向传播
loss, ce_loss, kl_loss = self.model.loss(batch_x, batch_y)
# 反向传播
self.optimizer.zero_grad()
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# 统计
total_loss += loss.item() * batch_x.size(0)
total_ce += ce_loss.item() * batch_x.size(0)
total_kl += kl_loss.item() * batch_x.size(0)
_, predicted = self.model.decode(
self.model.reparameterize(*self.model.encode(batch_x))
).max(1)
correct += predicted.eq(batch_y).sum().item()
total += batch_y.size(0)
# 互信息估计
i_zy, i_zx = self.model.get_mutual_info_bounds(batch_x, batch_y)
i_zy_sum += i_zy.item()
i_zx_sum += i_zx.item()
n = len(dataloader.dataset)
metrics = {
'loss': total_loss / n,
'ce': total_ce / n,
'kl': total_kl / n,
'acc': correct / total,
'i_zy': i_zy_sum / len(dataloader),
'i_zx': i_zx_sum / len(dataloader),
}
for k, v in metrics.items():
self.history[k].append(v)
return metrics
def fit(
self,
train_loader: DataLoader,
epochs: int,
val_loader: Optional[DataLoader] = None
) -> dict:
"""完整训练流程"""
best_acc = 0
for epoch in range(epochs):
train_metrics = self.train_epoch(train_loader)
msg = f"Epoch {epoch+1}/{epochs} | "
msg += f"Loss: {train_metrics['loss']:.4f} | "
msg += f"CE: {train_metrics['ce']:.4f} | "
msg += f"KL: {train_metrics['kl']:.4f} | "
msg += f"Acc: {train_metrics['acc']:.4f}"
if val_loader:
val_metrics = self.evaluate(val_loader)
msg += f" | Val Acc: {val_metrics['acc']:.4f}"
if val_metrics['acc'] > best_acc:
best_acc = val_metrics['acc']
print(msg)
return self.history
def evaluate(self, dataloader: DataLoader) -> dict:
"""评估模型"""
self.model.eval()
correct = 0
total = 0
total_loss = 0
with torch.no_grad():
for batch_x, batch_y in dataloader:
batch_x = batch_x.to(self.device)
batch_y = batch_y.to(self.device)
logits, mu, log_var, z = self.model(batch_x)
loss, _, _ = self.model.loss(batch_x, batch_y, reduction='sum')
total_loss += loss.item()
_, predicted = logits.max(1)
correct += predicted.eq(batch_y).sum().item()
total += batch_y.size(0)
return {
'loss': total_loss / total,
'acc': correct / total
}
def plot_information_plane(self, save_path: Optional[str] = None):
"""绘制信息平面"""
i_zx = self.history['i_zx']
i_zy = self.history['i_zy']
plt.figure(figsize=(10, 6))
plt.scatter(i_zx, i_zy, c=range(len(i_zx)), cmap='viridis', s=50)
plt.colorbar(label='Epoch')
plt.xlabel('$I(Z; X)$ (Compressed)', fontsize=12)
plt.ylabel('$I(Z; Y)$ (Relevant)', fontsize=12)
plt.title('Information Plane Trajectory', fontsize=14)
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()使用示例
# 准备数据
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# 加载数据
X, y = load_breast_cancer(return_X_y=True)
X = StandardScaler().fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# 转换为 Tensor
X_train_t = torch.FloatTensor(X_train)
y_train_t = torch.LongTensor(y_train)
X_test_t = torch.FloatTensor(X_test)
y_test_t = torch.LongTensor(y_test)
# 创建数据加载器
train_dataset = TensorDataset(X_train_t, y_train_t)
test_dataset = TensorDataset(X_test_t, y_test_t)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)
# 创建模型
model = VIBModule(
input_dim=30,
latent_dim=16,
num_classes=2,
hidden_dims=[256, 128],
beta=1e-3 # 信息瓶颈权衡参数
)
# 创建优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 训练
trainer = VIBTrainer(model, optimizer)
history = trainer.fit(train_loader, epochs=100, val_loader=test_loader)
# 可视化信息平面
trainer.plot_information_plane()对抗鲁棒性分析
原理
VIB 的对抗鲁棒性源于其对 的限制:1
- 限制输入敏感性:通过 KL 正则项, 被限制在先验附近
- 对抗扰动被过滤:微小扰动 对 的影响被压缩
- 信息瓶颈效应:只有与 高度相关的信息才能通过瓶颈
形式化分析
对于输入扰动 :
VIB 通过限制 间接限制了 ,使得对抗扰动难以影响表示 。
实验对比
def attack_fgsm(model, x, y, epsilon=0.03):
"""FGSM 对抗攻击"""
x_adv = x.clone().detach().requires_grad_(True)
output = model.decode(
model.reparameterize(*model.encode(x_adv))
)
loss = F.cross_entropy(output, y)
loss.backward()
# 生成对抗样本
with torch.no_grad():
x_adv = x_adv + epsilon * x_adv.grad.sign()
x_adv = torch.clamp(x_adv, x - epsilon, x + epsilon)
return x_adv
def evaluate_adversarial(model, loader, device, epsilon=0.03):
"""评估对抗鲁棒性"""
model.eval()
correct_clean = 0
correct_adv = 0
total = 0
for x, y in loader:
x, y = x.to(device), y.to(device)
# 干净样本准确率
with torch.no_grad():
logits = model.decode(
model.reparameterize(*model.encode(x))
)
correct_clean += (logits.argmax(1) == y).sum().item()
# 对抗样本准确率
x_adv = attack_fgsm(model, x, y, epsilon)
with torch.no_grad():
logits_adv = model.decode(
model.reparameterize(*model.encode(x_adv))
)
correct_adv += (logits_adv.argmax(1) == y).sum().item()
total += y.size(0)
return {
'clean_acc': correct_clean / total,
'adv_acc': correct_adv / total
}
# 对比实验
results = {
'Standard': evaluate_adversarial(standard_model, test_loader, device),
'VIB (β=0.001)': evaluate_adversarial(vib_model, test_loader, device),
'VIB (β=0.01)': evaluate_adversarial(vib_model_strong, test_loader, device),
}
print("对抗鲁棒性对比 (ε=0.03):")
for name, metrics in results.items():
print(f"{name}: 干净准确率={metrics['clean_acc']:.4f}, 对抗准确率={metrics['adv_acc']:.4f}")典型结果
| 模型 | 干净准确率 | 对抗准确率(FGSM) | 提升 |
|---|---|---|---|
| Standard CNN | 98.5% | 43.2% | - |
| VIB (β=10⁻³) | 97.8% | 61.4% | +18.2% |
| VIB (β=10⁻²) | 96.5% | 71.3% | +28.1% |
β 调度策略
固定 β 的问题
- 太小:压缩不足,过拟合风险
- 太大:压缩过度,信息丢失
渐进式 β 调度
class VIBWithBetaSchedule(VIBModule):
"""带 β 调度的 VIB"""
def __init__(self, *args, warmup_epochs=10, max_beta=1e-2, **kwargs):
super().__init__(*args, beta=0, **kwargs) # 初始 β=0
self.warmup_epochs = warmup_epochs
self.max_beta = max_beta
def get_beta(self, epoch):
"""线性 warmup + 常数"""
if epoch < self.warmup_epochs:
return self.max_beta * epoch / self.warmup_epochs
return self.max_beta
def loss(self, x, y, epoch=0):
# 动态调整 beta
self.beta = self.get_beta(epoch)
return super().loss(x, y)KL annealing
def train_with_kl_annealing(model, loader, epochs):
"""KL annealing 训练策略"""
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(epochs):
# annealing 进度 (0 -> 1)
kl_weight = min(1.0, epoch / (epochs * 0.5))
for x, y in loader:
loss, ce, kl = model.loss(x, y)
# 加权的总损失
total_loss = ce + kl_weight * model.beta * kl
optimizer.zero_grad()
total_loss.backward()
optimizer.step()实战技巧
1. β 选择
| β 值 | 适用场景 |
|---|---|
| 10⁻⁴ ~ 10⁻³ | 任务难度高,需要更多信息 |
| 10⁻³ ~ 10⁻² | 标准设置,平衡性能与鲁棒性 |
| 10⁻² ~ 10⁻¹ | 高鲁棒性需求,对抗场景 |
2. 隐变量维度
隐变量维度 影响表示容量:
- 太小:信息瓶颈过强,无法捕获任务相关信息
- 太大:压缩不足,过拟合风险
建议从 开始尝试。
3. 先验选择
# 标准高斯先验
r(z) = N(0, I)
# 混合高斯先验(鼓励解耦)
r(z) = 0.5 * N(-1, 0.5²) + 0.5 * N(1, 0.5²)
# 均匀先验(最大熵)
r(z) = Uniform(-a, a)核心公式速查
| 概念 | 公式 |
|---|---|
| VIB 目标 | |
| KL 散度 | |
| 重参数化 | |
| I(Z;Y) 下界 | |
| I(Z;X) 上界 |