率失真理论及其在深度学习中的应用

率失真理论(Rate-Distortion Theory)是信息论的核心分支,研究在给定失真约束下能达到的最低压缩率(或等价的最高压缩率下能达到的最小失真)。1 近年来,这一理论被广泛应用于理解深度神经网络的压缩、泛化和表示学习问题,为神经网络压缩提供了坚实的理论基础。

1. 率失真函数基础

1.1 形式化定义

为原始信源, 为重构/压缩后的表示, 为失真度量函数。率失真函数定义为:

其中:

  • 是原始数据 与压缩表示 之间的互信息,即压缩所需的码率
  • 期望失真
  • 约束条件 限制了允许的最大失真

物理意义 给出了在平均失真不超过 的条件下,描述 所需的最少比特数。

1.2 率失真曲线

率失真曲线描述了压缩率与失真之间的基本权衡关系:

R(D)
  │
  │╲
  │  ╲       率失真曲线
  │    ╲    (R(D)曲线)
  │      ╲
  │        ╲___
  │           ╲___________
  │                          ╲____
  │                                   ╲____
  └──────────────────────────────────────────→ D
  0        D*                           D_max

关键性质

性质描述
非递增性
凸性 是关于 的凸函数(下凸)
零失真点,即无损压缩需要熵的比特数
最大失真点,当失真足够大时无需传输信息

1.3 Shannon 下界

率失真函数的下界由 Shannon 下界给出:

其中 失真熵(Distortion Entropy):

Shannon 下界给出了率失真函数理论上的最优值,实际的率失真函数可能偏离该下界。

1.4 典型失真度量

根据任务类型选择合适的失真度量:

任务常用失真度量公式
回归MSE
二值信号Hamming 距离
图像压缩SSIM结构相似性指数
生成模型对数似然
分布匹配KL 散度

2. 深度网络的率失真特性

2.1 各层的压缩比分析

深度神经网络可以视为一个信息处理管道,每一层都在进行不同程度的压缩。2

输入层 ──→ Layer 1 ──→ Layer 2 ──→ ... ──→ Layer N ──→ 输出
  │          │           │                     │
  │          │           │                     │
 I(X;T_0)  I(T_0;T_1)  I(T_1;T_2)          I(T_{N-1};T_N)

关键观察

  • 随层数增加呈先增后减趋势
  • 存在一个信息瓶颈层(通常在中层),该层的 最小
  • 最终层 反映了模型对输入的压缩程度

2.2 信息瓶颈与率失真的联系

信息瓶颈理论 与率失真理论有着深刻的联系。3

IB 目标

RD 目标

对比维度信息瓶颈率失真
优化目标压缩 + 预测压缩 + 保真
约束
权衡参数(或
应用表示学习数据压缩

两者可以统一在一个更一般的框架下:当失真 与标签 相关时, 存在内在联系。

2.3 表示瓶颈层的信息论约束

为网络的瓶颈表示, 为目标标签。根据率失真理论,有以下约束:

数据处理不等式的扩展

这意味着:

  • 信息损失 保留的关于 的信息不超过 本身的信息
  • 压缩极限 至少需要 比特来编码标签相关信息

泛化误差的下界

基于率失真理论,泛化误差 与压缩率 存在关系:4

其中 是表示维度, 是有效失真。

2.4 深度网络的两阶段特性

借鉴 信息瓶颈理论 中的两阶段假说,深度网络的率失真特性也可以分为两个阶段:

阶段Epochs行为
编码阶段初期快速增加提取输入特征
压缩阶段中后期逐渐减小丢弃冗余信息

这一特性为理解神经网络的训练动态提供了新的视角。


3. 压缩与泛化的联系

3.1 最小描述长度(MDL)视角

最小描述长度原则(Minimum Description Length, MDL)是压缩与泛化联系的理论基础之一。5

基本思想

给定数据集 和假设空间 ,MDL 原则认为:

  • 最好的模型是能够给出最短编码的模型
  • 编码长度 = 模型描述长度 + 数据描述长度(给定模型)

与率失真的联系

在 MDL 框架下:

  • 模型参数 是对数据的压缩表示
  • 残差 描述了压缩损失
  • 泛化能力强的模型具有简洁的参数表示小的压缩失真

3.2 率失真理论预测泛化误差

基于率失真理论,可以推导出泛化误差的上界:4

其中:

  • 是率失真函数
  • 是假设空间的复杂度
  • 是样本数量

核心洞察压缩程度越高( 越小),泛化误差的上界越小

3.3 压缩比与泛化能力的权衡

压缩比定义为:

压缩比 特征泛化表现
保留全部信息过拟合风险高
适中丢弃冗余,保留关键最佳泛化
过度压缩欠拟合
泛化误差
    │
    │      ╱╲
    │     ╱  ╲
    │    ╱    ╲      最优区域
    │   ╱      ╲___/
    │  ╱
    │ ╱
    └─────────────────────→ 压缩比 γ
        欠拟合  最优  过拟合

3.4 奥卡姆剃刀原则的信息论解释

从率失真角度,奥卡姆剃刀原则可以解释为:

简单模型对应较低的率失真 ,因而在相同泛化误差约束下需要更少的参数比特数。

这为为什么正则化(如权重衰减、dropout)能提升泛化提供了信息论解释:它们鼓励学习更「压缩」的表示。


4. 应用场景

4.1 神经网络压缩

剪枝(Pruning)的率失真分析

神经网络剪枝可以建模为率失真优化问题:6

其中 是二值掩码, 是剪枝后的网络输出。

剪枝策略的率失真准则

策略率失真优化目标
幅度剪枝移除权重幅度最小的连接
显著性剪枝最小化对 的影响
结构化剪枝移除神经元/通道组

量化(Quantization)

量化将连续权重映射到离散表示,降低存储和计算成本。

率失真视角

  • 码率:量化等级数 为码本大小)
  • 失真:量化误差

