模型剪枝技术
模型剪枝(Pruning)通过移除不重要的权重或神经元来减少模型参数量和计算量,是模型压缩的核心技术之一。
1. 剪枝分类
模型剪枝
├── 非结构化剪枝
│ └── 任意位置移除权重
├── 结构化剪枝
│ ├── 神经元剪枝
│ ├── 卷积核剪枝
│ └── 层级剪枝
└── 渐进式剪枝
└── 迭代移除权重
2. 非结构化剪枝
2.1 幅度剪枝(Magnitude Pruning)
最简单的剪枝方法,基于权重幅度判断重要性:
import torch
import torch.nn as nn
def magnitude_pruning(model, sparsity=0.5):
"""
幅度剪枝
Args:
model: 待剪枝模型
sparsity: 剪枝比例(0.5 = 移除50%权重)
"""
for name, param in model.named_parameters():
if 'weight' in name:
# 计算阈值(权重绝对值的分位数)
threshold = torch.quantile(param.abs(), sparsity)
# 创建掩码
mask = (param.abs() > threshold).float()
# 应用掩码
param.data = param.data * mask
return model2.2 梯度幅度剪枝
根据梯度幅度判断重要性:
def gradient_magnitude_pruning(model, inputs, targets, sparsity=0.5):
"""
基于梯度幅度的剪枝
"""
# 计算梯度
model.zero_grad()
outputs = model(inputs)
loss = nn.functional.cross_entropy(outputs, targets)
loss.backward()
# 计算梯度幅度
grad_magnitudes = {}
for name, param in model.named_parameters():
if param.grad is not None:
grad_magnitudes[name] = param.grad.abs()
# 根据梯度幅度创建掩码
for name, param in model.named_parameters():
if name in grad_magnitudes:
threshold = torch.quantile(grad_magnitudes[name], sparsity)
mask = (grad_magnitudes[name] > threshold).float()
param.data = param.data * mask
return model3. 结构化剪枝
3.1 神经元剪枝
移除整个神经元(权重向量):
class NeuronPruning:
"""神经元级别剪枝"""
@staticmethod
def compute_neuron_importance(layer):
"""
计算神经元重要性(基于激活方差)
"""
# 对线性层:计算每行权重的L2范数
if isinstance(layer, nn.Linear):
# 权重形状: (out_features, in_features)
# 每个输出神经元对应一行
importance = torch.norm(layer.weight.data, dim=1)
elif isinstance(layer, nn.Conv2d):
# 卷积层:每个卷积核是一个"神经元"
# 权重形状: (out_channels, in_channels, kH, kW)
importance = torch.norm(
layer.weight.data.view(layer.out_channels, -1),
dim=1
)
return importance
@staticmethod
def prune_neurons(layer, importance, threshold):
"""剪枝不重要的神经元"""
mask = (importance > threshold).float()
if isinstance(layer, nn.Linear):
# 保留重要神经元对应的权重行
new_out_features = mask.sum().int().item()
new_weight = layer.weight.data[mask.bool()].clone()
new_bias = layer.bias.data[mask.bool()].clone() if layer.bias is not None else None
# 创建新层
new_layer = nn.Linear(layer.in_features, new_out_features)
new_layer.weight.data = new_weight
if new_bias is not None:
new_layer.bias.data = new_bias
return new_layer, mask3.2 卷积核剪枝
移除整个卷积核:
def prune_conv_kernels(model, sparsity=0.3):
"""
卷积核剪枝:移除不重要的卷积核
"""
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
# 计算每个卷积核的重要性(权重L2范数)
kernel_importance = torch.norm(
module.weight.data.view(module.out_channels, -1),
dim=1
)
# 确定保留的卷积核数量
num_keep = int(module.out_channels * (1 - sparsity))
# 保留最重要的卷积核
_, keep_indices = torch.topk(kernel_importance, num_keep)
keep_mask = torch.zeros(module.out_channels).scatter_(
0, keep_indices, 1
).bool()
# 更新权重
module.weight.data = module.weight.data[keep_mask]
if module.bias is not None:
module.bias.data = module.bias.data[keep_mask]
# 注意:下一层的in_channels需要对应更新
return model3.3 层级剪枝
class LayerPruning:
"""层级剪枝"""
def __init__(self, model):
self.model = model
def compute_layer_importance(self, dataloader, device='cuda'):
"""
基于验证损失计算每层的重要性
使用泰勒展开近似移除该层对损失的影响
"""
importance = {}
# 收集每层的梯度
hooks = []
for name, module in self.model.named_modules():
if len(list(module.children())) == 0: # 叶子模块
handle = module.register_backward_hook(
lambda m, g_in, g_out: self._hook_fn(name, g_out, importance)
)
hooks.append(handle)
# 一次前向反向
batch = next(iter(dataloader))
inputs, targets = batch[0].to(device), batch[1].to(device)
self.model.zero_grad()
outputs = self.model(inputs)
loss = nn.functional.cross_entropy(outputs, targets)
loss.backward()
# 移除hooks
for h in hooks:
h.remove()
return importance
@staticmethod
def _hook_fn(name, grad_out, importance):
"""梯度钩子函数"""
if grad_out[0] is not None:
importance[name] = (grad_out[0].abs()).mean()
def prune_layers(self, importance, sparsity):
"""
剪枝不重要的层
"""
# 按重要性排序
sorted_layers = sorted(importance.items(), key=lambda x: x[1])
# 移除最不重要的层
num_prune = int(len(sorted_layers) * sparsity)
prune_names = set([name for name, _ in sorted_layers[:num_prune]])
# 构建新模型(简化实现)
new_model = nn.Sequential()
for name, module in self.model.named_children():
if name not in prune_names:
new_model.add_module(name, module)
return new_model4. 渐进式剪枝
4.1 Lottery Ticket Hypothesis
彩票假说1:一个Dense网络包含一个Sparse子网络,可以从零开始训练并达到相同性能。
def lottery_ticket_pruning(model, train_loader, test_loader,
sparsity=0.9, iterations=5, lr=0.01):
"""
彩票假说剪枝流程
1. 训练模型到收敛
2. 剪枝权重
3. 重置剩余权重到初始值
4. 重复
"""
# 保存初始权重
original_weights = {}
for name, param in model.named_parameters():
original_weights[name] = param.data.clone()
current_sparsity = 0
for iteration in range(iterations):
# 目标稀疏度
target_sparsity = ((iteration + 1) / iterations) * sparsity
# 训练模型
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(10):
for batch in train_loader:
inputs, targets = batch
optimizer.zero_grad()
outputs = model(inputs)
loss = nn.functional.cross_entropy(outputs, targets)
loss.backward()
optimizer.step()
# 计算准确率
accuracy = evaluate(model, test_loader)
print(f"Iteration {iteration}: Sparsity={target_sparsity:.2%}, Accuracy={accuracy:.2%}")
# 剪枝
if target_sparsity > current_sparsity:
model = magnitude_pruning(model, target_sparsity)
current_sparsity = target_sparsity
# 重置剩余权重到原始值
for name, param in model.named_parameters():
if param.abs().sum() > 0: # 保留的权重
param.data = original_weights[name].clone()
# 需要应用掩码
mask = (original_weights[name].abs() > 0).float()
param.data = param.data * mask
return model4.2 渐进式剪枝策略
class ProgressivePruning:
"""渐进式剪枝调度器"""
def __init__(self, initial_sparsity=0.0, final_sparsity=0.9,
total_steps=10000, schedule='cubic'):
self.initial_sparsity = initial_sparsity
self.final_sparsity = final_sparsity
self.total_steps = total_steps
self.schedule = schedule
def get_sparsity(self, step):
"""
获取当前步的稀疏度
"""
progress = step / self.total_steps
if self.schedule == 'linear':
return self.initial_sparsity + (self.final_sparsity - self.initial_sparsity) * progress
elif self.schedule == 'cubic':
return self.initial_sparsity + (self.final_sparsity - self.initial_sparsity) * (progress ** 3)
elif self.schedule == 'exponential':
return self.final_sparsity * (1 - (1 - progress) ** 3)
elif self.schedule == 'sinusoidal':
return self.final_sparsity * (1 - (1 + np.cos(np.pi * progress)) / 2)
else:
return self.final_sparsity * progress
def apply_progressive_pruning(self, model, step):
"""应用渐进式剪枝"""
sparsity = self.get_sparsity(step)
# 为每层计算阈值
for name, param in model.named_parameters():
if 'weight' in name:
threshold = torch.quantile(param.abs(), sparsity)
mask = (param.abs() > threshold).float()
param.data = param.data * mask
return model5. 神经元剪枝与网络等价变换
5.1 BatchNorm等价变换
def fuse_conv_bn(model):
"""
融合卷积层和BatchNorm层
将 Conv -> BN 融合为单个卷积层
"""
new_model = nn.Sequential()
modules = list(model.modules())
for i, module in enumerate(modules):
if isinstance(module, nn.Conv2d):
# 检查下一个模块是否是BatchNorm
if i + 1 < len(modules) and isinstance(modules[i + 1], nn.BatchNorm2d):
bn = modules[i + 1]
# 融合权重
fused_conv = self._fuse_conv_bn(module, bn)
new_model.add_module(f'conv_{i}', fused_conv)
else:
new_model.add_module(f'conv_{i}', module)
elif isinstance(module, nn.BatchNorm2d):
continue # 已融合,跳过
else:
new_model.add_module(str(i), module)
return new_model
@staticmethod
def _fuse_conv_bn(conv, bn):
"""融合单个Conv-BN对"""
# BN参数
bn_std = torch.sqrt(bn.running_var + bn.eps)
gamma = bn.weight / bn_std
beta = bn.bias - bn.running_mean * gamma
# 融合到卷积权重和偏置
fused_conv = nn.Conv2d(
conv.in_channels, conv.out_channels, conv.kernel_size,
conv.stride, conv.padding, conv.dilation, conv.groups, True
)
fused_conv.weight.data = conv.weight.data * gamma.view(-1, 1, 1, 1)
fused_conv.bias.data = beta
return fused_conv5.2 通道剪枝的统一框架
class ChannelPruning:
"""通道剪枝的统一实现"""
def __init__(self, model, prune_ratio=0.5):
self.model = model
self.prune_ratio = prune_ratio
def channel_importance_l1(self, layer):
"""基于L1范数的通道重要性"""
if isinstance(layer, nn.Conv2d):
# 权重形状: (C_out, C_in, K, K)
# 计算每个输出通道的L1范数
importance = torch.norm(
layer.weight.data,
dim=(1, 2, 3) # 跨C_in, K, K
)
return importance
def channel_importance_taylor(self, layer, grad):
"""基于泰勒展开的重要性"""
if isinstance(layer, nn.Conv2d):
# 权重 * 梯度 的绝对值作为重要性
importance = (layer.weight.data * grad[0]).abs().sum(dim=(1, 2, 3))
return importance
def prune(self, importance_fn='l1'):
"""执行通道剪枝"""
new_model = nn.Sequential()
prune_count = 0
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
# 计算重要性
if importance_fn == 'l1':
importance = self.channel_importance_l1(module)
elif importance_fn == 'taylor':
# 需要前向反向计算梯度
importance = self.channel_importance_taylor(module, ...)
# 确定保留的通道数
num_keep = int(module.out_channels * (1 - self.prune_ratio))
_, keep_indices = torch.topk(importance, num_keep)
# 创建新卷积层
new_conv = nn.Conv2d(
module.in_channels, num_keep,
module.kernel_size, module.stride,
module.padding, module.dilation, module.groups
)
new_conv.weight.data = module.weight.data[keep_indices]
if module.bias is not None:
new_conv.bias.data = module.bias.data[keep_indices]
new_model.add_module(f'conv_{prune_count}', new_conv)
prune_count += 1
elif isinstance(module, nn.Linear):
new_model.add_module(str(prune_count), module)
else:
new_model.add_module(str(prune_count), module)
return new_model6. 实践指南
6.1 剪枝策略选择
| 场景 | 推荐策略 | 压缩比 |
|---|---|---|
| 快速推理 | 结构化剪枝 | 2-4x |
| 极致压缩 | 非结构化剪枝 + 稀疏格式 | 10x+ |
| 微调后剪枝 | 渐进式剪枝 | 5-10x |
| 资源受限 | 层级剪枝 | 可控 |
6.2 训练-剪枝-微调流程
def prune_and_finetune(model, train_loader, val_loader,
prune_ratio=0.5, finetune_epochs=10):
"""
完整的剪枝-微调流程
"""
# 1. 训练原始模型
print("Step 1: Training original model...")
model = train_model(model, train_loader, epochs=50)
# 2. 剪枝
print("Step 2: Pruning...")
model = magnitude_pruning(model, sparsity=prune_ratio)
# 3. 微调(恢复性能)
print("Step 3: Finetuning...")
model = finetune_model(model, train_loader, val_loader, epochs=finetune_epochs)
return model6.3 评估指标
| 指标 | 说明 |
|---|---|
| 压缩比 | 参数量减少比例 |
| 加速比 | 推理速度提升 |
| 精度损失 | 任务性能下降 |
| 稀疏度 | 零值比例 |
7. 参考资料
扩展阅读:
Footnotes
-
Frankle J, Carbin M. The lottery ticket hypothesis: Finding sparse, trainable neural networks. ICLR, 2019. ↩