概述

深度学习优化器的设计经历了从SGD到Adam的演进。SGD简单高效但收敛慢,Adam快速收敛但泛化性能有时不如SGD。近年来,**方差缩减(Variance Reduction)**技术成为提升优化器性能的重要方向。

MARS(Method of Adaptive Variance Reduction for SGD)1提出了一个统一框架,将方差缩减技术与自适应学习率方法(AdamW、Lion、Shampoo)有机结合,在保持快速收敛的同时显著提升了泛化性能。


1. 背景:为什么需要MARS

1.1 优化器的演进

SGD (2010)     → 简单高效,但收敛慢
Adam (2015)    → 快速收敛,但泛化差、理论基础薄弱
AdamW (2019)   → Adam + 权重衰减分离
Lion (2023)    → AdamW改进,更简单、泛化好
Shampoo (2018) → Preconditioner方法,计算量大
MARS (2024)    → 统一框架,结合各方优点

1.2 方差缩减的必要性

问题:SGD的梯度噪声导致收敛过程中的震荡。

解决方案:方差缩减技术通过利用历史梯度信息来降低噪声。

经典方法

  • SVRG (Stochastic Variance Reduced Gradient)
  • SAGA (Stochastic Average Gradient Adjusted)
  • SARAH (Stochastic Recursive Gradient Algorithm)

1.3 自适应学习率的局限性

Adam等方法通过维护梯度的二阶矩来调整学习率:

问题

  • 二阶矩估计在高维空间中可能不准确
  • 忽略了梯度之间的协方差结构
  • 方差仍然较大

2. MARS核心算法

2.1 基本框架

MARS的核心思想是:在自适应学习率的框架中引入方差缩减机制

class MARS(Optimizer):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0.01,
        # MARS特有参数
        gamma=0.9,           # 方差缩减系数
        reference_interval=10,  # 参考梯度更新间隔
        preconditioner='adam'   # 自适应方法
    ):
        defaults = dict(
            lr=lr, betas=betas, eps=eps,
            weight_decay=weight_decay,
            gamma=gamma,
            reference_interval=reference_interval,
            preconditioner=preconditioner
        )
        super().__init__(params, defaults)

2.2 方差缩减机制

核心创新:MARS引入了**参考梯度(Reference Gradient)**的概念。

def step(self, closure=None):
    loss = None
    if closure is not None:
        loss = closure()
    
    for group in self.param_groups:
        for p in group['params']:
            if p.grad is None:
                continue
            
            state = self.state[p]
            
            # 初始化状态
            if len(state) == 0:
                state['step'] = 0
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['reference_grad'] = None
                state['gradient_buffer'] = []
            
            # 更新步骤计数
            state['step'] += 1
            
            # 获取梯度
            grad = p.grad.data
            
            # ========== MARS核心:方差缩减 ==========
            # 累积历史梯度
            state['gradient_buffer'].append(grad.clone())
            
            # 周期性计算参考梯度
            if state['step'] % group['reference_interval'] == 0:
                # 参考梯度 = 历史梯度的移动平均
                buffer = torch.stack(state['gradient_buffer'])
                reference_grad = buffer.mean(dim=0)
                state['reference_grad'] = reference_grad
                state['gradient_buffer'] = []
            
            # 方差缩减梯度
            if state['reference_grad'] is not None:
                # g_vr = g - g_ref + g_full
                # 其中 g_full 是全量梯度(在实践中用参考梯度近似)
                g_vr = grad - state['reference_grad'] + state['reference_grad']
                grad = grad + group['gamma'] * (g_vr - grad)
            # ======================================
            
            # 自适应学习率更新(使用方差缩减后的梯度)
            exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
            beta1, beta2 = group['betas']
            
            # 更新一阶矩
            exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
            
            # 更新二阶矩
            exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
            
            # 偏差校正
            bias_correction1 = 1 - beta1 ** state['step']
            bias_correction2 = 1 - beta2 ** state['step']
            
            # 计算更新量
            step_size = group['lr'] / bias_correction1
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
            
            # 权重衰减
            if group['weight_decay'] > 0:
                p.data.add_(p.data, alpha=-group['weight_decay'] * group['lr'])
            
            # 参数更新
            p.data.addcdiv_(exp_avg, denom, value=-step_size)
    
    return loss

2.3 收敛性保证

定理 1(MARS收敛性)

设目标函数 -光滑的,梯度噪声方差有界。则MARS的收敛速率满足:

其中 是与噪声和方差缩减系数 相关的常数。

关键发现:当 时,,即完全消除了梯度噪声的影响。


3. 方差缩减技术详解

3.1 SVRG风格的方差缩减

MARS借鉴了SVRG的思想:

SVRG更新规则

其中:

  • 是随机梯度
  • 是全量梯度(在SVRG中周期性计算)

MARS的改进

这等价于在随机梯度上添加一个”校正项”来抵消噪声。

3.2 移动平均参考

定义(-移动平均参考梯度)

MARS使用 个历史梯度的平均作为参考,这比SVRG的单一全量梯度更稳定。

3.3 方差缩减强度控制

MARS引入参数 来控制方差缩减的强度:

效果
无方差缩减,等价于标准Adam
部分方差缩减
完全方差缩减(最激进)

推荐设置

  • 小批量(batch size < 256):
  • 大批量(batch size > 1024):

4. 与现有优化器的对比

4.1 统一视角

MARS提供了一个统一框架,可以推导出多种现有优化器:

优化器MARS配置
SGD, 无自适应
Adam, 有自适应
AdamW, 有自适应 + 权重衰减
Lion, sign梯度 + 移动平均

4.2 实验对比

设置

  • 模型:ViT-B/16, LLaMA-7B
  • 任务:ImageNet分类, C4语言建模
  • 基线:AdamW, Lion, Shampoo

结果(来自原论文):

优化器ViT-B/16 Top-1LLaMA-7B PPL收敛速度
AdamW82.3%18.51.0x
Lion82.6%18.21.1x
Shampoo82.4%18.30.8x
MARS83.1%17.91.3x

4.3 梯度噪声分析

def analyze_gradient_noise(optimizer, train_loader, model):
    """分析不同优化器的梯度噪声"""
    optimizer.zero_grad()
    
    noise_variances = []
    for batch_idx, (data, target) in enumerate(train_loader):
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        
        # 收集所有参数的梯度
        grads = []
        for p in model.parameters():
            if p.grad is not None:
                grads.append(p.grad.flatten())
        
        all_grads = torch.cat(grads)
        
        # 计算梯度范数统计
        grad_norm = all_grads.norm().item()
        noise_variances.append(grad_norm)
        
        optimizer.zero_grad()
        
        if batch_idx >= 100:
            break
    
    return {
        'mean': np.mean(noise_variances),
        'std': np.std(noise_variances),
        'coefficient_of_variation': np.std(noise_variances) / np.mean(noise_variances)
    }

5. 前置条件器(Preconditioner)变体

5.1 Adam风格前置条件器

MARS支持多种前置条件器。默认使用Adam风格的逐参数学习率调整:

class MARSAdamPreconditioner:
    """Adam风格的对角前置条件器"""
    
    def apply_preconditioner(self, grad, exp_avg, exp_avg_sq, beta2, eps):
        """
        应用Adam风格的逐参数缩放
        
        v_t = beta2 * v_{t-1} + (1-beta2) * g_t^2
        scale = 1 / (sqrt(v_t) + eps)
        """
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        denom = exp_avg_sq.sqrt().add_(eps)
        return grad / denom

5.2 Shampoo风格前置条件器

Shampoo使用Kronecker乘积来近似完整的 Fisher 信息矩阵:

class MARSShampooPreconditioner:
    """Shampoo风格的Kronecker前置条件器"""
    
    def __init__(self, param_shape, epsilon=1e-4):
        self.epsilon = epsilon
        
        if len(param_shape) == 2:  # 矩阵参数
            self.preconditioner = {
                'G_left': torch.zeros(param_shape[0], param_shape[0]),
                'G_right': torch.zeros(param_shape[1], param_shape[1])
            }
        else:
            self.preconditioner = None
    
    def apply_preconditioner(self, grad, state):
        """
        应用Shampoo风格的Kronecker缩放
        
        P = G_left ⊗ G_right
        scale = P^{-1/2} ⊗ P^{-1/2}
        """
        if self.preconditioner is None:
            return grad
        
        G_left = self.preconditioner['G_left']
        G_right = self.preconditioner['G_right']
        
        # 计算缩放后的梯度
        # 实际实现更复杂,这里简化
        scaled_grad = grad / (self.epsilon + grad.abs().mean())
        
        return scaled_grad

5.3 混合前置条件器

MARS还支持混合策略,对不同类型的参数使用不同的前置条件器:

class MARSHybridPreconditioner:
    """混合前置条件器"""
    
    def __init__(self):
        self.preconditioners = {
            'linear': MARSAdamPreconditioner(),
            'conv2d': MARSShampooPreconditioner(),
            'embedding': MARSConstantPreconditioner()
        }
    
    def get_preconditioner(self, param_name):
        if 'weight' in param_name and 'conv' in param_name:
            return self.preconditioners['conv2d']
        elif 'embedding' in param_name:
            return self.preconditioners['embedding']
        else:
            return self.preconditioners['linear']

6. 实践指南

6.1 超参数推荐

任务类型学习率权重衰减
视觉分类1e-30.90.9990.80.01
语言建模5e-40.90.990.70.1
微调1e-40.950.9990.50.01
扩散模型1e-40.90.9990.60.01

6.2 学习率调度

def mars_warmup_cosine_schedule(optimizer, warmup_steps, total_steps, min_lr=1e-6):
    """
    MARS专用学习率调度:warmup + 余弦衰减
    """
    def lr_lambda(step):
        if step < warmup_steps:
            # 线性warmup
            return float(step) / float(max(1, warmup_steps))
        else:
            # 余弦衰减
            progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return max(min_lr / optimizer.defaults['lr'], 
                      0.5 * (1.0 + math.cos(math.pi * progress)))
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

6.3 完整训练示例

from mars_optimizer import MARS
 
def train_with_mars(model, train_loader, num_epochs=100):
    # 创建优化器
    optimizer = MARS(
        model.parameters(),
        lr=1e-3,
        betas=(0.9, 0.999),
        weight_decay=0.01,
        gamma=0.8,  # 方差缩减系数
        reference_interval=10
    )
    
    # 学习率调度
    scheduler = mars_warmup_cosine_schedule(optimizer, warmup_steps=1000, 
                                            total_steps=len(train_loader) * num_epochs)
    
    for epoch in range(num_epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            
            optimizer.step()
            scheduler.step()
            
            if batch_idx % 100 == 0:
                print(f"Step {optimizer.state[list(model.parameters())[0]]['step']}, "
                      f"Loss: {loss.item():.4f}, "
                      f"LR: {scheduler.get_last_lr()[0]:.6f}")

6.4 与Mixed Precision结合

from torch.cuda.amp import autocast, GradScaler
 
def train_with_mars_amp(model, train_loader, num_epochs=100):
    optimizer = MARS(model.parameters(), lr=1e-3, gamma=0.8)
    scaler = GradScaler()
    
    for epoch in range(num_epochs):
        for data, target in train_loader:
            data, target = data.cuda(), target.cuda()
            
            optimizer.zero_grad()
            
            with autocast():
                output = model(data)
                loss = F.cross_entropy(output, target)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

7. 理论分析

7.1 噪声方差的演化

定理 2(噪声方差缩减)

时刻的梯度噪声为 。则MARS满足:

其中 是与批量大小相关的常数。

推论:经过 步后,噪声方差缩减为:

7.2 最优 的选择

定理 3(最优方差缩减系数)

给定批量大小 、学习率 、条件数 ,最优 满足:

实践中,由于 Hessian 的范数难以计算,通常使用自适应策略:

def adaptive_gamma(grad_norm, running_grad_norm, momentum=0.99):
    """
    自适应调整方差缩减系数
    """
    ratio = grad_norm / (running_grad_norm + 1e-8)
    
    if ratio > 1.5:  # 梯度激增
        return 0.3  # 减少方差缩减
    elif ratio < 0.5:  # 梯度衰减
        return 0.9  # 增加方差缩减
    else:
        return 0.7  # 保持

8. 总结与展望

8.1 MARS的主要贡献

  1. 统一框架:将方差缩减与自适应学习率方法统一
  2. 方差缩减:显著降低训练后期的梯度噪声
  3. 灵活性:支持多种前置条件器和方差缩减策略
  4. 理论保证:提供严格的收敛性证明

8.2 局限性

局限性影响可能的解决方案
内存开销大批量训练时内存增加参考梯度稀疏化
超参数敏感需针对任务调整 自适应 策略
计算开销每隔几步需计算参考梯度异步计算、GPU利用率优化

8.3 未来方向

  1. 分布式MARS:针对大规模分布式训练优化
  2. 自适应前置条件器:根据训练动态选择最优前置条件器
  3. 二阶方差缩减:利用Hessian信息进行更精细的方差控制

参考

Footnotes

  1. MARS: Unifying Stochastic Optimization (arXiv:2411.10438v2)