4.2 表示学习的最优编码

信息瓶颈编码

最优表示应该最大化 同时最小化 ,这正是率失真优化的目标:3

直觉:丢弃 中与 无关的「冗余信息」,只保留「相关知识」。

去噪自编码器的率失真解释

去噪自编码器(DAE)的目标可以写为:

这对应于率失真优化,其中:

  • 是加噪版本(对应
  • 是瓶颈表示
  • 是原始信号(重构目标)

4.3 知识蒸馏的信息论解释

知识蒸馏(Knowledge Distillation)将大模型(教师)的知识转移给小模型(学生)。7

信息论框架

设:

  • :教师模型的输出/表示
  • :学生模型的输出/表示
  • :真实标签

知识蒸馏的目标可以分解为:

组件含义
学生学习教师知识的程度
学生对真实标签的学习
权衡参数

蒸馏温度的信息论意义

温度参数 控制 softmax 的平滑程度:8

  • 低温度):提取「硬」知识,关注最大 logit
  • 高温度):提取「软」知识,保留类别间关系
  • 最优温度:最大化 的同时保持合理的
def distillation_loss(student_logits, teacher_logits, labels, temperature=4.0, alpha=0.5):
    """
    知识蒸馏损失
    
    L = alpha * CE(student, labels) + (1-alpha) * T^2 * KL(student_t, teacher_t)
    
    其中 T^2 是为了补偿温度对梯度尺度的影响
    """
    # 软目标损失(KL 散度)
    student_soft = F.log_softmax(student_logits / temperature, dim=-1)
    teacher_soft = F.softmax(teacher_logits / temperature, dim=-1)
    distill_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (temperature ** 2)
    
    # 硬目标损失(交叉熵)
    hard_loss = F.cross_entropy(student_logits, labels)
    
    # 加权组合
    return alpha * hard_loss + (1 - alpha) * distill_loss

5. 代码实现

5.1 率失真优化的基本框架

下面实现一个通用的率失真优化框架,支持多种失真度量和压缩策略:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Callable, Tuple, Optional
from dataclasses import dataclass
 
 
@dataclass
class RateDistortionConfig:
    """率失真优化配置"""
    latent_dim: int                    # 潜在表示维度
    beta: float = 1.0                  # 率失真权衡参数
    distortion_type: str = "mse"        # 失真类型:mse, bernoulli, gaussian
    prior_type: str = "gaussian"        # 先验分布类型
    target_rate: Optional[float] = None # 目标码率(比特/样本)
    target_distortion: Optional[float] = None  # 目标失真
 
 
