稀疏神经网络训练方法

1. 概述

传统神经网络的训练遵循”密集初始化 → 密集训练 → 稀疏部署”的范式。然而,这种方法存在明显的低效:我们在训练时使用了全部参数,但最终只部署稀疏网络。稀疏训练 (Sparse Training) 提出了一种更优雅的方法:从一开始就保持网络稀疏,在训练过程中动态发现和优化稀疏结构1

稀疏训练的核心优势包括:

  • 内存效率:训练时只需存储和更新稀疏参数
  • 计算加速:稀疏操作可以显著减少计算量
  • 发现更好结构:动态调整可以找到比静态剪枝更好的稀疏模式
  • 理论价值:帮助我们理解神经网络的结构学习

本文将详细介绍稀疏训练的核心算法、实现技术,以及在现代架构(尤其是Transformer)中的应用。

2. 稀疏训练范式

2.1 稀疏训练 vs 密集训练 vs 静态剪枝

范式训练过程部署网络特点
密集训练密集密集或剪枝标准方法,资源浪费
静态剪枝密集稀疏先训练后剪枝,剪枝步骤可能次优
稀疏训练稀疏稀疏全程稀疏,结构自优化

2.2 稀疏模式分类

2.2.1 非结构化稀疏 (Unstructured Sparsity)

最灵活的稀疏形式,对权重矩阵中的任意位置进行稀疏化:

原始权重矩阵:          稀疏矩阵:
W = [[2.1, -0.3, 0.8],    S = [[2.1,  0, 0.8],
     [-1.2, 0.5, -0.7],         [  0, 0.5,  0 ],
     [0.3, 1.1, -0.4],          [0.3,  0,  0 ],
     [-0.6, 0.2, 1.3]]          [  0,  0, 1.3]]

优点:最高灵活性,可以达到极高的稀疏度
缺点:硬件加速困难,需要特殊的稀疏矩阵库

2.2.2 结构化稀疏 (Structured Sparsity)

按行、列、块或通道进行稀疏化,更容易硬件加速:

类型描述示例
行稀疏整行置零移除整个神经元
列稀疏整列置零移除输入维度
块稀疏m×n块置零移除神经元组
通道稀疏整个卷积核置零移除滤波器

2.2.3 N:M 结构化稀疏

NVIDIA Ampere架构引入的2:4稀疏模式:

每4个连续元素中恰好有2个为0:
[×, 0, ×, ×] ✓  [0, ×, 0, ×] ✓
[×, ×, 0, ×] ✓  [×, 0, ×, 0] ✓
[0, ×, ×, 0] ✗  [×, ×, ×, 0] ✗

这种模式可以实现2倍的理论加速,且有硬件支持。

3. 核心稀疏训练算法

3.1 RigL: Rigging the Lottery

RigL是由Google Brain提出的开创性稀疏训练算法2。其核心思想是:利用梯度信息动态发现和添加重要连接

3.1.1 算法原理

def rigl_step(model, optimizer, mask, sparsity, drop_fraction=0.3):
    """
    RigL单步更新
    """
    # 1. 正常稀疏前向/反向传播
    loss = forward_backward(model, mask)
    optimizer.step()
    
    # 2. 计算每个权重的重要性(梯度幅值)
    importance = compute_gradients(model)
    
    # 3. 移除不重要的连接(幅度最小)
    current_connections = get_active_connections(mask)
    num_to_drop = int(len(current_connections) * drop_fraction)
    weakest = argtopk(-|importance|, num_to_drop)
    mask[weakest] = 0
    
    # 4. 添加新连接(梯度最大)
    dead_connections = get_dead_connections(mask)
    num_to_add = num_to_drop
    strongest_dead = argtopk(importance[dead_connections], num_to_add)
    mask[strongest_dead] = 1
    
    return mask

3.1.2 关键设计决策

  1. drop_fraction调度

    • 初期:较高的drop率(如30%)
    • 后期:逐渐降低
    • 目标:平衡探索与利用
  2. 新连接生成策略

    • 方案A:从”死连接”(梯度长时间为0)中选择
    • 方案B:ERK初始化,允许连接到之前未活跃的位置
  3. 训练稳定性

    • 使用动量累积梯度信息
    • 避免频繁大幅度的结构变化

3.1.3 性能表现

网络数据集稀疏度Dense基线RigL
ResNet-50ImageNet80%76.8%76.4%
ResNet-50ImageNet95%76.8%74.4%
VGG-16CIFAR-1092%93.5%93.3%

