MAE 掩码自编码器

Masked Autoencoder(MAE)是FAIR提出的一种用于视觉自监督学习的革命性方法。其核心思想源自NLP领域的BERT,通过随机掩码图像块并重建缺失像素来学习视觉表示。本章详细介绍MAE的设计原理、架构实现和实验结果。

一、MAE的设计思想

1.1 核心洞察

MAE基于一个关键洞察:非对称编码器-解码器架构

┌─────────────────────────────────────────────────────────────┐
│                    MAE 核心思想                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   输入图像                                                   │
│      │                                                     │
│      ▼                                                     │
│   ┌───────────────────┐                                    │
│   │ Patch Embedding   │ 将图像切分为非重叠patch              │
│   └─────────┬─────────┘                                    │
│             │                                              │
│             ▼                                              │
│   ┌───────────────────┐                                    │
│   │ 随机掩码 (75%)    │ 大部分patch被掩码                    │
│   └─────────┬─────────┘                                    │
│             │                                              │
│             ▼                                              │
│   ┌───────────────────┐     ┌───────────────────┐          │
│   │   编码器 (只处理   │────▶│   解码器 (处理所有 │          │
│   │   可见patch)       │     │   patch + 位置编码 │          │
│   └───────────────────┘     └─────────┬─────────┘          │
│                                       │                    │
│                                       ▼                    │
│                                 重建像素值                   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

1.2 为什么要掩码?

传统自监督方法的局限:

方法预文本任务问题
对比学习区分正负样本需要大量负样本、依赖数据增强
自动编码器重建像素全部编码导致效率低
MAE掩码重建仅编码可见patch、高效

1.3 掩码策略

import torch
import random
 
def random_masking(x, mask_ratio=0.75):
    """
    随机掩码策略
    
    Args:
        x: 输入序列 [B, N, D] N = H*W / patch_size^2
        mask_ratio: 掩码比例
    """
    B, N, D = x.shape
    len_keep = int(N * (1 - mask_ratio))
    
    # 随机采样保留的索引
    noise = torch.rand(B, N)
    ids_shuffle = torch.argsort(noise, dim=1)
    ids_restore = torch.argsort(ids_shuffle, dim=1)
    
    # 保留前len_keep个patch
    ids_keep = ids_shuffle[:, :len_keep]
    
    # 创建掩码:1=掩码,0=可见
    mask = torch.ones(B, N)
    mask[:, :len_keep] = 0
    mask = torch.gather(mask, dim=1, index=ids_restore)
    
    return ids_keep, ids_restore, mask

二、MAE架构详解

2.1 编码器(Encoder)

MAE编码器是一个标准ViT,仅处理可见(未被掩码)的patch:

class MAEEncoder(nn.Module):
    """MAE编码器:标准ViT,仅处理可见patch"""
    
    def __init__(self, img_size=224, patch_size=16, embed_dim=768, 
                 depth=12, num_heads=12, mlp_ratio=4.0):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Patch Embedding
        self.patch_embed = nn.Linear(patch_size * patch_size * 3, embed_dim)
        
        # 位置编码
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
        
        # Transformer Blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        
        self._init_weights()
        
    def forward(self, x, ids_keep):
        """
        Args:
            x: 所有patch嵌入 [B, N, D]
            ids_keep: 保留patch的索引 [B, len_keep]
        """
        B = x.shape[0]
        
        # 仅保留可见patch
        x = gather_tokens(x, ids_keep)
        
        # 添加位置编码
        pos_embed_visible = torch.gather(
            self.pos_embed.expand(B, -1, -1),
            dim=1,
            index=ids_keep.unsqueeze(-1).expand(-1, -1, x.shape[-1])
        )
        x = x + pos_embed_visible
        
        # 通过Transformer
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        
        return x

2.2 解码器(Decoder)

MAE解码器是一个更轻量的Transformer,处理所有patch(包括掩码的位置):

class MAEDecoder(nn.Module):
    """MAE解码器:重建被掩码的patch"""
    
    def __init__(self, embed_dim=512, decoder_embed_dim=512, 
                 decoder_depth=8, decoder_num_heads=16, 
                 mlp_ratio=4.0, patch_size=16, num_patches=196):
        super().__init__()
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
        
        # 掩码token(可学习)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        
        # 解码器位置编码(可学习)
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim))
        
        # 解码器Transformer
        self.decoder_blocks = nn.ModuleList([
            TransformerBlock(decoder_embed_dim, decoder_num_heads, mlp_ratio)
            for _ in range(decoder_depth)
        ])
        self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
        
        # 预测头:解码器维度 → 像素值
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * 3)
        
    def forward(self, x, ids_restore, ids_keep):
        """
        Args:
            x: 编码器输出 [B, len_keep, embed_dim]
            ids_restore: 恢复原始顺序的索引
        """
        B = x.shape[0]
        
        # 将编码器输出投影到解码器维度
        x = self.decoder_embed(x)
        
        # 构建完整序列
        # 1. 创建掩码token序列
        mask_tokens = self.mask_token.expand(B, ids_restore.shape[1] - x.shape[1], -1)
        
        # 2. 按原始顺序拼接可见token和掩码token
        x = torch.cat([x, mask_tokens], dim=1)
        
        # 3. 恢复原始顺序
        x = gather_tokens(x, ids_restore)
        
        # 4. 添加位置编码
        x = x + self.decoder_pos_embed
        
        # 5. 通过解码器
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        
        # 6. 预测像素值
        pred = self.decoder_pred(x)
        
        return pred
 
 
