模型剪枝技术

模型剪枝(Pruning)通过移除不重要的权重或神经元来减少模型参数量和计算量,是模型压缩的核心技术之一。

1. 剪枝分类

模型剪枝
├── 非结构化剪枝
│   └── 任意位置移除权重
├── 结构化剪枝
│   ├── 神经元剪枝
│   ├── 卷积核剪枝
│   └── 层级剪枝
└── 渐进式剪枝
    └── 迭代移除权重

2. 非结构化剪枝

2.1 幅度剪枝(Magnitude Pruning)

最简单的剪枝方法,基于权重幅度判断重要性:

import torch
import torch.nn as nn
 
def magnitude_pruning(model, sparsity=0.5):
    """
    幅度剪枝
    
    Args:
        model: 待剪枝模型
        sparsity: 剪枝比例(0.5 = 移除50%权重)
    """
    for name, param in model.named_parameters():
        if 'weight' in name:
            # 计算阈值(权重绝对值的分位数)
            threshold = torch.quantile(param.abs(), sparsity)
            
            # 创建掩码
            mask = (param.abs() > threshold).float()
            
            # 应用掩码
            param.data = param.data * mask
            
    return model

2.2 梯度幅度剪枝

根据梯度幅度判断重要性:

def gradient_magnitude_pruning(model, inputs, targets, sparsity=0.5):
    """
    基于梯度幅度的剪枝
    """
    # 计算梯度
    model.zero_grad()
    outputs = model(inputs)
    loss = nn.functional.cross_entropy(outputs, targets)
    loss.backward()
    
    # 计算梯度幅度
    grad_magnitudes = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_magnitudes[name] = param.grad.abs()
    
    # 根据梯度幅度创建掩码
    for name, param in model.named_parameters():
        if name in grad_magnitudes:
            threshold = torch.quantile(grad_magnitudes[name], sparsity)
            mask = (grad_magnitudes[name] > threshold).float()
            param.data = param.data * mask
    
    return model

3. 结构化剪枝

3.1 神经元剪枝

移除整个神经元(权重向量):

class NeuronPruning:
    """神经元级别剪枝"""
    
    @staticmethod
    def compute_neuron_importance(layer):
        """
        计算神经元重要性(基于激活方差)
        """
        # 对线性层:计算每行权重的L2范数
        if isinstance(layer, nn.Linear):
            # 权重形状: (out_features, in_features)
            # 每个输出神经元对应一行
            importance = torch.norm(layer.weight.data, dim=1)
        elif isinstance(layer, nn.Conv2d):
            # 卷积层:每个卷积核是一个"神经元"
            # 权重形状: (out_channels, in_channels, kH, kW)
            importance = torch.norm(
                layer.weight.data.view(layer.out_channels, -1), 
                dim=1
            )
        return importance
    
    @staticmethod
    def prune_neurons(layer, importance, threshold):
        """剪枝不重要的神经元"""
        mask = (importance > threshold).float()
        
        if isinstance(layer, nn.Linear):
            # 保留重要神经元对应的权重行
            new_out_features = mask.sum().int().item()
            new_weight = layer.weight.data[mask.bool()].clone()
            new_bias = layer.bias.data[mask.bool()].clone() if layer.bias is not None else None
            
            # 创建新层
            new_layer = nn.Linear(layer.in_features, new_out_features)
            new_layer.weight.data = new_weight
            if new_bias is not None:
                new_layer.bias.data = new_bias
        
        return new_layer, mask

3.2 卷积核剪枝

移除整个卷积核:

