梯度检查点技术

梯度检查点(Gradient Checkpointing,也称激活重计算)是训练大型深度学习模型的关键内存优化技术。本专题详细分析其原理、实现与实践。

1. 问题背景

1.1 内存瓶颈

训练深度神经网络时,反向传播需要存储前向传播的中间激活值:

参数量 ~ 10B → 需要 ~40GB (fp32)
中间激活值 → 可能需要 ~100GB+

典型Transformer的内存分布

组件内存占用比例
模型参数40GB~25%
梯度40GB~25%
优化器状态 (Adam)80GB~50%
中间激活值可变变化大

1.2 内存-计算权衡

全存储策略

  • 内存:
  • 反向计算:
  • 其中 是层数

全重计算策略

  • 内存:
  • 反向计算:
  • 重新计算所有激活值

选择性检查点(最优)

  • 内存:
  • 计算:
  • 通过智能选择实现最优权衡

2. 核心原理

2.1 检查点思想

# 无检查点
def forward(self, x):
    h1 = self.layer1(x)      # 存储
    h2 = self.layer2(h1)     # 存储
    h3 = self.layer3(h2)     # 存储
    h4 = self.layer4(h3)     # 存储
    h5 = self.layer5(h4)     # 存储
    return h5
 
# 有检查点
def forward(self, x):
    h1 = self.layer1(x)
    h1_c = checkpoint(h1)   # 保存检查点
    h2 = self.layer2(h1_c)
    h3 = self.layer3(h2)
    h4 = self.layer4(h3)
    h5 = self.layer5(h4_c)   # 反向时需要重计算
    return h5

2.2 反向传播的内存需求

反向传播需要:

  1. 前向传播的所有中间激活值
  2. 每层的梯度

其中 是第 层的激活值内存。

2.3 检查点策略

策略:只保存部分激活值,反向时重新计算其余值。

# 保存间隔为 d 的检查点
checkpoints = [x_0, x_d, x_2d, ..., x_L]
# 反向时重新计算: x_i = recompute(x_{i-d})

最优间隔

对于 层模型和可用内存

其中 是单层激活值的内存大小。

3. PyTorch实现

3.1 基本用法

import torch
from torch.utils.checkpoint import checkpoint
 