def gather_tokens(x, ids):
    """根据索引收集token"""
    B, L, D = x.shape
    ids = ids.unsqueeze(-1).expand(-1, -1, D)
    return torch.gather(x, dim=1, index=ids)

2.3 完整MAE模型

class MAE(nn.Module):
    """MAE完整模型"""
    
    def __init__(self, img_size=224, patch_size=16, 
                 embed_dim=768, depth=12, num_heads=12,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4.0):
        super().__init__()
        
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # 编码器
        self.encoder = MAEEncoder(
            img_size=img_size, patch_size=patch_size,
            embed_dim=embed_dim, depth=depth, 
            num_heads=num_heads, mlp_ratio=mlp_ratio
        )
        
        # 解码器
        self.decoder = MAEDecoder(
            embed_dim=embed_dim,
            decoder_embed_dim=decoder_embed_dim,
            decoder_depth=decoder_depth,
            decoder_num_heads=decoder_num_heads,
            mlp_ratio=mlp_ratio,
            patch_size=patch_size,
            num_patches=self.num_patches
        )
        
    def forward(self, x, mask_ratio=0.75):
        """
        Args:
            x: 输入图像 [B, C, H, W]
            mask_ratio: 掩码比例
        """
        # Patch嵌入
        x = self.patchify(x)  # [B, N, D_patch]
        x = self.encoder.patch_embed(x)
        
        # 添加位置编码
        B, N, D = x.shape
        x = x + self.encoder.pos_embed
        
        # 随机掩码
        ids_keep, ids_restore, mask = random_masking(x, mask_ratio)
        
        # 编码(仅可见patch)
        x_encoded = self.encoder(x, ids_keep)
        
        # 解码(所有patch)
        pred = self.decoder(x_encoded, ids_restore, ids_keep)
        
        # 仅返回被掩码patch的预测
        mask = mask.unsqueeze(-1)
        pred_masked = pred[mask.expand(-1, -1, pred.shape[-1])].reshape(-1, self.patch_size**2 * 3)
        target = self.patchify(x)  # 所有patch
        target_masked = target[mask.expand(-1, -1, target.shape[-1])].reshape(-1, self.patch_size**2 * 3)
        
        return pred_masked, target_masked, mask
    
    def patchify(self, imgs):
        """将图像切分为patch"""
        B, C, H, W = imgs.shape
        p = self.patch_size
        assert H % p == 0 and W % p == 0
        x = imgs.reshape(B, C, H // p, p, W // p, p)
        x = x.permute(0, 2, 4, 3, 5, 1)  # [B, H/p, W/p, p, p, C]
        x = x.reshape(B, (H // p) * (W // p), p * p * C)
        return x

三、训练策略

3.1 损失函数

MAE使用均方误差(MSE)损失,仅在掩码patch上计算:

def mae_loss(pred, target, mask):
    """
    MAE损失函数
    
    Args:
        pred: 预测值 [N_masked, D_patch]
        target: 目标值 [N_masked, D_patch]
        mask: 掩码 [B, N, 1]
    """
    loss = (pred - target) ** 2
    loss = loss.mean(dim=-1)  # 每patch平均
    
    # 归一化
    num_masked = mask.sum()
    loss = loss.sum() / num_masked
    
    return loss

3.2 训练配置

配置MAE-BMAE-LMAE-H
编码器ViT-BViT-LViT-H
编码器深度122432
编码器维度76810241280
编码器头数121616
解码器深度888
解码器维度512512512
掩码比例75%75%75%
批大小409640962048
训练轮数8008001600

3.3 数据增强

MAE的独特之处在于不需要strong augmentation

# MAE vs 对比学习数据增强对比
mae_augment = Compose([
    Resize(224),
    RandomResizedCrop(224, scale=(0.2, 1.0)),
    HorizontalFlip(),
    ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
 
contrastive_augment = Compose([
    # ... 更多增强
    MultiCrop(output_size=224, scales=(0.14, 1.0)),
    RandomGaussianBlur(sigma=(0.1, 2.0)),
    Solarize(threshold=0.5),
    # ...
])

四、实验结果

4.1 ImageNet分类

方法预训练Epochs预训练数据LinearFinetune
From scratch300--68.3%
Supervised300IN-1K79.9%83.6%
MoCo v3300IN-1K76.4%83.2%
BEiT300IN-1K86.3%83.1%
MAE1600IN-1K87.8%83.6%
MAE-L1600IN-1K88.6%84.9%

4.2 掩码比例的影响

┌─────────────────────────────────────────────────────────────┐
│          掩码比例 vs Linear Probe 准确率                    │
│                                                             │
│  95% ┤                                                      │
│  90% ┤                          ┌───┐                      │
│  85% ┤                    ┌───┐ │MAE│                      │
│  87% ┤              ┌───┐ │MAE│ └───┘                      │
│  88% ┤        ┌───┐ │MAE│ └───┘                            │
│  87% ┤  ┌───┐│MAE│ └───┘                                    │
│  80% ┤┌───┐└───┘                                           │
│      └┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┤
│      25%  40%  55%  70%  75%  80%  85%  90%  95%           │
│                        掩码比例                             │
│                                                             │
│  最优掩码比例: 75%                                          │
└─────────────────────────────────────────────────────────────┘

4.3 下游任务迁移

任务数据集MAE-BMAE-L
分类ImageNet83.6%84.9%
检测COCO val50.3 AP52.5 AP
分割ADE20K47.4 mIoU49.1 mIoU
视频分类K40078.2%81.0%

五、关键洞察

5.1 掩码重建的优势

  1. ** pretext任务的自然性**:掩码是一种自然的pretext任务,与语言模型中的掩码语言建模一致
  2. 高掩码率的合理性:75%的掩码率强制编码器学习更好的表示,因为需要从少量可见patch中推断更多信息
  3. 非对称设计的效率:编码器只处理可见patch,减少了75%的计算量

5.2 重建粒度的选择

重建目标描述效果
像素值直接预测RGB值✅ 简单有效
特征匹配匹配预训练特征❌ 需要额外模型
离散token预测VAE离散编码BEiT使用

5.3 与BERT的类比

方面BERTMAE
输入文本token图像patch
掩码比例15%75%
编码器输入可见token可见patch
解码器输入全部token全部patch
预测目标Masked tokenMasked patch
损失函数Cross-EntropyMSE

六、局限性

  1. 重建任务简单:像素级重建相对简单,可能限制了表示的质量
  2. 解码器训练:解码器仅用于预训练,无法迁移使用
  3. 高掩码率的训练稳定性:需要仔细的初始化和学习率调度

七、参考论文