class RateDistortionModel(nn.Module):
    """
    率失真优化模型
    
    目标函数:
        L = R(z|x) + beta * D(x, x_recon)
    
    其中:
        R(z|x) = D_KL(q(z|x) || p(z))  码率(近似互信息 I(X;Z))
        D(x, x_recon) = E[d(x, x_recon)]  失真
    """
    
    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
        config: RateDistortionConfig
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.config = config
        
        # 初始化先验分布参数
        self._init_prior()
    
    def _init_prior(self):
        """初始化先验分布"""
        if self.config.prior_type == "gaussian":
            self.prior_mean = nn.Parameter(
                torch.zeros(self.config.latent_dim), 
                requires_grad=False
            )
            self.prior_log_var = nn.Parameter(
                torch.zeros(self.config.latent_dim), 
                requires_grad=False
            )
        elif self.config.prior_type == "uniform":
            self.prior_mean = nn.Parameter(
                torch.zeros(self.config.latent_dim) - 0.5, 
                requires_grad=False
            )
            self.prior_scale = nn.Parameter(
                torch.ones(self.config.latent_dim) / np.sqrt(12), 
                requires_grad=False
            )
    
    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        编码:返回均值和对数方差
        
        Returns:
            mu: 均值
            log_var: 对数方差
        """
        h = self.encoder(x)
        mu, log_var = h.chunk(2, dim=-1)
        return mu, log_var
    
    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """解码:从潜在表示恢复原始信号"""
        return self.decoder(z)
    
    def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        """重参数化技巧"""
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def compute_rate(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        """
        计算码率:q(z|x) 与先验 p(z) 之间的 KL 散度
        
        R(z|x) ≈ D_KL(q(z|x) || p(z)) ≈ I(X; Z) 的上界
        """
        if self.config.prior_type == "gaussian":
            # 高斯先验下的 KL 散度
            prior_var = torch.exp(self.prior_log_var).to(mu.device)
            latent_var = torch.exp(log_var)
            
            rate = 0.5 * torch.sum(
                self.prior_log_var.to(mu.device) - log_var +
                (latent_var + mu.pow(2)) / prior_var - 1
            )
        else:
            # 其他先验的实现
            rate = torch.tensor(0.0, device=mu.device)
        
        return rate / mu.size(0)  # 按样本平均
    
    def compute_distortion(
        self, 
        x: torch.Tensor, 
        x_recon: torch.Tensor
    ) -> torch.Tensor:
        """
        计算失真
        
        支持多种失真度量
        """
        if self.config.distortion_type == "mse":
            distortion = F.mse_loss(x_recon, x, reduction='none').sum(dim=list(range(1, x.ndim)))
        elif self.config.distortion_type == "bernoulli":
            distortion = F.binary_cross_entropy(
                x_recon, x, reduction='none'
            ).sum(dim=list(range(1, x.ndim)))
        elif self.config.distortion_type == "gaussian":
            distortion = 0.5 * ((x - x_recon) ** 2).sum(dim=list(range(1, x.ndim)))
        else:
            raise ValueError(f"Unknown distortion type: {self.config.distortion_type}")
        
        return distortion.mean()
    
    def forward(self, x: torch.Tensor) -> dict:
        """
        前向传播
        
        Returns:
            dict: 包含重构 x_recon, 码率 rate, 失真 distortion 等
        """
        # 编码
        mu, log_var = self.encode(x)
        
        # 重参数化采样
        z = self.reparameterize(mu, log_var)
        
        # 解码
        x_recon = self.decode(z)
        
        # 计算损失
        rate = self.compute_rate(mu, log_var)
        distortion = self.compute_distortion(x, x_recon)
        
        return {
            'x_recon': x_recon,
            'mu': mu,
            'log_var': log_var,
            'z': z,
            'rate': rate,
            'distortion': distortion,
            'loss': rate + self.config.beta * distortion
        }
 
 
class RateDistortionTrainer:
    """
    率失真模型训练器
    
    支持:
    - 固定 beta 训练
    - 变分率失真训练(随训练动态调整 beta)
    - 目标码率/失真约束优化
    """
    
    def __init__(
        self,
        model: RateDistortionModel,
        optimizer: torch.optim.Optimizer,
        device: str = 'cuda'
    ):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.device = device
        self.history = {
            'rate': [],
            'distortion': [],
            'loss': []
        }
    
    def train_step(self, x: torch.Tensor) -> dict:
        """单步训练"""
        x = x.to(self.device)
        
        # 前向传播
        outputs = self.model(x)
        
        # 反向传播
        self.optimizer.zero_grad()
        outputs['loss'].backward()
        self.optimizer.step()
        
        # 记录历史
        self.history['rate'].append(outputs['rate'].item())
        self.history['distortion'].append(outputs['distortion'].item())
        self.history['loss'].append(outputs['loss'].item())
        
        return outputs
    
    def compute_rd_curve(
        self, 
        test_loader: torch.utils.data.DataLoader,
        beta_range: list = None
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        计算率失真曲线
        
        在不同的 beta 值下计算对应的 (rate, distortion) 点
        """
        if beta_range is None:
            beta_range = [0.001, 0.01, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0]
        
        original_beta = self.model.config.beta
        rates = []
        distortions = []
        
        for beta in beta_range:
            self.model.config.beta = beta
            
            # 收集该 beta 下的 rate 和 distortion
            batch_rates = []
            batch_distortions = []
            
            with torch.no_grad():
                for x, _ in test_loader:
                    x = x.to(self.device)
                    outputs = self.model(x)
                    batch_rates.append(outputs['rate'].item())
                    batch_distortions.append(outputs['distortion'].item())
            
            rates.append(np.mean(batch_rates))
            distortions.append(np.mean(batch_distortions))
        
        # 恢复原始 beta
        self.model.config.beta = original_beta
        
        return np.array(rates), np.array(distortions)
    
    def fit(
        self,
        train_loader: torch.utils.data.DataLoader,
        epochs: int,
        val_loader: Optional[torch.utils.data.DataLoader] = None
    ):
        """完整训练流程"""
        for epoch in range(epochs):
            epoch_rates = []
            epoch_distortions = []
            
            for x, _ in train_loader:
                outputs = self.train_step(x)
                epoch_rates.append(outputs['rate'].item())
                epoch_distortions.append(outputs['distortion'].item())
            
            # 打印训练进度
            if (epoch + 1) % 10 == 0:
                print(
                    f"Epoch {epoch+1}/{epochs} | "
                    f"Rate: {np.mean(epoch_rates):.4f} | "
                    f"Distortion: {np.mean(epoch_distortions):.4f}"
                )

