信息瓶颈理论

信息瓶颈理论(Information Bottleneck, IB)由 Tishby、Pereira 和 Bialek 于1999年提出,是理解深度学习和数据表示学习的重要理论框架。1

核心思想

信息瓶颈理论的核心目标是找到关于目标 信息最丰富、同时对输入 压缩最多的表示

输入 X ──┬──→ 表示 T ──→ 目标 Y
        │         ↑
        │         │
        └─────────┘
       信息流:最大化 I(T;Y),最小化 I(T;X)

直观理解

想象一个通信场景:

  • 是原始数据(如图片)
  • 是我们关心的标签(如类别)
  • 是传输的压缩表示(如神经网络的中间层)

目标:在保证足够信息量用于预测 的同时,尽可能压缩 的信息,丢弃不相关的细节。


形式化定义

基本设定

考虑随机变量三元组 ,满足 Markov 链

这意味着:

  • 给定 条件独立(
  • 完全由 决定

IB 优化问题

原始的约束优化形式:

等价地,使用拉格朗日形式

其中 控制压缩与信息保留之间的权衡:

行为
只关注信息保留, 保留所有 的信息
只关注压缩,完全忽略 的信息
适中最佳权衡

信息平面(Information Plane)

定义

信息平面是以 为坐标的二维平面:

I(Y;T)
  ↑
  │    · · · · · · IB曲线 · · · · ·
  │   ·                              ·
  │  ·                                ·
  │ ·                                  ·
  │·                                    ·
  │                                      ·
  └──────────────────────────────────────→ I(X;T)

IB 曲线

对于给定的数据分布 ,不同 值对应平面上的不同点,这些点构成 IB 曲线

IB 曲线的特点

  • 曲线上任意点都是 Pareto 最优的
  • 无法在增加 的同时减少 (反之亦然)

自洽方程

最优编码满足以下自洽方程:

其中 是归一化常数。


深度学习视角

Tishby 的训练阶段假说

Tishby 等人提出深度神经网络的训练过程可以理解为两个阶段:2

Loss
  │
  │╲
  │  ╲        拟合阶段
  │    ╲     (Fitting Phase)
  │      ╲
  │       ╲________________
  │                           ╲
  │                              ╲
  │                               ╲____
  │                                    ╲  压缩阶段
  │                                      ╲(Compression)
  └────────────────────────────────────────→ Epoch

1. 拟合阶段(Fitting Phase)

  • 快速增加
  • 网络学习预测标签
  • 神经元响应变得与标签相关

2. 压缩阶段(Compression Phase)

  • 逐渐减小
  • 网络丢弃冗余信息
  • 表示变得更加高效和通用

实验证据

Shwartz-Ziv 和 Tishby(2017)在实验中观察到:2

阶段Epochs现象
拟合0-200快速增加快速增加学习标签
过渡200-400缓慢变化缓慢变化平衡
压缩400+逐渐减小基本稳定泛化

深度变分信息瓶颈(Deep VIB)

基本思想

Deep VIB 将信息瓶颈目标应用于深度神经网络,使用变分近似来实现高效的优化。3

目标函数

原始目标:

变分下界

其中:

  • 随机编码器(近似
  • 先验分布(通常取标准高斯分布
  • 分类器

PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class VIBModule(nn.Module):
    """
    深度变分信息瓶颈模块
    
    核心思想:通过变分近似实现信息瓶颈目标
    - 重构/分类损失:最大化 I(Z; Y)
    - KL 正则项:最小化 I(Z; X)
    """
    def __init__(self, input_dim, latent_dim, num_classes, beta=1e-3):
        super().__init__()
        # 随机编码器:输出均值和对数方差
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * latent_dim)  # mean and log_var
        )
        # 分类器
        self.classifier = nn.Linear(latent_dim, num_classes)
        # 先验分布(标准高斯)
        self.prior_mean = torch.zeros(latent_dim)
        self.prior_log_var = torch.zeros(latent_dim)
        
        self.latent_dim = latent_dim
        self.beta = beta
        
    def reparameterize(self, mu, log_var):
        """重参数化技巧:使梯度可以通过随机采样反向传播"""
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def kl_divergence(self, mu, log_var):
        """计算与先验分布的 KL 散度"""
        prior_mean = self.prior_mean.to(mu.device)
        prior_log_var = self.prior_log_var.to(log_var.device)
        
        # D_KL(N(mu, sigma) || N(0, I))
        kl = 0.5 * torch.sum(
            prior_log_var - log_var + 
            (log_var.exp() + (mu - prior_mean).pow(2)) / prior_log_var.exp() - 1
        )
        return kl
        
    def forward(self, x):
        # 编码
        h = self.encoder(x)
        mu, log_var = h.chunk(2, dim=-1)
        z = self.reparameterize(mu, log_var)
        
        # 分类
        logits = self.classifier(z)
        
        return logits, mu, log_var, z
    
    def loss(self, x, y):
        """
        VIB 损失函数
        
        L = E[-log q(y|z)] + beta * D_KL(q(z|x) || r(z))
        
        第一项:重构/分类损失(最大化 I(Z;Y) 的下界)
        第二项:KL 正则项(最小化 I(Z;X))
        """
        logits, mu, log_var, z = self.forward(x)
        
        # 交叉熵分类损失
        ce_loss = F.cross_entropy(logits, y, reduction='mean')
        
        # KL 散度正则项
        kl_loss = self.kl_divergence(mu, log_var)
        
        # 总损失
        total_loss = ce_loss + self.beta * kl_loss
        
        return total_loss, ce_loss, kl_loss
    
    def get_mutual_info(self, x, y):
        """
        估算互信息的变分下界
        
        I(Z; Y) >= E_z~q(z|x)[log q(y|z)] + H(Y)
        I(Z; X) <= D_KL(q(z|x) || r(z)) + 常数
        """
        with torch.no_grad():
            logits, mu, log_var, z = self.forward(x)
            
            # I(Z;Y) 的下界估计
            log_probs = F.log_softmax(logits, dim=-1)
            i_zy = torch.gather(log_probs, 1, y.unsqueeze(1)).mean()
            
            # I(Z;X) 的上界估计
            kl = self.kl_divergence(mu, log_var)
            
        return i_zy, kl

使用示例

# 创建模型
model = VIBModule(input_dim=784, latent_dim=32, num_classes=10, beta=1e-3)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
 
# 训练循环
for epoch in range(100):
    for batch_x, batch_y in dataloader:
        loss, ce_loss, kl_loss = model.loss(batch_x, batch_y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 记录信息平面坐标
        i_zy, i_zx = model.get_mutual_info(batch_x, batch_y)
        info_plane_history.append((i_zx.item(), i_zy.item()))

VIB 的优势与应用

理论优势

特性描述
更好的泛化压缩表示减少过拟合风险
对抗鲁棒性随机编码增加对对抗样本的抵抗力
解耦表示促进学习独立的语义因子
可解释性信息平面可视化展示类别分离

实际应用

1. 对抗鲁棒性

Alemi 等人的实验表明:3

模型标准准确率对抗准确率(FGSM)
标准 CNN98.5%43.2%
VIB CNN97.8%71.3%

VIB 通过限制 使得对抗扰动难以影响表示

2. 表示解耦

通过 VIB 学习到的表示 通常:

  • 各维度之间相关性更低
  • 每个维度对应独立的语义概念
  • 便于可控生成和编辑

3. 主动学习

在标注数据有限的情况下,VIB 可以帮助选择最有信息量的样本进行标注。


信息瓶颈与注意力机制

注意力作为信息瓶颈

注意力机制可以被理解为在信息瓶颈框架下工作:4

从 IB 视角分析:

IB 组件注意力对应
Query 与 Key 交互后的压缩信息量
Softmax 权重分布与 Value 信息的保留程度
温度参数 控制 sharp/soft 程度

Sparse Attention

稀疏注意力通过限制注意力范围实现信息瓶颈效果:

class SparseAttention(nn.Module):
    def __init__(self, d_model, num_heads, sparsity=0.5):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, num_heads)
        self.sparsity = sparsity
    
    def forward(self, x, mask=None):
        attn_output, attn_weights = self.attn(x, x, x, mask)
        
        # 稀疏化:只保留 top-k 重要的注意力
        if self.training:
            # 随机丢弃部分注意力(类似 Dropout)
            keep_mask = torch.rand_like(attn_weights) > self.sparsity
            attn_weights = attn_weights * keep_mask / (1 - self.sparsity)
        
        return attn_output, attn_weights

核心公式速查

概念公式
IB 目标
约束形式
VIB 损失
自洽方程

参考

相关文章

Footnotes

  1. Tishby, N., Pereira, F.C., & Bialek, W. (1999). “The Information Bottleneck Method”. Proceedings of the 37th Annual Allerton Conference on Communication, Control, and Computing.

  2. Shwartz-Ziv, R., & Tishby, N. (2017). “Opening the Black Box of Deep Neural Networks via Information”. arXiv:1703.00810. 2

  3. Alemi, A.A., Fischer, I., Dillon, J.V., & Murphy, K. (2017). “Deep Variational Information Bottleneck”. ICLR. 2

  4. Zhao, H., et al. (2020). “Entropy-Lens: Understanding Transformers via Information”. NeurIPS 2020 Workshop.