信息瓶颈变体与表示学习
信息瓶颈(Information Bottleneck, IB)理论自1999年提出以来,已经发展出一个丰富的变体家族,用于解决不同场景下的表示学习问题。1本文系统梳理各类IB变体的数学形式、核心思想和实践方法,揭示其与表示学习的深层联系。
预备知识
基本设定
考虑随机变量三元组 ,满足 Markov 链 ,即:
- 完全由 决定()
- 给定 时, 与 条件独立
信息平面
以 为坐标的二维平面:
I(Y;T)
↑
│ · · · · · · IB曲线 · · · · ·
│ · ·
│ · ·
│ · ·
│· ·
│ ·
└──────────────────────────────────────→ I(X;T)
更多信息请参考 信息瓶颈理论。
1. 原始信息瓶颈(Original IB)
1.1 目标函数
原始IB的核心思想是找到压缩且任务相关的表示 。形式化为约束优化:
使用拉格朗日乘子法转化为无约束形式:
其中 控制压缩与信息保留之间的权衡:
| 值 | 行为 |
|---|---|
| 只关注信息保留, 保留所有 的信息 | |
| 只关注压缩,完全忽略 的信息 |
1.2 IB曲线的Pareto最优性
IB曲线(Information Bottleneck Curve)是 Pareto 最优前沿:
Pareto 最优性证明:
对于任意两个可行解 和 ,如果:
则两者都在 Pareto 前沿上,无法同时改进两个目标。
1.3 自洽方程推导
使用变分法求解IB优化问题。引入辅助分布 (类似变分后验),构建拉格朗日函数:
展开互信息:
对 施加变分,得到 自洽方程(Self-Consistent Equation):
其中 是归一化常数:
直观的解释:对于输入 ,编码器将高概率分配给:
- 先验 较大的区域
- 预测分布 与 接近的区域(KL散度小)
1.4 固定点迭代
自洽方程可通过迭代求解:
def ib_fixed_point_iteration(p_xy, beta, n_iter=100):
"""
原始IB的固定点迭代算法
Args:
p_xy: 联合分布 p(x,y)
beta: 拉格朗日乘子
n_iter: 迭代次数
Returns:
p_t_given_x: 编码分布 p(t|x)
"""
n_x = p_xy.shape[0]
n_y = p_xy.shape[1]
n_t = n_x # 潜在空间大小
# 初始化
p_t = np.ones(n_t) / n_t
p_y_given_t = np.ones((n_t, n_y)) / n_y
for _ in range(n_iter):
# E步:更新 p(t|x)
p_t_given_x = np.zeros((n_x, n_t))
for x in range(n_x):
kl_div = kl_divergence(p_xy[x] / p_xy[x].sum(), p_y_given_t) # D_KL(p(y|x)||p(y|t))
unnorm = p_t * np.exp(-beta * kl_div)
p_t_given_x[x] = unnorm / unnorm.sum()
# M步:更新 p(y|t) 和 p(t)
p_ty = p_xy.T @ p_t_given_x # p(y,t)
p_t = p_ty.sum(axis=0)
p_y_given_t = p_ty / (p_t[:, np.newaxis] + 1e-10)
p_t /= p_t.sum()
return p_t_given_x2. 变分信息瓶颈(Variational IB, VIB)
2.1 变分近似
原始IB的难点在于:
- 和 难以精确计算
- 自洽方程难以解析求解
- 高维连续空间中需要变分近似
VIB 使用变分推断技术,将IB目标转化为可优化的下界。
2.2 目标函数推导
原始目标(最大化):
变分下界推导:
第一步,引入变分分布 近似真实后验 :
第二步,引入先验分布 近似 :
综合得到 变分下界(Variational Lower Bound):
在实现中通常简化为:
2.3 与VAE的联系
VIB 与 变分自编码器(VAE) 有着深刻联系:
| 组件 | VIB | VAE |
|---|---|---|
| 隐变量后验 | ||
| 先验分布 | ||
| 重建损失 | ||
| 正则项 |
关键区别:
- VAE 关注重建输入
- VIB 关注预测标签
数学上,VAE 的 ELBO:
如果令 ,则 VIB 与 VAE 在形式上统一。
2.4 PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import kl_divergence, Normal
class VariationalInformationBottleneck(nn.Module):
"""
变分信息瓶颈(VIB)模块
目标函数:
L = E[-log q(y|z)] + beta * D_KL(q(z|x) || r(z))
其中:
- q(z|x): 变分编码器(近似 p(z|x))
- r(z): 先验分布(通常为 N(0, I))
- q(y|z): 分类器(近似 p(y|z))
"""
def __init__(self, input_dim, latent_dim, num_classes, beta=1e-3):
super().__init__()
self.latent_dim = latent_dim
self.beta = beta
self.num_classes = num_classes
# 变分编码器:输出高斯分布参数
self.encoder = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 2 * latent_dim) # [mean, log_var]
)
# 先验分布(标准高斯)
self.register_buffer('prior_mean', torch.zeros(latent_dim))
self.register_buffer('prior_log_var', torch.zeros(latent_dim))
# 分类器
self.classifier = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, num_classes)
)
def reparameterize(self, mean, log_var):
"""重参数化技巧:z = mu + sigma * epsilon"""
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mean + eps * std
def kl_divergence(self, mean, log_var):
"""
计算变分后验与先验的 KL 散度
D_KL(N(mu, sigma) || N(0, I))
= 0.5 * (sigma^2 + mu^2 - 1 - log(sigma^2))
"""
prior = Normal(self.prior_mean, torch.ones_like(log_var).exp())
posterior = Normal(mean, log_var.exp().sqrt())
return kl_divergence(posterior, prior).sum(dim=-1).mean()
def forward(self, x, training=True):
"""
前向传播
Args:
x: 输入数据 (batch_size, input_dim)
training: 是否在训练模式
Returns:
logits: 分类 logits (batch_size, num_classes)
mean: 潜在变量均值
log_var: 潜在变量对数方差
z: 重参数化后的潜在变量
"""
# 编码
h = self.encoder(x)
mean, log_var = h.chunk(2, dim=-1)
# 重参数化采样
if training:
z = self.reparameterize(mean, log_var)
else:
z = mean # 推理时使用均值
# 分类
logits = self.classifier(z)
return logits, mean, log_var, z
def loss(self, x, y):
"""
计算 VIB 损失
L = E[-log q(y|z)] + beta * D_KL(q(z|x) || r(z))
Returns:
total_loss: 总损失
ce_loss: 交叉熵损失(信息保留项)
kl_loss: KL 散度(压缩项)
"""
logits, mean, log_var, z = self.forward(x, training=True)
# 交叉熵损失(最大化 I(Z;Y))
ce_loss = F.cross_entropy(logits, y, reduction='mean')
# KL 散度正则(最小化 I(Z;X))
kl_loss = self.kl_divergence(mean, log_var)
# 总损失
total_loss = ce_loss + self.beta * kl_loss
return total_loss, ce_loss, kl_loss
def estimate_mutual_information(self, x, y):
"""
估计信息平面坐标
I(Z; Y) >= E[log q(y|z)] + H(Y)
I(Z; X) <= D_KL(q(z|x) || r(z))
"""
with torch.no_grad():
logits, mean, log_var, z = self.forward(x, training=False)
# I(Z;Y) 的下界
log_probs = F.log_softmax(logits, dim=-1)
i_zy = torch.gather(log_probs, 1, y.unsqueeze(1)).mean()
# I(Z;X) 的上界
i_zx = self.kl_divergence(mean, log_var)
return i_zy, i_zx
class VIBResNet(nn.Module):
"""
基于 ResNet 的 VIB 模型(可用于图像分类)
将标准 ResNet 的最后层替换为 VIB 模块
"""
def __init__(self, backbone, latent_dim, num_classes, beta=1e-3):
super().__init__()
self.backbone = backbone
self.vib = VariationalInformationBottleneck(
input_dim=backbone.output_dim,
latent_dim=latent_dim,
num_classes=num_classes,
beta=beta
)
def forward(self, x, training=True):
features = self.backbone(x)
return self.vib(features, training=training)2.5 使用示例
# 训练循环
def train_vib(model, train_loader, optimizer, device):
model.train()
total_loss = 0.0
total_ce = 0.0
total_kl = 0.0
for batch_x, batch_y in train_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
optimizer.zero_grad()
loss, ce_loss, kl_loss = model.loss(batch_x, batch_y)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_ce += ce_loss.item()
total_kl += kl_loss.item()
n_batches = len(train_loader)
return total_loss / n_batches, total_ce / n_batches, total_kl / n_batches
# 信息平面可视化
def plot_information_plane(model, data_loader, device):
"""可视化训练过程中的信息平面轨迹"""
import matplotlib.pyplot as plt
i_zx_list, i_zy_list = [], []
model.eval()
for batch_x, batch_y in data_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
i_zy, i_zx = model.estimate_mutual_information(batch_x, batch_y)
i_zx_list.append(i_zx.item())
i_zy_list.append(i_zy.item())
plt.figure(figsize=(8, 6))
plt.scatter(i_zx_list, i_zy_list, alpha=0.5)
plt.xlabel('$I(X; Z)$')
plt.ylabel('$I(Z; Y)$')
plt.title('Information Plane')
plt.grid(True)
plt.show()3. 条件信息瓶颈(Conditional IB, CIB)
3.1 问题背景
在许多实际场景中,我们不仅关心 ,还关心 条件 变量 的影响。例如:
- 领域适应: 表示源/目标域
- 对抗鲁棒性: 表示对抗扰动
- 因果推断: 表示混杂因素
3.2 条件互信息的优化
条件信息瓶颈的目标是最大化在给定 条件下 与 的互信息,同时最小化 与 的互信息:
其中条件互信息定义为:
I(Y; T \mid C) = \mathbb{E}_{p(c)}\left[I(Y; T \mid C=c)\right] = \mathbb{E}_{p(c)}\left[\mathbb{E}_{p(x,y \mid c)}}\left[\log \frac{p(y \mid t, c)}{p(y \mid c)}\right]\right]3.3 任务相关表示学习
条件IB的核心思想是学习任务相关的表示:
- 任务无关信息: 中与 无关的部分,应该被压缩
- 任务相关但域无关: 中与 相关但与 无关的部分,应该被保留
- 虚假相关: 中与 和 都相关的部分(虚假相关),应该被识别和处理
class ConditionalInformationBottleneck(nn.Module):
"""
条件信息瓶颈
学习在给定条件 C 下对 Y 有用的表示
"""
def __init__(self, input_dim, latent_dim, num_classes, num_domains, beta=1e-3):
super().__init__()
self.beta = beta
# 条件编码器:输出依赖输入和条件
self.encoder = nn.Sequential(
nn.Linear(input_dim + num_domains, 256), # 输入拼接条件
nn.ReLU(),
nn.Linear(256, 2 * latent_dim)
)
# 域无关分类器
self.classifier = nn.Linear(latent_dim, num_classes)
# 域判别器(用于对抗训练)
self.domain_discriminator = nn.Linear(latent_dim, num_domains)
def forward(self, x, c):
"""
Args:
x: 输入特征
c: 条件/域标签 (one-hot 或 index)
"""
# 拼接输入和条件
xc = torch.cat([x, c], dim=-1)
# 编码
h = self.encoder(xc)
mean, log_var = h.chunk(2, dim=-1)
# 重参数化
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
z = mean + eps * std
# 分类
logits = self.classifier(z)
domain_logits = self.domain_discriminator(z)
return logits, domain_logits, mean, log_var, z
def loss(self, x, y, c, alpha=0.5):
"""
条件IB损失
L = E[-log q(y|z)] + beta * D_KL(q(z|x,c) || r(z))
- alpha * E[log q(c|z)]
第三项:对抗损失,鼓励表示与条件解耦
"""
logits, domain_logits, mean, log_var, z = self.forward(x, c)
# 分类损失
ce_loss = F.cross_entropy(logits, y)
# KL 正则
prior = Normal(0, 1)
posterior = Normal(mean, log_var.exp().sqrt())
kl_loss = kl_divergence(posterior, prior).sum(dim=-1).mean()
# 对抗损失(最大化域预测误差 => 最小化 I(Z;C))
domain_loss = F.cross_entropy(domain_logits, c.argmax(dim=-1))
# 总损失
total_loss = ce_loss + self.beta * kl_loss - alpha * domain_loss
return total_loss, ce_loss, kl_loss, domain_loss4. 对比信息瓶颈(Contrastive IB, CIB)
4.1 与SimCLR、MoCo的联系
对比学习方法(如 InfoNCE)可以理解为一种特殊的IB变体。考虑对比学习中的数据增强:
对于表示 和 :
- 正样本对 应该相似 最大化
- 负样本对 应该不同 最小化虚假相关
4.2 正负样本对的信息论分析
假设增强分布 满足:
- :正样本分布
- :负样本分布(与 无关)
InfoNCE 损失的IB解释:
当温度 时,这近似于互信息的下界:
4.3 的IB解释
从IB视角,InfoNCE 同时优化两个目标:
| 目标 | InfoNCE 实现 | IB 对应 |
|---|---|---|
| 信息保留 | 最大化正样本对的相似度 | 最大化 |
| 压缩 | 使用归一化 + temperature | 限制 的增长 |
统一的理论框架:
其中第二项对应于负样本的对比正则。
4.4 PyTorch实现
class ContrastiveInformationBottleneck(nn.Module):
"""
对比信息瓶颈
将 InfoNCE 损失解释为 IB 目标的变分近似
"""
def __init__(self, encoder_dim, latent_dim, temperature=0.07, beta=1e-3):
super().__init__()
self.temperature = temperature
self.beta = beta
# 编码器
self.encoder = encoder
# 投影头
self.projection = nn.Sequential(
nn.Linear(encoder_dim, latent_dim),
nn.ReLU(),
nn.Linear(latent_dim, latent_dim)
)
def info_nce_loss(self, z_i, z_j):
"""
InfoNCE 损失
L = -log exp(s(z_i, z_j)/tau) / sum_k exp(s(z_i, z_k)/tau)
"""
batch_size = z_i.shape[0]
# 归一化
z_i = F.normalize(z_i, dim=1)
z_j = F.normalize(z_j, dim=1)
# 拼接所有表示
z = torch.cat([z_i, z_j], dim=0) # (2N, d)
# 相似度矩阵
sim = torch.mm(z, z.T) / self.temperature # (2N, 2N)
# 正样本对:(i, i+N) 和 (i+N, i)
N = batch_size
sim_i_pos = torch.diag(sim, N) # z_i 与对应 z_j
sim_j_pos = torch.diag(sim, -N) # z_j 与对应 z_i
pos = torch.cat([sim_i_pos, sim_j_pos], dim=0)
# 掩码自身
mask = torch.eye(2 * N, dtype=torch.bool, device=sim.device)
sim.masked_fill_(mask, -float('inf'))
# 负样本
neg = sim.logsumexp(dim=1)
# InfoNCE 损失
loss = -torch.mean(pos - neg)
return loss
def ib_regularization(self, z):
"""
IB 正则:鼓励表示压缩
使用方差正则近似信息压缩
"""
# 协方差正则
z = z - z.mean(dim=0)
cov = (z.T @ z) / (z.shape[0] - 1)
off_diag = cov.fill_diagonal_(0)
# 鼓励对角协方差(维度解耦)
diag_loss = off_diag.abs().mean()
# 鼓励均匀分布(信息最大化)
uniformity = torch.pdist(z, p=2).pow(2).mean()
return diag_loss + uniformity
def forward(self, x_i, x_j):
"""
Args:
x_i, x_j: 同一batch的两个增强视图
"""
# 编码
h_i = self.encoder(x_i)
h_j = self.encoder(x_j)
# 投影
z_i = self.projection(h_i)
z_j = self.projection(h_j)
# InfoNCE 损失(信息保留)
nce_loss = self.info_nce_loss(z_i, z_j)
# IB 正则(压缩)
all_z = torch.cat([z_i, z_j], dim=0)
ib_loss = self.ib_regularization(all_z)
# 总损失
total_loss = nce_loss + self.beta * ib_loss
return total_loss, nce_loss, ib_loss
# 与 MoCo 结合
class ContrastiveIBMoCo(nn.Module):
"""
对比IB与MoCo的结合
使用MoCo维护负样本队列,同时加入IB正则
"""
def __init__(self, encoder, projection_dim=128, K=65536, m=0.999, beta=1e-3):
super().__init__()
self.K = K
self.m = m
self.beta = beta
# 查询编码器
self.encoder_q = encoder
self.projection_q = nn.Linear(encoder.output_dim, projection_dim)
# 键编码器(动量更新)
self.encoder_k = copy.deepcopy(encoder)
self.projection_k = nn.Linear(encoder.output_dim, projection_dim)
# 负样本队列
self.register_buffer('queue', torch.randn(projection_dim, K))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key(self):
"""动量更新键编码器"""
for param_q, param_k in zip(
self.encoder_q.parameters(), self.encoder_k.parameters()
):
param_k.data.mul_(self.m).add_(param_q.data, alpha=1 - self.m)
for param_q, param_k in zip(
self.projection_q.parameters(), self.projection_k.parameters()
):
param_k.data.mul_(self.m).add_(param_q.data, alpha=1 - self.m)
def forward(self, x_q, x_k):
# 查询表示
q = F.normalize(self.projection_q(self.encoder_q(x_q)), dim=1)
# 键表示(动量编码器)
with torch.no_grad():
k = F.normalize(self.projection_k(self.encoder_k(x_k)), dim=1)
# 正样本 logits
l_pos = torch.einsum('nc,nc->n', q, k).unsqueeze(1)
# 负样本 logits
l_neg = torch.einsum('nc,ck->nk', q, self.queue.clone())
# 总 logits
logits = torch.cat([l_pos, l_neg], dim=1) / 0.07
# 标签
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
# 对比损失
contrastive_loss = F.cross_entropy(logits, labels)
# IB 正则:鼓励压缩(限制队列表示的熵)
queue_entropy = -(self.queue * torch.log(self.queue + 1e-10)).sum(dim=0).mean()
# 更新
self._momentum_update_key()
self._dequeue_and_enqueue(k)
return contrastive_loss + self.beta * queue_entropy5. 其他变体
5.1 任务导向IB(Task-Oriented IB, TOIB)
核心思想:不同任务需要不同级别的压缩。
传统IB对所有任务使用统一的压缩级别;TOIB根据任务特性自适应调整:
class TaskOrientedIB(nn.Module):
"""
任务导向信息瓶颈
为不同任务学习不同程度压缩的表示
"""
def __init__(self, input_dim, latent_dim, num_tasks):
super().__init__()
self.num_tasks = num_tasks
# 共享编码器
self.shared_encoder = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Linear(512, 256)
)
# 任务特定编码器(调整压缩级别)
self.task_encoders = nn.ModuleList([
nn.Linear(256, latent_dim * task_beta) # 不同任务不同潜在维度
for task_beta in [1.0, 0.75, 0.5, 0.25] # 压缩级别
])
# 任务分类器
self.task_classifiers = nn.ModuleList([
nn.Linear(latent_dim, num_classes)
for _ in range(num_tasks)
])
def forward(self, x, task_id):
# 共享特征
h = self.shared_encoder(x)
# 任务特定编码
z = self.task_encoders[task_id](h)
# 分类
logits = self.task_classifiers[task_id](z)
return logits, z5.2 归一化IB(Normalized IB, NIB)
问题:原始IB的权衡参数 对不同数据集和模型架构敏感。
解决方案:使用归一化的互信息:
归一化后的IB具有更好的跨设置泛化能力。
5.3 不确定性感知IB(Uncertainty-Aware IB, UAIB)
核心思想:区分认知不确定性( epistemic uncertainty)和偶然不确定性( aleatoric uncertainty)。
- 认知不确定性:由训练数据不足导致,可通过更多数据减少
- 偶然不确定性:数据本身的固有噪声,不可减少
class UncertaintyAwareIB(nn.Module):
"""
不确定性感知信息瓶颈
分解表示中的认知和偶然不确定性
"""
def __init__(self, input_dim, latent_dim, num_classes):
super().__init__()
# 共享编码器
self.encoder = nn.Linear(input_dim, latent_dim)
# 认知不确定性估计器(数据依赖)
self.epistemic_head = nn.Linear(latent_dim, latent_dim)
# 偶然不确定性估计器(输入依赖)
self.aleatoric_head = nn.Sequential(
nn.Linear(input_dim, latent_dim),
nn.Linear(latent_dim, 1) # 输出 log(sigma^2)
)
# 分类器
self.classifier = nn.Linear(latent_dim, num_classes)
def forward(self, x):
z = self.encoder(x)
# 认知不确定性
epistemic_var = torch.exp(self.epistemic_head(z))
# 偶然不确定性
aleatoric_var = torch.exp(self.aleatoric_head(x))
# 总不确定性
total_var = epistemic_var + aleatoric_var
return z, epistemic_var, aleatoric_var, total_var
def loss(self, x, y, beta=1e-3):
z, epi_var, alea_var, total_var = self.forward(x)
# 带不确定性的分类损失
logits = self.classifier(z)
# NLL 损失(隐式处理不确定性)
nll_loss = F.cross_entropy(logits, y, reduction='none')
nll_loss = (nll_loss / (alea_var + 1e-6)).mean()
# 不确定性正则
# 鼓励高认知不确定性(表示不确信)但低偶然不确定性
uncertainty_loss = epi_var.mean() - alea_var.mean()
total_loss = nll_loss + beta * uncertainty_loss
return total_loss, nll_loss, uncertainty_loss5.4 变体对比总结
| 变体 | 目标函数 | 主要应用 |
|---|---|---|
| 原始IB | 理论基础 | |
| VIB | $\min \mathbb{E}[-\log q(y | z)] + \beta D_{KL}$ |
| CIB | 领域适应 | |
| CIB (Contrastive) | 自监督学习 | |
| TOIB | 多任务学习 | |
| NIB | 跨设置泛化 | |
| UAIB | $\min \text{NLL} + \beta \cdot ( | \text{epi} |
6. 统一框架
6.1 IB变体的层次结构
┌─────────────────┐
│ Information │
│ Bottleneck │
└────────┬────────┘
│
┌──────────────────┼──────────────────┐
│ │ │
▼ ▼ ▼
┌──────────┐ ┌──────────┐ ┌──────────┐
│ Variational│ │Conditional│ │Contrastive│
│ IB │ │ IB │ │ IB │
└──────────┘ └──────────┘ └──────────┘
│ │ │
▼ ▼ ▼
┌──────────┐ ┌──────────┐ ┌──────────┐
│ VAE │ │DANN/ADDA │ │ SimCLR │
└──────────┘ └──────────┘ └──────────┘
6.2 统一数学形式
所有IB变体可以统一为:
其中 是辅助正则项:
- :原始IB
- :条件IB(解耦)
- :对比IB
核心公式速查
| 变体 | 公式 |
|---|---|
| 原始IB | |
| VIB损失 | $\mathbb{E}[-\log q(y\mid z)] + \beta D_{KL}(q(z\mid x)\ |
| 条件IB | |
| 对比IB | |
| 自洽方程 | $p(t\mid x) \propto p(t) \exp(-\beta D_{KL}(p(y\mid x)\ |
| InfoNCE下界 |
参考
相关文章
- 信息瓶颈理论 — IB基本概念与深度学习联系
- 信息论基础 — 熵、互信息、KL散度
- 对比学习与InfoNCE — InfoNCE详细解析
- 变分自编码器 — VAE与VIB的联系
- 变分推断 — 变分推断技术
Footnotes
-
Tishby, N., Pereira, F.C., & Bialek, W. (1999). “The Information Bottleneck Method”. Proceedings of the 37th Annual Allerton Conference on Communication, Control, and Computing. ↩