5.2 变分率失真(Variational Rate Distortion)

变分率失真是率失真理论在大规模深度学习中的实用实现:

class VariationalRateDistortion(nn.Module):
    """
    变分率失真模型
    
    使用变分推断近似难以直接计算的率失真目标:
    
    L = E_q[log q(z|x)] - E_q[log p(z)] + beta * E_q[log p(x|z)]
    
    其中:
        - 第一项:编码器的负熵 H(Z|X)
        - 第二项:先验的负熵 H(Z)
        - 第三项:重构损失(负对数似然)
    
    关系推导:
        I(X; Z) = H(Z) - H(Z|X)
        R(D) ≈ I(X; Z) = -E_q[log p(z)] - H(Z|X)
    """
    
    def __init__(
        self,
        input_dim: int,
        latent_dim: int,
        hidden_dim: int = 256,
        beta: float = 1.0,
        beta_annealing: bool = True,
        beta_start: float = 0.001,
        beta_end: float = 1.0,
        annealing_epochs: int = 100
    ):
        super().__init__()
        
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.beta = beta
        self.beta_annealing = beta_annealing
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.annealing_epochs = annealing_epochs
        self.current_epoch = 0
        
        # 编码器:输入 -> 潜在分布参数
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, 2 * latent_dim)  # mean 和 log_var
        )
        
        # 解码器:潜在 -> 重构
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # 假设输入在 [0, 1]
        )
        
        # 先验分布参数(标准高斯)
        self.register_buffer('prior_mean', torch.zeros(latent_dim))
        self.register_buffer('prior_log_var', torch.zeros(latent_dim))
    
    def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        """重参数化"""
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def kl_divergence(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        """
        计算与先验的 KL 散度
        
        D_KL(N(mu, sigma) || N(0, I)) = 0.5 * (sigma^2 + mu^2 - 1 - log(sigma^2))
        """
        kl = 0.5 * torch.sum(
            log_var.exp() + mu.pow(2) - 1 - log_var,
            dim=-1
        )
        return kl.mean()
    
    def get_current_beta(self) -> float:
        """获取当前 epoch 的 beta 值(用于退火)"""
        if not self.beta_annealing:
            return self.beta
        
        if self.current_epoch < self.annealing_epochs:
            progress = self.current_epoch / self.annealing_epochs
            return self.beta_start + (self.beta_end - self.beta_start) * progress
        return self.beta_end
    
    def forward(self, x: torch.Tensor) -> dict:
        """前向传播"""
        # 编码
        h = self.encoder(x)
        mu, log_var = h.chunk(2, dim=-1)
        
        # 采样
        z = self.reparameterize(mu, log_var)
        
        # 解码
        x_recon = self.decoder(z)
        
        # 计算各部分损失
        kl_loss = self.kl_divergence(mu, log_var)
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='mean')
        
        # 获取当前 beta(支持退火)
        current_beta = self.get_current_beta()
        
        # 总损失:L = recon_loss + beta * kl_loss
        total_loss = recon_loss + current_beta * kl_loss
        
        return {
            'recon': x_recon,
            'z': z,
            'mu': mu,
            'log_var': log_var,
            'kl_loss': kl_loss,
            'recon_loss': recon_loss,
            'total_loss': total_loss,
            'beta': current_beta,
            'rate': kl_loss,  # KL 散度作为码率上界
            'distortion': recon_loss
        }
    
    def get_rate_distortion_metrics(self, x: torch.Tensor) -> dict:
        """计算率失真指标"""
        with torch.no_grad():
            outputs = self.forward(x)
            
        return {
            'rate_bits': outputs['rate'].item(),
            'distortion': outputs['distortion'].item(),
            'rate_distortion': outputs['rate'].item() + outputs['distortion'].item(),
            'compression_ratio': (
                self.input_dim / self.latent_dim if self.input_dim > 0 
                else float('inf')
            )
        }