3.2 SET: Sparse Evolutionary Training

SET是最早的稀疏训练方法之一,由Mostafa和Wang提出3。其核心思想简单而优雅:移除弱连接,添加随机新连接

3.2.1 算法原理

def set_step(model, optimizer, mask, sparsity, epsilon=1e-5):
    """
    SET单步更新
    """
    # 1. 正常训练
    loss = forward_backward(model, mask)
    optimizer.step()
    
    # 2. 移除最小幅度权重
    weights = get_weights(model)
    threshold = np.percentile(|weights|, sparsity * 100)
    mask[|weights| < threshold] = 0
    
    # 3. 添加随机新连接(保持稀疏度恒定)
    num_removed = initial_connections * sparsity
    dead_indices = where(mask == 0)
    new_indices = random.choice(dead_indices, num_removed, replace=False)
    mask[new_indices] = 1
    
    return mask

3.2.2 稀疏模式初始化:ERK

SET使用Erdős–Rényi–Kleinberg (ERK) 稀疏模式初始化:

def erk_init(layer_sizes, density):
    """
    ERK稀疏初始化
    
    对于第i层,稀疏度与输入输出维度的乘积成反比
    """
    n_layers = len(layer_sizes) - 1
    total_params = sum(layer_sizes[i] * layer_sizes[i+1] for i in range(n_layers))
    
    # 每个参数的"概率"与连接的维度成正比
    layer_fractions = [layer_sizes[i] + layer_sizes[i+1] for i in range(n_layers)]
    layer_fractions = [f / sum(layer_fractions) for f in layer_fractions]
    
    masks = []
    for i, (in_dim, out_dim) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
        # ERK公式:每个连接被保留的概率
        p = density * (in_dim + out_dim) / (in_dim * out_dim)
        p = min(p, 1.0)  # 限制最大概率
        mask = (random.random(in_dim, out_dim) < p).astype(float)
        masks.append(mask)
    
    return masks

为什么ERK有效? ERK基于这样的直觉:早期层(靠近输入)应该更密集,因为它们需要捕获更多样的特征。

3.3 SNFS: Sparse Networks From Scratch

SNFS是SET的改进版本,提出了固定稀疏度但优化的初始化方法4

3.3.1 与SET的关键区别

方面SETSNFS
新连接完全随机基于当前权重分布
初始化ERKSNFS初始化
训练稳定性一般更好

3.4 算法比较

算法剪枝策略重生长策略优点缺点
RigL梯度幅值梯度幅值利用学习信息计算开销较高
SET权重幅值完全随机简单高效随机性太强
SNFS权重幅值权重感知训练更稳定实现较复杂
DS-T多种策略多种策略通用框架需要调参

4. 动态稀疏训练框架

4.1 训练周期设计

标准的稀疏训练采用周期性结构

┌─────────────────────────────────────────────────────────────┐
│  Epoch 0-9      Epoch 10-19     Epoch 20-29    Epoch 30-39 │
│  ┌─────────┐    ┌─────────┐     ┌─────────┐    ┌─────────┐  │
│  │ Train   │ →  │ Prune   │ →  │ Train   │ →  │ Prune   │  │
│  │ (Sparse)│    │ 30%     │    │ (Sparse)│    │ 30%     │  │
│  └─────────┘    └─────────┘     └─────────┘    └─────────┘  │
└─────────────────────────────────────────────────────────────┘

4.2 剪枝调度策略

均匀调度 (Uniform Schedule)

def uniform_schedule(total_epochs, n_steps, sparsity_target):
    """
    均匀调度:每n_steps均匀剪枝
    """
    sparsity_per_step = sparsity_target / n_steps
    schedule = []
    for step in range(n_steps):
        current_sparsity = step * sparsity_per_step
        schedule.append(current_sparsity)
    return schedule

指数调度 (Exponential Schedule)

def exponential_schedule(total_epochs, n_steps, s0=0.1, sT=0.9):
    """
    指数调度:初期变化小,后期变化大
    """
    schedule = []
    for step in range(n_steps):
        t = step / n_steps
        # s(t) = s0 * (sT/s0)^t
        current_sparsity = s0 * (sT / s0) ** t
        schedule.append(current_sparsity)
    return schedule

梯度感知调度

更高级的方法是根据训练动态调整调度

  • 当梯度范数较低时,减少剪枝
  • 当训练损失停滞时,增加探索
  • 使用验证集性能作为反馈信号

5. Transformer中的稀疏训练

5.1 注意力稀疏性

