BEiT BERT风格的视觉预训练
BEiT(Bidirectional Encoder representation from Image Transformers)是微软提出的一种将BERT预训练范式引入视觉领域的自监督学习方法。与MAE直接重建像素值不同,BEiT通过预训练的离散VAE tokenizer将图像转换为”视觉单词”,然后预测这些离散表示。本章详细介绍BEiT的设计原理、架构实现和与MAE的对比分析。
一、BEiT的核心思想
1.1 从BERT到视觉
BERT在NLP中取得巨大成功,其核心是掩码语言建模(MLM):
NLP BERT:
┌─────────────────────────────────────────────────────────────┐
│ │
│ "The cat sat on the [MASK]" │
│ │ │
│ ▼ │
│ 预测 [MASK] = "mat" │
│ │
└─────────────────────────────────────────────────────────────┘
Vision BEiT:
┌─────────────────────────────────────────────────────────────┐
│ │
│ 图像 patch: [🍎] [🍐] [🍊] [MASK] [🍋] │
│ │ │
│ ▼ │
│ 预测 [MASK] = 🍇 (离散token) │
│ │
└─────────────────────────────────────────────────────────────┘
1.2 为什么要预测离散token?
MAE直接预测像素值存在以下问题:
| 问题 | 描述 |
|---|---|
| 信息密度低 | 每个patch有数百个像素值,大部分是高频噪声 |
| 语义信息少 | 像素级重建鼓励学习低层特征而非高层语义 |
| 语义鸿沟 | 像素值与语义概念之间存在巨大差距 |
BEiT的解决方案:使用离散VAE将图像编码为语义相关的token
┌─────────────────────────────────────────────────────────────┐
│ BEiT vs MAE 对比 │
├─────────────────────────────────────────────────────────────┤
│ │
│ MAE: │
│ 图像 ──▶ Patch ──▶ 像素值 ──▶ 重建像素值 │
│ (768维) (768维MSE) │
│ │
│ BEiT: │
│ 图像 ──▶ Patch ──▶ 离散token ──▶ 预测token │
│ (1维) (1维CE) │
│ ↓ │
│ 离散VAE tokenizer │
│ (学习语义表示) │
│ │
└─────────────────────────────────────────────────────────────┘
二、离散VAE Tokenizer
2.1 dVAE架构
BEiT使用离散变分自编码器(dVAE)将图像patch转换为离散token:
class dVAE(nn.Module):
"""
离散VAE:编码器 + Gumbel-Softmax + 解码器
核心思想:将连续特征向量映射到固定大小的离散codebook
"""
def __init__(self, vocab_size=8192, img_size=224, patch_size=16,
encoder_dim=256, embed_dim=256):
super().__init__()
self.vocab_size = vocab_size
self.patch_size = patch_size
# 编码器:将patch映射到连续表示
self.encoder = nn.Sequential(
nn.Linear(patch_size * patch_size * 3, encoder_dim),
nn.ReLU(),
nn.Linear(encoder_dim, encoder_dim),
nn.ReLU(),
nn.Linear(encoder_dim, embed_dim),
)
# 码本(Codebook):8192个可学习向量
self.codebook = nn.Embedding(vocab_size, embed_dim)
# 解码器:从code重建像素
self.decoder = nn.Sequential(
nn.Linear(embed_dim, encoder_dim),
nn.ReLU(),
nn.Linear(encoder_dim, encoder_dim),
nn.ReLU(),
nn.Linear(encoder_dim, patch_size * patch_size * 3),
)
def encode(self, x):
"""
编码:图像patch → 连续向量 → 离散token
"""
z = self.encoder(x) # [B*N, embed_dim]
# Gumbel-Softmax采样
logits = z @ self.codebook.weight.T # [B*N, vocab_size]
# Gumbel-Softmax重参数化
gumbel = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
hard_one_hot = F.one_hot(logits.argmax(dim=-1), self.vocab_size).float()
z_q = hard_one_hot @ self.codebook.weight - z.detach() + z
return z_q, logits
def decode(self, z_q):
"""解码:token → 像素"""
return self.decoder(z_q)
def forward(self, x):
z_q, logits = self.encode(x)
x_recon = self.decode(z_q)
return x_recon, logits, z_q2.2 Gumbel-Softmax重参数化
def gumbel_softmax(logits, temperature=1.0, hard=True):
"""
Gumbel-Softmax重参数化
使得离散的采样操作可导
"""
# 采样Gumbel噪声
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
# 加到logits上
y = (logits + gumbel_noise) / temperature
# Softmax
y = F.softmax(y, dim=-1)
if hard:
# 硬采样:one-hot
y_hard = torch.zeros_like(y)
y_hard[torch.arange(y.shape[0]), y.argmax(dim=-1)] = 1
y = (y_hard - y).detach() + y
return y2.3 Tokenizer训练策略
def train_dvae(dvae, image_loader, epochs=10):
"""训练离散VAE tokenizer"""
optimizer = torch.optim.Adam(dvae.parameters(), lr=1e-4)
for epoch in range(epochs):
for images in image_loader:
patches = patchify(images) # [B, N, D]
# 重建损失
recon, logits, z_q = dvae(patches)
recon_loss = F.mse_loss(recon, patches)
# VAE损失
# KL散度(先验 = 均匀分布)
q = F.softmax(logits, dim=-1)
kl_loss = q * (torch.log(q + 1e-20) - torch.log(1.0 / 8192))
kl_loss = kl_loss.sum(dim=-1).mean()
# 损失 = 重建 + β * KL
loss = recon_loss + 0.1 * kl_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()三、BEiT预训练
3.1 预训练目标
BEiT有两种预训练目标:
变体1:预测离散token(BEiT v1)
class BEiTPretraining(nn.Module):
"""BEiT v1: 预测离散token"""
def __init__(self, vocab_size=8192, embed_dim=768, depth=12, num_heads=12):
super().__init__()
# Transformer编码器
self.transformer = TransformerEncoder(depth, embed_dim, num_heads)
# 预测头:预测被掩码patch的token
self.decoder = nn.Linear(embed_dim, vocab_size)
# Tokenizer(冻结)
self.dvae = load_pretrained_dvae()
def forward(self, images, mask_ratio=0.4):
# Patch化
patches = patchify(images) # [B, N, D]
# 获取token
with torch.no_grad():
_, tokens = self.dvae.encode(patches) # [B, N]
# 掩码
masked_patches, mask, ids_restore = random_masking(patches, mask_ratio)
# 编码
features = self.transformer(masked_patches)
# 预测token
pred_tokens = self.decoder(features)
# 损失:交叉熵
# 仅在被掩码位置计算
loss = F.cross_entropy(pred_tokens[mask], tokens[mask])
return loss变体2:预测像素值 + token(BEiT v2)
class BEiTv2Pretraining(nn.Module):
"""BEiT v2: 多任务学习"""
def forward(self, images, mask_ratio=0.4):
patches = patchify(images)
masked_patches, mask, _ = random_masking(patches, mask_ratio)
features = self.transformer(masked_patches)
# 任务1: 预测token
pred_tokens = self.token_head(features)
tokens = self.dvae.encode(patches) # ground truth tokens
# 任务2: 预测像素值
pred_pixels = self.pixel_head(features)
# 联合损失
loss = F.cross_entropy(pred_tokens[mask], tokens[mask]) + \
F.mse_loss(pred_pixels[mask], patches[mask])
return loss3.2 掩码策略
BEiT使用与BERT类似的块级掩码策略:
def block_masking(x, mask_ratio=0.4, block_size=4):
"""
块级掩码:将连续的patch块一起掩码
比随机掩码更难,提供更好的预训练效果
"""
B, N, D = x.shape
num_masked = int(N * mask_ratio)
# 构建块掩码
num_blocks = N // (block_size * block_size)
# 随机选择要掩码的块
noise = torch.rand(B, num_blocks)
block_rank = torch.argsort(noise, dim=-1)
# 前num_masked块被掩码
blocks_to_mask = block_rank[:, :num_masked // (block_size * block_size)]
mask = torch.zeros(B, N)
for b in range(B):
for blk in blocks_to_mask[b]:
i = blk // block_size
j = blk % block_size
mask[b, i*block_size:(i+1)*block_size,
j*block_size:(j+1)*block_size] = 1
return mask.bool()四、实验结果
4.1 ImageNet分类
| 方法 | 预训练 | Linear | Finetune |
|---|---|---|---|
| From scratch | - | - | 68.3% |
| Supervised | IN-1K | 79.9% | 83.6% |
| MoCo v3 | 300ep | 76.4% | 83.2% |
| BEiT | 800ep | 86.3% | 83.1% |
| BEiT v2 | 800ep | 87.3% | 85.2% |
| MAE | 1600ep | 87.8% | 83.6% |
4.2 消融实验
Tokenizer Vocab Size
| Vocab Size | Linear | Finetune |
|---|---|---|
| 512 | 83.2% | 82.4% |
| 2048 | 85.1% | 82.8% |
| 8192 | 86.3% | 83.1% |
| 32768 | 86.1% | 83.0% |
结论:Vocab Size=8192是最优选择
掩码比例
| 掩码比例 | Linear |
|---|---|
| 15% | 82.1% |
| 30% | 84.8% |
| 40% | 86.3% |
| 50% | 86.0% |
| 75% | 85.2% |
结论:BEiT的最优掩码比例约为40%,远低于MAE的75%
4.3 下游任务迁移
| 任务 | 数据集 | BEiT | BEiT v2 | MAE |
|---|---|---|---|---|
| 分类 | ImageNet | 83.1% | 85.2% | 83.6% |
| 检测 | COCO | 48.5 AP | 50.3 AP | 50.3 AP |
| 分割 | ADE20K | 48.1 mIoU | 50.1 mIoU | 47.4 mIoU |
五、BEiT vs MAE 对比分析
5.1 核心区别
| 方面 | BEiT | MAE |
|---|---|---|
| 重建目标 | 离散token | 像素值 |
| 预测难度 | 分类(8192类) | 回归(256维) |
| Tokenizer | 需要预训练dVAE | 无需 |
| 掩码比例 | 40% | 75% |
| 编码器 | 处理所有patch | 仅处理可见patch |
| 解码器 | 轻量(BEiT v1无解码器) | 重量(8层Transformer) |
5.2 训练效率
| 指标 | BEiT | MAE |
|---|---|---|
| GPU Memory | 较高 | 较低 |
| 训练时间 | 相近 | 相近 |
| 收敛速度 | 快 | 慢 |
5.3 表示特性
┌─────────────────────────────────────────────────────────────┐
│ 表示特性对比 │
├─────────────────────────────────────────────────────────────┤
│ │
│ BEiT: │
│ - 学习更多语义特征 │
│ - 离散token提供更好的语义监督 │
│ - 适合分类任务 │
│ │
│ MAE: │
│ - 学习更多低层视觉特征 │
│ - 像素级重建提供更丰富的监督信号 │
│ - 适合密集预测任务 │
│ │
└─────────────────────────────────────────────────────────────┘
六、代码实现
6.1 完整BEiT预训练
class BEiTForPretraining(nn.Module):
"""BEiT预训练模型"""
def __init__(self, config):
super().__init__()
self.vocab_size = config.vocab_size # 8192
self.embed_dim = config.embed_dim # 768
self.depth = config.depth # 12
self.num_heads = config.num_heads # 12
# Patch嵌入
self.patch_embed = nn.Linear(16*16*3, self.embed_dim)
# 位置编码
self.pos_embed = nn.Parameter(torch.zeros(1, 196, self.embed_dim))
# cls token
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
# Transformer编码器
self.blocks = nn.ModuleList([
TransformerBlock(self.embed_dim, self.num_heads)
for _ in range(self.depth)
])
self.norm = nn.LayerNorm(self.embed_dim)
# 预测头
self.head = nn.Linear(self.embed_dim, self.vocab_size)
def forward(self, images, masked_pos=None):
# Patch嵌入
x = self.patchify(images) # [B, N, D_patch]
x = self.patch_embed(x)
# 添加位置编码
x = x + self.pos_embed
# 掩码
if masked_pos is not None:
x[masked_pos] = 0 # 或使用可学习的mask token
# Transformer编码
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
# 预测
pred = self.head(x[masked_pos])
return pred
def patchify(self, imgs):
p = 16
assert imgs.shape[2] % p == 0 and imgs.shape[3] % p == 0
x = imgs.reshape(imgs.shape[0], 3, imgs.shape[2] // p, p, imgs.shape[3] // p, p)
x = x.permute(0, 2, 4, 3, 5, 1).reshape(imgs.shape[0], -1, p*p*3)
return x七、关键洞察
7.1 语义 vs 低层特征
- BEiT通过离散token提供语义级别的监督
- MAE通过像素值提供低层特征的监督
- 两者可以互补,BEiT v2结合两者
7.2 掩码策略的重要性
- 块级掩码比随机掩码更有效
- 掩码比例需要针对任务调优
7.3 Tokenizer的作用
- dVAE tokenizer将像素映射到语义空间
- codebook大小影响表示质量