MAE 掩码自编码器
Masked Autoencoder(MAE)是FAIR提出的一种用于视觉自监督学习的革命性方法。其核心思想源自NLP领域的BERT,通过随机掩码图像块并重建缺失像素来学习视觉表示。本章详细介绍MAE的设计原理、架构实现和实验结果。
一、MAE的设计思想
1.1 核心洞察
MAE基于一个关键洞察:非对称编码器-解码器架构
┌─────────────────────────────────────────────────────────────┐
│ MAE 核心思想 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 输入图像 │
│ │ │
│ ▼ │
│ ┌───────────────────┐ │
│ │ Patch Embedding │ 将图像切分为非重叠patch │
│ └─────────┬─────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────┐ │
│ │ 随机掩码 (75%) │ 大部分patch被掩码 │
│ └─────────┬─────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────┐ ┌───────────────────┐ │
│ │ 编码器 (只处理 │────▶│ 解码器 (处理所有 │ │
│ │ 可见patch) │ │ patch + 位置编码 │ │
│ └───────────────────┘ └─────────┬─────────┘ │
│ │ │
│ ▼ │
│ 重建像素值 │
│ │
└─────────────────────────────────────────────────────────────┘
1.2 为什么要掩码?
传统自监督方法的局限:
| 方法 | 预文本任务 | 问题 |
|---|---|---|
| 对比学习 | 区分正负样本 | 需要大量负样本、依赖数据增强 |
| 自动编码器 | 重建像素 | 全部编码导致效率低 |
| MAE | 掩码重建 | 仅编码可见patch、高效 |
1.3 掩码策略
import torch
import random
def random_masking(x, mask_ratio=0.75):
"""
随机掩码策略
Args:
x: 输入序列 [B, N, D] N = H*W / patch_size^2
mask_ratio: 掩码比例
"""
B, N, D = x.shape
len_keep = int(N * (1 - mask_ratio))
# 随机采样保留的索引
noise = torch.rand(B, N)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# 保留前len_keep个patch
ids_keep = ids_shuffle[:, :len_keep]
# 创建掩码:1=掩码,0=可见
mask = torch.ones(B, N)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return ids_keep, ids_restore, mask二、MAE架构详解
2.1 编码器(Encoder)
MAE编码器是一个标准ViT,仅处理可见(未被掩码)的patch:
class MAEEncoder(nn.Module):
"""MAE编码器:标准ViT,仅处理可见patch"""
def __init__(self, img_size=224, patch_size=16, embed_dim=768,
depth=12, num_heads=12, mlp_ratio=4.0):
super().__init__()
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# Patch Embedding
self.patch_embed = nn.Linear(patch_size * patch_size * 3, embed_dim)
# 位置编码
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
# Transformer Blocks
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self._init_weights()
def forward(self, x, ids_keep):
"""
Args:
x: 所有patch嵌入 [B, N, D]
ids_keep: 保留patch的索引 [B, len_keep]
"""
B = x.shape[0]
# 仅保留可见patch
x = gather_tokens(x, ids_keep)
# 添加位置编码
pos_embed_visible = torch.gather(
self.pos_embed.expand(B, -1, -1),
dim=1,
index=ids_keep.unsqueeze(-1).expand(-1, -1, x.shape[-1])
)
x = x + pos_embed_visible
# 通过Transformer
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x2.2 解码器(Decoder)
MAE解码器是一个更轻量的Transformer,处理所有patch(包括掩码的位置):
class MAEDecoder(nn.Module):
"""MAE解码器:重建被掩码的patch"""
def __init__(self, embed_dim=512, decoder_embed_dim=512,
decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4.0, patch_size=16, num_patches=196):
super().__init__()
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
# 掩码token(可学习)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
# 解码器位置编码(可学习)
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim))
# 解码器Transformer
self.decoder_blocks = nn.ModuleList([
TransformerBlock(decoder_embed_dim, decoder_num_heads, mlp_ratio)
for _ in range(decoder_depth)
])
self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
# 预测头:解码器维度 → 像素值
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * 3)
def forward(self, x, ids_restore, ids_keep):
"""
Args:
x: 编码器输出 [B, len_keep, embed_dim]
ids_restore: 恢复原始顺序的索引
"""
B = x.shape[0]
# 将编码器输出投影到解码器维度
x = self.decoder_embed(x)
# 构建完整序列
# 1. 创建掩码token序列
mask_tokens = self.mask_token.expand(B, ids_restore.shape[1] - x.shape[1], -1)
# 2. 按原始顺序拼接可见token和掩码token
x = torch.cat([x, mask_tokens], dim=1)
# 3. 恢复原始顺序
x = gather_tokens(x, ids_restore)
# 4. 添加位置编码
x = x + self.decoder_pos_embed
# 5. 通过解码器
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# 6. 预测像素值
pred = self.decoder_pred(x)
return pred
def gather_tokens(x, ids):
"""根据索引收集token"""
B, L, D = x.shape
ids = ids.unsqueeze(-1).expand(-1, -1, D)
return torch.gather(x, dim=1, index=ids)2.3 完整MAE模型
class MAE(nn.Module):
"""MAE完整模型"""
def __init__(self, img_size=224, patch_size=16,
embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4.0):
super().__init__()
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# 编码器
self.encoder = MAEEncoder(
img_size=img_size, patch_size=patch_size,
embed_dim=embed_dim, depth=depth,
num_heads=num_heads, mlp_ratio=mlp_ratio
)
# 解码器
self.decoder = MAEDecoder(
embed_dim=embed_dim,
decoder_embed_dim=decoder_embed_dim,
decoder_depth=decoder_depth,
decoder_num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
patch_size=patch_size,
num_patches=self.num_patches
)
def forward(self, x, mask_ratio=0.75):
"""
Args:
x: 输入图像 [B, C, H, W]
mask_ratio: 掩码比例
"""
# Patch嵌入
x = self.patchify(x) # [B, N, D_patch]
x = self.encoder.patch_embed(x)
# 添加位置编码
B, N, D = x.shape
x = x + self.encoder.pos_embed
# 随机掩码
ids_keep, ids_restore, mask = random_masking(x, mask_ratio)
# 编码(仅可见patch)
x_encoded = self.encoder(x, ids_keep)
# 解码(所有patch)
pred = self.decoder(x_encoded, ids_restore, ids_keep)
# 仅返回被掩码patch的预测
mask = mask.unsqueeze(-1)
pred_masked = pred[mask.expand(-1, -1, pred.shape[-1])].reshape(-1, self.patch_size**2 * 3)
target = self.patchify(x) # 所有patch
target_masked = target[mask.expand(-1, -1, target.shape[-1])].reshape(-1, self.patch_size**2 * 3)
return pred_masked, target_masked, mask
def patchify(self, imgs):
"""将图像切分为patch"""
B, C, H, W = imgs.shape
p = self.patch_size
assert H % p == 0 and W % p == 0
x = imgs.reshape(B, C, H // p, p, W // p, p)
x = x.permute(0, 2, 4, 3, 5, 1) # [B, H/p, W/p, p, p, C]
x = x.reshape(B, (H // p) * (W // p), p * p * C)
return x三、训练策略
3.1 损失函数
MAE使用均方误差(MSE)损失,仅在掩码patch上计算:
def mae_loss(pred, target, mask):
"""
MAE损失函数
Args:
pred: 预测值 [N_masked, D_patch]
target: 目标值 [N_masked, D_patch]
mask: 掩码 [B, N, 1]
"""
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # 每patch平均
# 归一化
num_masked = mask.sum()
loss = loss.sum() / num_masked
return loss3.2 训练配置
| 配置 | MAE-B | MAE-L | MAE-H |
|---|---|---|---|
| 编码器 | ViT-B | ViT-L | ViT-H |
| 编码器深度 | 12 | 24 | 32 |
| 编码器维度 | 768 | 1024 | 1280 |
| 编码器头数 | 12 | 16 | 16 |
| 解码器深度 | 8 | 8 | 8 |
| 解码器维度 | 512 | 512 | 512 |
| 掩码比例 | 75% | 75% | 75% |
| 批大小 | 4096 | 4096 | 2048 |
| 训练轮数 | 800 | 800 | 1600 |
3.3 数据增强
MAE的独特之处在于不需要strong augmentation:
# MAE vs 对比学习数据增强对比
mae_augment = Compose([
Resize(224),
RandomResizedCrop(224, scale=(0.2, 1.0)),
HorizontalFlip(),
ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
contrastive_augment = Compose([
# ... 更多增强
MultiCrop(output_size=224, scales=(0.14, 1.0)),
RandomGaussianBlur(sigma=(0.1, 2.0)),
Solarize(threshold=0.5),
# ...
])四、实验结果
4.1 ImageNet分类
| 方法 | 预训练Epochs | 预训练数据 | Linear | Finetune |
|---|---|---|---|---|
| From scratch | 300 | - | - | 68.3% |
| Supervised | 300 | IN-1K | 79.9% | 83.6% |
| MoCo v3 | 300 | IN-1K | 76.4% | 83.2% |
| BEiT | 300 | IN-1K | 86.3% | 83.1% |
| MAE | 1600 | IN-1K | 87.8% | 83.6% |
| MAE-L | 1600 | IN-1K | 88.6% | 84.9% |
4.2 掩码比例的影响
┌─────────────────────────────────────────────────────────────┐
│ 掩码比例 vs Linear Probe 准确率 │
│ │
│ 95% ┤ │
│ 90% ┤ ┌───┐ │
│ 85% ┤ ┌───┐ │MAE│ │
│ 87% ┤ ┌───┐ │MAE│ └───┘ │
│ 88% ┤ ┌───┐ │MAE│ └───┘ │
│ 87% ┤ ┌───┐│MAE│ └───┘ │
│ 80% ┤┌───┐└───┘ │
│ └┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┤
│ 25% 40% 55% 70% 75% 80% 85% 90% 95% │
│ 掩码比例 │
│ │
│ 最优掩码比例: 75% │
└─────────────────────────────────────────────────────────────┘
4.3 下游任务迁移
| 任务 | 数据集 | MAE-B | MAE-L |
|---|---|---|---|
| 分类 | ImageNet | 83.6% | 84.9% |
| 检测 | COCO val | 50.3 AP | 52.5 AP |
| 分割 | ADE20K | 47.4 mIoU | 49.1 mIoU |
| 视频分类 | K400 | 78.2% | 81.0% |
五、关键洞察
5.1 掩码重建的优势
- ** pretext任务的自然性**:掩码是一种自然的pretext任务,与语言模型中的掩码语言建模一致
- 高掩码率的合理性:75%的掩码率强制编码器学习更好的表示,因为需要从少量可见patch中推断更多信息
- 非对称设计的效率:编码器只处理可见patch,减少了75%的计算量
5.2 重建粒度的选择
| 重建目标 | 描述 | 效果 |
|---|---|---|
| 像素值 | 直接预测RGB值 | ✅ 简单有效 |
| 特征匹配 | 匹配预训练特征 | ❌ 需要额外模型 |
| 离散token | 预测VAE离散编码 | BEiT使用 |
5.3 与BERT的类比
| 方面 | BERT | MAE |
|---|---|---|
| 输入 | 文本token | 图像patch |
| 掩码比例 | 15% | 75% |
| 编码器输入 | 可见token | 可见patch |
| 解码器输入 | 全部token | 全部patch |
| 预测目标 | Masked token | Masked patch |
| 损失函数 | Cross-Entropy | MSE |
六、局限性
- 重建任务简单:像素级重建相对简单,可能限制了表示的质量
- 解码器训练:解码器仅用于预训练,无法迁移使用
- 高掩码率的训练稳定性:需要仔细的初始化和学习率调度