Transformer的注意力机制存在固有的稀疏性:

  1. softmax的峰值特性:经过训练后,注意力分布往往集中于少数位置
  2. 长距离依赖:大部分位置之间的注意力权重接近零
  3. 冗余模式:不同头可能学习相似的注意力模式

5.1.1 Sparse Attention Patterns

模式描述示例
局部窗口只关注近邻Swin Transformer
随机稀疏随机选择k个位置Sparse Transformer
固定模式步长采样Longformer
学习模式数据驱动学习Routing Transformer

5.1.2 EcoSpa: LLM联合稀疏

EcoSpa是针对LLM的联合稀疏方法,同时稀疏化注意力层和FFN层5

def ecospa_train_step(model, inputs, sparsity_attn=0.5, sparsity_ffn=0.5):
    """
    EcoSpa: 联合稀疏训练
    
    关键创新:评估注意力层和FFN层之间的交互模式
    """
    # 1. 前向传播(稀疏)
    outputs = sparse_forward(model, inputs)
    loss = compute_loss(outputs, targets)
    
    # 2. 计算联合重要性
    attn_importance = compute_attention_importance(model)
    ffn_importance = compute_ffn_importance(model)
    
    # 3. 联合稀疏化决策
    # 考虑注意力-FFN交互模式
    joint_importance = attn_importance * ffn_importance
    mask = update_sparse_mask(joint_importance, 
                              sparsity_attn, sparsity_ffn)
    
    # 4. 更新掩码
    model.apply_mask(mask)
    
    return loss

性能结果

  • LLaMA-7B: 训练内存减少50%,速度提升21%
  • GPT-2-Medium: 困惑度降低2.4倍

5.2 FFN稀疏性

FFN(前馈网络)占据了Transformer参数的大部分(通常约2/3):

Transformer层参数分布:
┌─────────────────────────────────────────┐
│  FFN层          ████████████████████  ~67% │
│  注意力QKV      ██████                ~20% │
│  注意力输出     ████                   ~13% │
└─────────────────────────────────────────┘

5.2.1 MoE作为稀疏FFN

Mixture of Experts (MoE) 是实现FFN稀疏性的主要方法:

  • 每个输入只激活少数”专家”(FFN)
  • 路由机制决定使用哪些专家
  • 可以实现极高的稀疏度(如100B参数中只激活10B)

5.3 注意力头剪枝

类似神经元剪枝,我们可以剪枝不重要的注意力头:

def compute_head_importance(model, dataloader):
    """
    计算注意力头的重要性
    使用注意力头部对损失的影响作为重要性指标
    """
    importance = {}
    
    for batch in dataloader:
        outputs = model(batch)
        loss = outputs.loss
        
        # 计算每个头对损失的梯度
        for layer in model.layers:
            for head_idx, head in enumerate(layer.attention.heads):
                grad = torch.autograd.grad(
                    loss, head.output, retain_graph=True
                )[0]
                # 使用梯度范数作为重要性
                importance[f"layer{layer.id}_head{head_idx}"] = grad.norm()
    
    return importance

6. 硬件支持与加速

6.1 NVIDIA 2:4 结构化稀疏

NVIDIA Ampere架构引入了对2:4稀疏性的硬件支持6

// cuSPARSELt API示例
#include <cusparse.h>
 
void sparse_matrix_multiply(
    const float* dense_A,    // 输入密集矩阵
    const int* mask_B,       // 2:4稀疏掩码
    float* output_C,         // 输出
    int M, int N, int K,
    float alpha, float beta
);

性能提升:矩阵乘法吞吐量提升2倍

6.2 稀疏训练的系统设计

class SparseTrainingSystem:
    """
    高效稀疏训练系统
    """
    def __init__(self, model, sparsity=0.9):
        self.model = model
        self.sparsity = sparsity
        self.mask = self._init_mask()
        
    def _init_mask(self):
        """初始化稀疏掩码"""
        mask = {}
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                # 使用ERK初始化
                mask[name] = self._erk_mask(param.shape)
        return mask
    
    def sparse_forward(self, inputs):
        """稀疏前向传播"""
        # 只计算非零位置的乘积
        for name, param in self.model.named_parameters():
            if name in self.mask:
                param.mul_(self.mask[name])
        return self.model(inputs)
    
    def update_mask(self):
        """根据训练动态更新掩码"""
        for name, param in self.model.named_parameters():
            if name not in self.mask:
                continue
            
            # 计算梯度幅值
            grad_mag = param.grad.abs()
            
            # 计算新的稀疏掩码
            threshold = torch.quantile(
                grad_mag.flatten(), 
                self.sparsity
            )
            new_mask = (grad_mag > threshold).float()
            
            # 保持稀疏度恒定
            self.mask[name] = self._balance_mask(
                new_mask, self.sparsity
            )