class ModelWithCheckpointing(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.ModuleList([
            torch.nn.Linear(512, 512) for _ in range(20)
        ])
    
    def forward(self, x):
        for i, layer in enumerate(self.layers):
            if i % 4 == 0:  # 每4层保存检查点
                x = checkpoint(layer, x, use_reentrant=True)
            else:
                x = layer(x)
        return x

3.2 模块级检查点

from torch.utils.checkpoint import checkpoint_sequential
 
class DeepModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.module_list = torch.nn.Sequential(*[
            ResidualBlock() for _ in range(50)
        ])
    
    def forward(self, x):
        # 将整个Sequential包装为检查点
        return checkpoint_sequential(
            self.module_list,
            num_segments=10,  # 分成10段
            input=x
        )

3.3 自定义检查点函数

from torch.utils.checkpoint import create_checkpoint_fn
 
# 获取检查点工厂函数
checkpoint_fn = create_checkpoint_fn()
 
# 自定义检查点逻辑
class CustomModel(torch.nn.Module):
    def forward(self, x):
        # 第一段
        x = self.layers[:10](x)
        x = checkpoint_fn(
            lambda y: self.middle_layers(y),
            x,
            use_reentrant=False,  # 推荐用于精确的梯度计算
            preserve_rng_state=True
        )
        # 第二段
        x = self.layers[20:](x)
        return x

3.4 use_reentrant参数

# use_reentrant=True (默认)
# - 更快但梯度可能不精确
# - 不支持某些操作(如原地操作)
# - 适用于大多数情况
 
# use_reentrant=False
# - 更精确的梯度计算
# - 支持原地操作
# - 需要更多内存
x = checkpoint(func, x, use_reentrant=False)

4. 高级技术

4.1 选择性检查点策略

class SelectiveCheckpointing(torch.nn.Module):
    """
    根据层类型选择性地应用检查点
    """
    def __init__(self):
        super().__init__()
        self.attention_layers = torch.nn.ModuleList([...])
        self.ffn_layers = torch.nn.ModuleList([...])
    
    def forward(self, x):
        for attn, ffn in zip(self.attention_layers, self.ffn_layers):
            # Attention层:计算密集,检查点化
            x = checkpoint(attn, x, use_reentrant=False)
            # FFN层:内存密集,全存储
            x = ffn(x)
        return x

4.2 CPU卸载

from torch.utils.checkpoint import checkpoint
 
def cpu_offload_checkpoint(module, *inputs):
    """将检查点卸载到CPU"""
    def _run_on_cpu(module, *args):
        # 将输入移到CPU
        args_cpu = [a.cpu() for a in args]
        
        # 在CPU上执行
        with torch.no_grad():
            outputs = module(*args_cpu)
        
        # 移回GPU
        if isinstance(outputs, torch.Tensor):
            return outputs.cuda()
        elif isinstance(outputs, tuple):
            return tuple(o.cuda() for o in outputs)
        
        return outputs
    
    return checkpoint(
        _run_on_cpu,
        module,
        *inputs,
        use_reentrant=False
    )

4.3 分块检查点

def chunked_checkpoint(module, x, chunk_size=32):
    """
    分块处理长序列的检查点
    """
    seq_len = x.shape[1]
    outputs = []
    
    for i in range(0, seq_len, chunk_size):
        chunk = x[:, i:i+chunk_size]
        chunk_out = checkpoint(module, chunk, use_reentrant=False)
        outputs.append(chunk_out)
    
    return torch.cat(outputs, dim=1)

4.4 与其他优化的结合

# 与混合精度训练结合
from torch.cuda.amp import autocast, GradScaler
 
class MixedPrecisionCheckpointModel(torch.nn.Module):
    def forward(self, x):
        with autocast():
            x = checkpoint(self.encoder, x)
            x = self.decoder(x)
        return x
 
# 与DistributedDataParallel结合
model = DistributedDataParallel(
    CheckpointedModel(),
    device_ids=[local_rank]
)

5. 内存-计算分析

5.1 理论分析

对于 层网络,层 的激活值大小为

全存储

选择性检查点(间隔

最优间隔

5.2 实际测量

import torch
import gc
 
def measure_memory(model, input_tensor):
    """测量模型前向传播的内存使用"""
    torch.cuda.empty_cache()
    gc.collect()
    
    torch.cuda.reset_peak_memory_stats()
    
    output = model(input_tensor)
    
    peak_memory = torch.cuda.max_memory_allocated() / 1024**3  # GB
    return peak_memory
 
# 测试不同检查点策略
for num_checkpoints in [1, 5, 10, 20]:
    model = CheckpointedModel(num_segments=num_checkpoints)
    memory = measure_memory(model, input_tensor)
    print(f"检查点数: {num_checkpoints}, 内存: {memory:.2f}GB")

5.3 典型收益

模型无检查点有检查点节省
ResNet-1527.2GB3.1GB~57%
ViT-B/1612.5GB5.8GB~54%
GPT-2 (774M)38GB16GB~58%
LLaMA-7B56GB24GB~57%

6. 实践指南

6.1 检查点位置选择

推荐策略

  1. 计算密集层优先:Attention层、卷积层
  2. 大内存层优先:全连接层、大嵌入表
  3. 均匀分布:总能找到接近最优的策略
# 自动选择最优检查点数
def find_optimal_checkpointing(model, input_tensor, target_memory_gb=10):
    """二分搜索最优检查点数量"""
    num_layers = count_parameters(model)
    
    low, high = 1, num_layers
    best = num_layers
    
    while low <= high:
        mid = (low + high) // 2
        memory = measure_memory(model, mid, input_tensor)
        
        if memory <= target_memory_gb:
            best = mid
            high = mid - 1
        else:
            low = mid + 1
    
    return best

6.2 常见问题与解决方案

问题1:梯度不正确

# 解决方案:使用 use_reentrant=False
x = checkpoint(func, x, use_reentrant=False)

问题2:原地操作导致错误

# 避免原地操作
x = x + self.bias  # 替代 x.add_(self.bias)

问题3:RNG状态不一致

# 使用 preserve_rng_state
x = checkpoint(func, x, preserve_rng_state=True)

6.3 性能调优

# 1. 使用非重新进入模式(更精确)
checkpoint(fn, *args, use_reentrant=False)
 
# 2. 减少检查点数量(减少重计算)
# 3. 使用梯度累积(增大有效batch size)
# 4. 使用CPU卸载处理超长序列

7. 框架支持

7.1 PyTorch

from torch.utils.checkpoint import checkpoint, checkpoint_sequential
 
# 单函数检查点
out = checkpoint(my_func, input)
 
# 序列检查点
out = checkpoint_sequential(my_seq, num_segments=4, input)

7.2 JAX

from jax.checkpoint import checkpoint as jax_checkpoint
 
# 使用remat装饰器
@partial(jax_checkpoint, policy=jax.checkpoint_policies.everyTHING)
def model(x):
    return layers(x)

7.3 TensorFlow

import tensorflow as tf
 
# 使用GradientTape的检查点功能
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        # 使用tf.recompute_grad
        output = tf.recompute_grad(model)(x)
        loss = compute_loss(output, y)
    
    grads = tape.gradient(loss, model.trainable_variables)
    return grads

8. 与其他技术的比较

8.1 CPU Offloading

特性梯度检查点CPU卸载
内存节省2-5x5-10x
计算开销~30%~100%+
实现复杂度
适用场景通用超长序列

8.2 量化

特性梯度检查点量化
内存节省2-5x4-8x
精度损失可变
计算开销~30%最小
兼容性独立使用可叠加

9. 总结

9.1 核心要点

  1. 核心思想:用计算换内存
  2. 最优策略:选择性检查点
  3. PyTorch APIcheckpointcheckpoint_sequential
  4. 关键参数use_reentrant=False 推荐
  5. 典型收益:50%+ 内存节省,~30% 计算开销

9.2 相关专题

参考资料