def prune_conv_kernels(model, sparsity=0.3):
    """
    卷积核剪枝:移除不重要的卷积核
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            # 计算每个卷积核的重要性(权重L2范数)
            kernel_importance = torch.norm(
                module.weight.data.view(module.out_channels, -1), 
                dim=1
            )
            
            # 确定保留的卷积核数量
            num_keep = int(module.out_channels * (1 - sparsity))
            
            # 保留最重要的卷积核
            _, keep_indices = torch.topk(kernel_importance, num_keep)
            keep_mask = torch.zeros(module.out_channels).scatter_(
                0, keep_indices, 1
            ).bool()
            
            # 更新权重
            module.weight.data = module.weight.data[keep_mask]
            if module.bias is not None:
                module.bias.data = module.bias.data[keep_mask]
            
            # 注意:下一层的in_channels需要对应更新
    
    return model

3.3 层级剪枝

class LayerPruning:
    """层级剪枝"""
    
    def __init__(self, model):
        self.model = model
    
    def compute_layer_importance(self, dataloader, device='cuda'):
        """
        基于验证损失计算每层的重要性
        
        使用泰勒展开近似移除该层对损失的影响
        """
        importance = {}
        
        # 收集每层的梯度
        hooks = []
        for name, module in self.model.named_modules():
            if len(list(module.children())) == 0:  # 叶子模块
                handle = module.register_backward_hook(
                    lambda m, g_in, g_out: self._hook_fn(name, g_out, importance)
                )
                hooks.append(handle)
        
        # 一次前向反向
        batch = next(iter(dataloader))
        inputs, targets = batch[0].to(device), batch[1].to(device)
        
        self.model.zero_grad()
        outputs = self.model(inputs)
        loss = nn.functional.cross_entropy(outputs, targets)
        loss.backward()
        
        # 移除hooks
        for h in hooks:
            h.remove()
        
        return importance
    
    @staticmethod
    def _hook_fn(name, grad_out, importance):
        """梯度钩子函数"""
        if grad_out[0] is not None:
            importance[name] = (grad_out[0].abs()).mean()
    
    def prune_layers(self, importance, sparsity):
        """
        剪枝不重要的层
        """
        # 按重要性排序
        sorted_layers = sorted(importance.items(), key=lambda x: x[1])
        
        # 移除最不重要的层
        num_prune = int(len(sorted_layers) * sparsity)
        prune_names = set([name for name, _ in sorted_layers[:num_prune]])
        
        # 构建新模型(简化实现)
        new_model = nn.Sequential()
        for name, module in self.model.named_children():
            if name not in prune_names:
                new_model.add_module(name, module)
        
        return new_model

4. 渐进式剪枝

4.1 Lottery Ticket Hypothesis

彩票假说1:一个Dense网络包含一个Sparse子网络,可以从零开始训练并达到相同性能。

def lottery_ticket_pruning(model, train_loader, test_loader, 
                          sparsity=0.9, iterations=5, lr=0.01):
    """
    彩票假说剪枝流程
    
    1. 训练模型到收敛
    2. 剪枝权重
    3. 重置剩余权重到初始值
    4. 重复
    """
    # 保存初始权重
    original_weights = {}
    for name, param in model.named_parameters():
        original_weights[name] = param.data.clone()
    
    current_sparsity = 0
    
    for iteration in range(iterations):
        # 目标稀疏度
        target_sparsity = ((iteration + 1) / iterations) * sparsity
        
        # 训练模型
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        
        for epoch in range(10):
            for batch in train_loader:
                inputs, targets = batch
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = nn.functional.cross_entropy(outputs, targets)
                loss.backward()
                optimizer.step()
        
        # 计算准确率
        accuracy = evaluate(model, test_loader)
        print(f"Iteration {iteration}: Sparsity={target_sparsity:.2%}, Accuracy={accuracy:.2%}")
        
        # 剪枝
        if target_sparsity > current_sparsity:
            model = magnitude_pruning(model, target_sparsity)
            current_sparsity = target_sparsity
            
            # 重置剩余权重到原始值
            for name, param in model.named_parameters():
                if param.abs().sum() > 0:  # 保留的权重
                    param.data = original_weights[name].clone()
                    # 需要应用掩码
                    mask = (original_weights[name].abs() > 0).float()
                    param.data = param.data * mask
    
    return model

4.2 渐进式剪枝策略

class ProgressivePruning:
    """渐进式剪枝调度器"""
    
    def __init__(self, initial_sparsity=0.0, final_sparsity=0.9, 
                 total_steps=10000, schedule='cubic'):
        self.initial_sparsity = initial_sparsity
        self.final_sparsity = final_sparsity
        self.total_steps = total_steps
        self.schedule = schedule
    
    def get_sparsity(self, step):
        """
        获取当前步的稀疏度
        """
        progress = step / self.total_steps
        
        if self.schedule == 'linear':
            return self.initial_sparsity + (self.final_sparsity - self.initial_sparsity) * progress
        
        elif self.schedule == 'cubic':
            return self.initial_sparsity + (self.final_sparsity - self.initial_sparsity) * (progress ** 3)
        
        elif self.schedule == 'exponential':
            return self.final_sparsity * (1 - (1 - progress) ** 3)
        
        elif self.schedule == 'sinusoidal':
            return self.final_sparsity * (1 - (1 + np.cos(np.pi * progress)) / 2)
        
        else:
            return self.final_sparsity * progress
    
    def apply_progressive_pruning(self, model, step):
        """应用渐进式剪枝"""
        sparsity = self.get_sparsity(step)
        
        # 为每层计算阈值
        for name, param in model.named_parameters():
            if 'weight' in name:
                threshold = torch.quantile(param.abs(), sparsity)
                mask = (param.abs() > threshold).float()
                param.data = param.data * mask
        
        return model

5. 神经元剪枝与网络等价变换

5.1 BatchNorm等价变换

def fuse_conv_bn(model):
    """
    融合卷积层和BatchNorm层
    
    将 Conv -> BN 融合为单个卷积层
    """
    new_model = nn.Sequential()
    modules = list(model.modules())
    
    for i, module in enumerate(modules):
        if isinstance(module, nn.Conv2d):
            # 检查下一个模块是否是BatchNorm
            if i + 1 < len(modules) and isinstance(modules[i + 1], nn.BatchNorm2d):
                bn = modules[i + 1]
                
                # 融合权重
                fused_conv = self._fuse_conv_bn(module, bn)
                new_model.add_module(f'conv_{i}', fused_conv)
            else:
                new_model.add_module(f'conv_{i}', module)
        
        elif isinstance(module, nn.BatchNorm2d):
            continue  # 已融合,跳过
        else:
            new_model.add_module(str(i), module)
    
    return new_model
 
@staticmethod
def _fuse_conv_bn(conv, bn):
    """融合单个Conv-BN对"""
    # BN参数
    bn_std = torch.sqrt(bn.running_var + bn.eps)
    gamma = bn.weight / bn_std
    beta = bn.bias - bn.running_mean * gamma
    
    # 融合到卷积权重和偏置
    fused_conv = nn.Conv2d(
        conv.in_channels, conv.out_channels, conv.kernel_size,
        conv.stride, conv.padding, conv.dilation, conv.groups, True
    )
    
    fused_conv.weight.data = conv.weight.data * gamma.view(-1, 1, 1, 1)
    fused_conv.bias.data = beta
    
    return fused_conv

5.2 通道剪枝的统一框架

class ChannelPruning:
    """通道剪枝的统一实现"""
    
    def __init__(self, model, prune_ratio=0.5):
        self.model = model
        self.prune_ratio = prune_ratio
    
    def channel_importance_l1(self, layer):
        """基于L1范数的通道重要性"""
        if isinstance(layer, nn.Conv2d):
            # 权重形状: (C_out, C_in, K, K)
            # 计算每个输出通道的L1范数
            importance = torch.norm(
                layer.weight.data, 
                dim=(1, 2, 3)  # 跨C_in, K, K
            )
        return importance
    
    def channel_importance_taylor(self, layer, grad):
        """基于泰勒展开的重要性"""
        if isinstance(layer, nn.Conv2d):
            # 权重 * 梯度 的绝对值作为重要性
            importance = (layer.weight.data * grad[0]).abs().sum(dim=(1, 2, 3))
        return importance
    
    def prune(self, importance_fn='l1'):
        """执行通道剪枝"""
        new_model = nn.Sequential()
        prune_count = 0
        
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                # 计算重要性
                if importance_fn == 'l1':
                    importance = self.channel_importance_l1(module)
                elif importance_fn == 'taylor':
                    # 需要前向反向计算梯度
                    importance = self.channel_importance_taylor(module, ...)
                
                # 确定保留的通道数
                num_keep = int(module.out_channels * (1 - self.prune_ratio))
                _, keep_indices = torch.topk(importance, num_keep)
                
                # 创建新卷积层
                new_conv = nn.Conv2d(
                    module.in_channels, num_keep,
                    module.kernel_size, module.stride, 
                    module.padding, module.dilation, module.groups
                )
                new_conv.weight.data = module.weight.data[keep_indices]
                if module.bias is not None:
                    new_conv.bias.data = module.bias.data[keep_indices]
                
                new_model.add_module(f'conv_{prune_count}', new_conv)
                prune_count += 1
                
            elif isinstance(module, nn.Linear):
                new_model.add_module(str(prune_count), module)
            else:
                new_model.add_module(str(prune_count), module)
        
        return new_model

6. 实践指南

6.1 剪枝策略选择

场景推荐策略压缩比
快速推理结构化剪枝2-4x
极致压缩非结构化剪枝 + 稀疏格式10x+
微调后剪枝渐进式剪枝5-10x
资源受限层级剪枝可控

6.2 训练-剪枝-微调流程

def prune_and_finetune(model, train_loader, val_loader, 
                       prune_ratio=0.5, finetune_epochs=10):
    """
    完整的剪枝-微调流程
    """
    # 1. 训练原始模型
    print("Step 1: Training original model...")
    model = train_model(model, train_loader, epochs=50)
    
    # 2. 剪枝
    print("Step 2: Pruning...")
    model = magnitude_pruning(model, sparsity=prune_ratio)
    
    # 3. 微调(恢复性能)
    print("Step 3: Finetuning...")
    model = finetune_model(model, train_loader, val_loader, epochs=finetune_epochs)
    
    return model

6.3 评估指标

指标说明
压缩比参数量减少比例
加速比推理速度提升
精度损失任务性能下降
稀疏度零值比例

7. 参考资料

扩展阅读:

Footnotes

  1. Frankle J, Carbin M. The lottery ticket hypothesis: Finding sparse, trainable neural networks. ICLR, 2019.