5.3 深度神经网络的信息瓶颈分析工具

class InformationBottleneckAnalyzer:
    """
    深度神经网络的信息瓶颈分析工具
    
    用于估算各层之间的互信息 I(T_i; T_{i+1}) 和 I(T_i; Y)
    """
    
    def __init__(
        self,
        model: nn.Module,
        layer_names: list,
        data_loader: torch.utils.data.DataLoader,
        device: str = 'cuda'
    ):
        self.model = model
        self.layer_names = layer_names
        self.data_loader = data_loader
        self.device = device
        self.model.eval()
        
        # 存储各层激活
        self.activations = {name: [] for name in layer_names}
    
    @torch.no_grad()
    def collect_activations(self, x: torch.Tensor, y: torch.Tensor):
        """收集指定层的激活"""
        hooks = []
        
        def get_activation(name):
            def hook(module, input, output):
                self.activations[name].append(output.cpu())
            return hook
        
        # 注册 hook
        for name, module in self.model.named_modules():
            if name in self.layer_names:
                hooks.append(module.register_forward_hook(get_activation(name)))
        
        # 前向传播
        x = x.to(self.device)
        self.model(x)
        
        # 移除 hook
        for hook in hooks:
            hook.remove()
    
    def estimate_mutual_info(self, z: torch.Tensor, y: torch.Tensor, k: int = 10) -> float:
        """
        使用 KNN 估计 I(Z; Y)
        
        I(Z; Y) ≈ log(k) - ψ(k) + ψ(N_z) - ψ(N_{z|y})
        
        其中 ψ 是 digamma 函数
        """
        from scipy.special import digamma
        from sklearn.neighbors import NearestNeighbors
        
        z = z.numpy()
        y = y.numpy()
        n = len(z)
        
        # 估计 H(Z)
        nn = NearestNeighbors(n_neighbors=k + 1)
        nn.fit(z)
        dist, _ = nn.kneighbors(z)
        h_z = np.mean(np.log(dist[:, -1]) + np.log(n - 1) - digamma(k))
        
        # 估计 H(Z|Y)
        h_z_y = 0
        for label in np.unique(y):
            mask = y == label
            z_y = z[mask]
            if len(z_y) > k:
                nn_y = NearestNeighbors(n_neighbors=k + 1)
                nn_y.fit(z_y)
                dist_y, _ = nn_y.kneighbors(z_y)
                h_z_y += (len(z_y) / n) * np.mean(
                    np.log(dist_y[:, -1]) + np.log(len(z_y) - 1) - digamma(k)
                )
        
        return max(0, h_z - h_z_y)
    
    def analyze(self, max_batches: int = 100) -> dict:
        """
        分析网络各层的信息瓶颈特性
        
        Returns:
            dict: 各层的 I(X; T) 和 I(Y; T) 估计
        """
        results = {name: {'i_xt': [], 'i_ty': []} for name in self.layer_names}
        
        batch_count = 0
        for x, y in self.data_loader:
            if batch_count >= max_batches:
                break
            
            # 收集激活
            self.collect_activations(x, y)
            
            batch_count += 1
        
        # 计算互信息估计
        for name in self.layer_names:
            if len(self.activations[name]) == 0:
                continue
            
            acts = torch.cat(self.activations[name], dim=0)
            
            # 将激活展平或处理为适合估计的形式
            # 这里使用简化的处理方式
            z = acts.view(len(acts), -1).float()
            
            # 只用部分样本估算以节省计算
            sample_size = min(1000, len(z))
            indices = np.random.choice(len(z), sample_size, replace=False)
            
            try:
                i_ty = self.estimate_mutual_info(
                    z[indices].cpu(), 
                    y[indices % len(y)].cpu()
                )
                results[name]['i_ty'] = [i_ty]
            except Exception as e:
                results[name]['i_ty'] = [0.0]
        
        return results

