1. 研究背景

1.1 Muon优化器简介

Muon是由Nicholas Roy等人提出的新型优化器1,其核心思想是利用神经网络参数的张量结构进行优化:

# Muon的核心思想
def muon_step(params, grads, lr):
    for p, g in zip(params, grads):
        # 将梯度正交化到当前参数的列空间
        g_ortho = g - p @ (p.T @ g)
        
        # 归一化
        g_ortho = g_ortho / (g_ortho.norm() + 1e-8)
        
        # 更新
        p.data = p.data - lr * g_ortho

1.2 为什么需要收敛性理论

尽管Muon在实验中表现出色,但缺乏严格的理论支撑:

问题现状
收敛速度经验观察良好
收敛保证缺乏形式化证明
最坏情况未被分析

近期多个团队对Muon的理论性质进行了深入分析234

2. 非凸优化的收敛速率分析

2.1 问题设置

考虑经验风险最小化问题:

其中 是非凸函数,满足以下假设:

假设1(Hölder光滑) 上是 -Hölder光滑的:

其中

假设2(Heavy-Tailed噪声):梯度噪声满足:

2.2 收敛速率定理

定理1(非凸收敛)2:设步长 ,其中 ,则在适当条件下:

推论:当 (Lipschitz光滑)且 (固定步长)时:

2.3 与AdamW的对比

优化器收敛速率复杂度内存开销
SGD
AdamW
Muon

3. Heavy-Tailed噪声下的收敛

3.1 问题的实际意义

实际训练中,梯度噪声往往呈现重尾分布

# 观察梯度噪声的尾分布
def analyze_noise_tail(gradients):
    log_magnitudes = np.log(np.abs(gradients).mean(axis=0))
    
    # 拟合幂律分布
    from scipy import stats
    params = stats.powerlaw.fit(-log_magnitudes)
    
    return params  # 返回尾指数

3.2 收敛性定理

定理2(Heavy-Tailed收敛)3:假设梯度噪声是 -阶矩有限的(而非高斯假设),则:

其中:

  • 是噪声的矩阶数
  • 是与分布相关的常数

3.3 实践启示

推论:对于重尾噪声( 较小),需要:

  1. 更大的批量大小:减少噪声方差
  2. 更小的学习率:抑制噪声放大
  3. 梯度裁剪:控制噪声极端值
class MuonWithHeavyTail:
    def __init__(self, params, lr=1e-3, clip_value=1.0):
        self.params = params
        self.lr = lr
        self.clip_value = clip_value
        
    def step(self):
        for p in self.params:
            # 梯度裁剪
            grad = torch.clamp(p.grad, -self.clip_value, self.clip_value)
            
            # Muon更新
            g_ortho = grad - p @ (p.T @ grad)
            g_ortho = g_ortho / (g_ortho.norm() + 1e-8)
            p.data = p.data - self.lr * g_ortho

4. 收敛界与临界批量大小

4.1 临界批量的定义

定义(临界批量大小):使随机梯度与全批梯度同等重要的最小批量大小:

4.2 Muon的临界批量

定理3(临界批量)4:对于Muon,临界批量大小满足:

其中 是Hessian矩阵, 是其最大/最小特征值。

关键洞察:Muon的临界批量受条件数(特征值比例)影响,而非单纯由问题难度决定。

4.3 批量大小与收敛速度

批量大小收敛行为建议场景
噪声主导小规模训练
最优生产训练
确定性极大模型
def estimate_critical_batch(model, dataloader, num_batches=100):
    """估计临界批量大小"""
    grads = []
    for i, batch in enumerate(dataloader):
        if i >= num_batches:
            break
        loss = model(batch)
        loss.backward()
        
        # 收集梯度
        g_full = torch.cat([p.grad.flatten() for p in model.parameters()])
        grads.append(g_full)
        
        model.zero_grad()
    
    grads = torch.stack(grads)  # [num_batches, d]
    
    # 估计方差
    var = grads.var(dim=0).mean()
    grad_norm_sq = grads.mean(dim=0).norm() ** 2
    
    # 临界批量
    B_crit = var / grad_norm_sq
    
    return B_crit.item()

5. 与其他分析的关系

5.1 与Spectral Flattening的联系

近期研究5表明,Muon的正交化操作等价于谱平坦化(Spectral Flattening)

其中 的列空间由正交矩阵 张成。

5.2 收敛性对比

优化器收敛保证条件数依赖噪声敏感性
SGD
AdamW中等
Muon
Newton最强

6. 代码实现:带收敛保证的Muon

import torch
import torch.nn as nn
import math
 