7. 实践指南

7.1 何时使用稀疏训练

场景推荐方法原因
极高稀疏度 (>90%)RigL + ERK动态调整效果好
中等稀疏度 (50-80%)SET简单高效
硬件部署结构化稀疏硬件加速友好
LLM训练EcoSpa/MoE内存效率关键

7.2 超参数设置

# RigL推荐配置
config = {
    'drop_fraction': 0.3,          # 每步drop的比例
    'drop_fraction_end': 0.0,      # 最终drop比例
    'epoch_prune_start': 10,       # 开始剪枝的epoch
    'epoch_prune_end': 160,        # 结束剪枝的epoch
    'dense_allocation': 'erk',     # 初始稀疏分布
    'sparsity': 0.9,              # 目标稀疏度
}

7.3 常见问题与解决

问题原因解决方案
训练不稳定结构变化太频繁降低drop_fraction,使用动量
性能下降稀疏度过高降低目标稀疏度或增加训练时间
收敛慢学习率不合适初期使用较大学习率
局部最小随机性不足增加探索,使用ERK初始化

8. 未来方向

8.1 自动化稀疏模式发现

利用神经网络架构搜索(NAS)的思想,自动发现最优的稀疏模式:

# 自动化稀疏模式搜索
def auto_sparse_search(model, search_space, budget):
    """
    搜索最优的稀疏模式配置
    """
    best_config = None
    best_performance = 0
    
    for config in sample_configs(search_space, budget):
        # 训练并评估
        performance = train_and_evaluate(model, config)
        
        if performance > best_performance:
            best_performance = performance
            best_config = config
    
    return best_config

8.2 可微分稀疏化

将稀疏化决策参数化,使其可微:

class DifferentiableSparsity(torch.nn.Module):
    """
    可微分稀疏化
    """
    def __init__(self, n_params, target_sparsity):
        super().__init__()
        self.scores = torch.nn.Parameter(torch.randn(n_params))
        self.target_sparsity = target_sparsity
    
    def forward(self, weights):
        # 软化版的top-k选择
        # 使用Gumbel-Softmax等技术
        temperature = 1.0
        weights = weights * self.scores
        
        # Straight-through estimator
        hard_mask = (weights > torch.topk(
            weights, 
            int(self.target_sparsity * len(weights))
        )[0][-1]).float()
        
        return weights * hard_mask

8.3 跨任务稀疏迁移

在多个任务上联合训练,发现可迁移的稀疏结构:

  • 基础模型:学习通用的稀疏表示
  • 下游任务:微调稀疏掩码而非权重
  • 效率提升:大幅减少微调成本

9. 总结

稀疏神经网络训练代表了深度学习优化的一个重要方向。通过从一开始就保持网络稀疏,并在训练中动态发现和优化结构,我们可以在保持性能的同时显著提高训练和推理效率。

核心要点总结

  1. RigL利用梯度信息动态调整稀疏结构,是当前最有效的方法之一
  2. SET简单高效,适合资源受限的场景
  3. 结构化稀疏对于硬件加速至关重要
  4. LLM稀疏训练(如EcoSpa)可以显著降低训练成本
  5. 稀疏训练与彩票假说有深刻联系

参考资料

Footnotes

  1. Evci, U., et al. (2022). RigL: A Sparse Training Method for Neural Networks. Journal of Machine Learning Research (JMLR). https://arxiv.org/abs/1911.11134

  2. Evci, U., et al. (2020). Rigging the Lottery Ticket. International Conference on Learning Representations (ICLR). https://arxiv.org/abs/1911.11134

  3. Mostafa, H., & Wang, X. (2019). Parameter Efficient Training of Large Neural Networks. International Conference on Machine Learning (ICML). https://arxiv.org/abs/1912.00821

  4. Mocanu, D. C., et al. (2018). Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity. Scientific Reports. https://www.nature.com/articles/s41598-018-26290-4

  5. Zhang, J., et al. (2025). EcoSpa: Efficient Training of LLMs with Joint HHH Sparsity. arXiv. https://arxiv.org/abs/2511.11641

  6. NVIDIA. (2020). NVIDIA Ampere Architecture In-Depth. https://developer.nvidia.com/blog/nvidia-ampere-architecture-in-depth/