5.4 使用示例

# 示例:MNIST 数据集上的率失真实验
 
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
 
 
def main():
    # 配置
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    latent_dim = 32
    input_dim = 784  # MNIST 28x28
    
    # 创建模型
    model = VariationalRateDistortion(
        input_dim=input_dim,
        latent_dim=latent_dim,
        hidden_dim=256,
        beta=1.0,
        beta_annealing=True,
        beta_start=0.001,
        beta_end=1.0,
        annealing_epochs=50
    ).to(device)
    
    # 数据加载
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.view(-1))
    ])
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
    
    # 训练器
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    # 训练循环
    epochs = 100
    for epoch in range(epochs):
        model.current_epoch = epoch
        epoch_losses = []
        
        for x, _ in train_loader:
            x = x.to(device)
            
            outputs = model(x)
            
            optimizer.zero_grad()
            outputs['total_loss'].backward()
            optimizer.step()
            
            epoch_losses.append({
                'total': outputs['total_loss'].item(),
                'recon': outputs['recon_loss'].item(),
                'kl': outputs['kl_loss'].item(),
                'beta': outputs['beta']
            })
        
        if (epoch + 1) % 10 == 0:
            avg = {k: np.mean([l[k] for l in epoch_losses]) for k in epoch_losses[0]}
            print(
                f"Epoch {epoch+1:3d} | "
                f"β={avg['beta']:.4f} | "
                f"Recon={avg['recon']:.4f} | "
                f"KL={avg['kl']:.4f} | "
                f"Total={avg['total']:.4f}"
            )
    
    # 分析率失真曲线
    print("\n分析率失真曲线...")
    
    # 在不同 beta 下评估
    beta_values = [0.01, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0]
    rd_curve = []
    
    for beta in beta_values:
        model.beta = beta
        model.beta_annealing = False
        
        rates = []
        distortions = []
        
        with torch.no_grad():
            for x, _ in train_loader:
                x = x.to(device)
                outputs = model(x)
                rates.append(outputs['rate'].item())
                distortions.append(outputs['distortion'].item())
        
        rd_curve.append({
            'beta': beta,
            'rate': np.mean(rates),
            'distortion': np.mean(distortions)
        })
        
        print(f"β={beta:.2f}: Rate={np.mean(rates):.4f}, Distortion={np.mean(distortions):.4f}")
    
    return model, rd_curve
 
 
if __name__ == "__main__":
    model, rd_curve = main()

6. 核心公式速查

概念公式
率失真函数
RD 优化目标
Shannon 下界
码率(高斯先验)
MSE 失真
IB 目标
压缩比
MDL 原理
蒸馏损失

参考

相关文章

Footnotes

  1. Berger, T. (1971). Rate-Distortion Theory: A Mathematical Basis for Data Compression. Prentice-Hall.

  2. Shwartz-Ziv, R., & Tishby, N. (2017). “Opening the Black Box of Deep Neural Networks via Information”. arXiv:1703.00810.

  3. Tishby, N., Pereira, F.C., & Bialek, W. (1999). “The Information Bottleneck Method”. Proceedings of the 37th Annual Allerton Conference. 2

  4. Hauhubei, P., & Tishby, N. (2019). “The Information Bottleneck and the Geometry of Deep Learning”. NeurIPS 2019. 2

  5. Rissanen, J. (1978). “Modeling by shortest data description”. Automatica, 14(5), 465-471.

  6. Aghajanyan, A., et al. (2020). “Intrinsic Dimensionality Explains the Effectiveness of Language Model Fine-Tuning”. ACL 2020.

  7. Hinton, G., Vinyals, O., & Dean, J. (2015). “Distilling the Knowledge in a Neural Network”. NIPS Workshop.

  8. Tang, J., et al. (2022). “On the Energy-Efficiency of Deep Learning”. JMLR.