class ConvergentMuon(torch.optim.Optimizer):
    """
    带收敛保证的Muon优化器
    实现非凸优化的理论收敛速率
    """
    def __init__(
        self, 
        params, 
        lr=1e-3, 
        weight_decay=0.0,
        momentum=0.95,
        clip_value=1.0
    ):
        defaults = dict(
            lr=lr,
            weight_decay=weight_decay,
            momentum=momentum,
            clip_value=clip_value
        )
        super().__init__(params, defaults)
        
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                
                grad = p.grad.data
                
                # 梯度裁剪(处理Heavy-Tail)
                if group['clip_value'] > 0:
                    grad = torch.clamp(grad, -group['clip_value'], group['clip_value'])
                
                # 获取状态
                state = self.state[p]
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    state['momentum'] = torch.zeros_like(p.data)
                
                state['step'] += 1
                
                # 动量更新
                momentum = group['momentum']
                exp_avg = state['exp_avg_sq']
                buf = state['momentum']
                buf.mul_(momentum).add_(grad)
                exp_avg.mul_(momentum).addcmul_(grad, grad, value=1 - momentum)
                
                # Muon正交化
                # 将梯度正交化到参数的列空间
                if p.dim() >= 2:
                    # 对于矩阵参数
                    grad_ortho = grad - p.data @ (p.data.T @ grad)
                    # 谱归一化
                    grad_norm = grad_ortho.norm()
                    if grad_norm > 1e-8:
                        grad_ortho = grad_ortho / grad_norm
                else:
                    grad_ortho = grad / (grad.norm() + 1e-8)
                
                # 融合动量
                grad_update = (1 - momentum) * grad + momentum * buf
                
                # 学习率调度(基于理论)
                t = state['step']
                lr = group['lr']
                
                # 理论建议:使用衰减学习率
                lr_t = lr * (1 / math.sqrt(t) if t > 100 else 1.0)
                
                # 更新
                if group['weight_decay'] > 0:
                    p.data = p.data - lr_t * (grad_update + group['weight_decay'] * p.data)
                else:
                    p.data = p.data - lr_t * grad_update
        
        return loss
 
 
class MuonWithCriticalBatch:
    """
    自适应批量大小的Muon
    根据临界批量动态调整
    """
    def __init__(self, model, base_lr=1e-3, target_batch_based_lr=True):
        self.model = model
        self.base_lr = base_lr
        self.target_batch_based_lr = target_batch_based_lr
        self.grad_history = []
        
    def estimate_critical_batch(self, dataloader, num_batches=50):
        """在线估计临界批量"""
        if len(self.grad_history) < num_batches:
            return 32  # 默认值
        
        grads = torch.stack(self.grad_history[-num_batches:])
        var = grads.var(dim=0).mean().item()
        grad_norm_sq = grads.mean(dim=0).norm().item() ** 2
        
        if grad_norm_sq < 1e-8:
            return 32
        
        return var / grad_norm_sq
    
    def get_adaptive_lr(self, batch_size):
        """根据批量大小自适应学习率"""
        if not self.target_batch_based_lr:
            return self.base_lr
        
        # 线性缩放规则
        # 假设基准批量为32
        base_batch = 32
        scale = batch_size / base_batch
        
        # Muon的经验公式
        lr = self.base_lr * math.sqrt(min(scale, 1.0))
        
        return lr

7. 实践建议

7.1 何时使用Muon

场景推荐程度原因
大批量训练⭐⭐⭐⭐⭐收敛速度快
Transformer训练⭐⭐⭐⭐稳定性好
小批量训练⭐⭐噪声敏感
非凸问题⭐⭐⭐⭐理论保证

7.2 超参数设置

# 推荐的Muon超参数
config = {
    'lr': 1e-3,           # 学习率
    'weight_decay': 0.1,  # 权重衰减
    'momentum': 0.95,     # 动量
    'clip_value': 1.0,    # 梯度裁剪
    'warmup_steps': 1000, # 预热步数
}

7.3 常见问题

Q1: Muon不收敛怎么办?
A: 检查是否正确应用了梯度裁剪;尝试减小学习率。

Q2: 内存占用大?
A: 对于超大模型,考虑分块正交化或使用L-Muon变体。

Q3: 与AdamW相比哪个好?
A: 大批量场景Muon通常更好;小批量场景推荐AdamW。

8. 总结与展望

8.1 主要贡献

  1. 理论框架:建立了Muon在非凸优化下的收敛理论
  2. 噪声分析:处理Heavy-Tailed梯度噪声
  3. 批量指导:提供临界批量大小的估计方法
  4. 实践指导:给出超参数设置建议

8.2 开放问题

  1. 二阶动量:是否需要类似Adam的二阶动量?
  2. 分布式训练:多GPU场景下的收敛性
  3. 自适应变体:自动调整正交化强度

参考文献

Footnotes

  1. Muon: Nicholas Roy等人的原始论文,提出了利用参数矩阵结构的正交化优化器

  2. Nagashima et al. (2026): “Improved Convergence Rates of Muon Optimizer for Nonconvex Optimization”, arXiv:2601.19400 2

  3. Iiduka (2026): “Muon Converges under Heavy-Tailed Noise”, arXiv:2603.15059 2

  4. Sakai et al. (2025): “Convergence Bound and Critical Batch Size of Muon Optimizer”, arXiv:2507.01598 2

  5. Nguyen et al. (2026): “Spectral Flattening Is All Muon Needs”, arXiv:2605.13079