BEiT BERT风格的视觉预训练

BEiT(Bidirectional Encoder representation from Image Transformers)是微软提出的一种将BERT预训练范式引入视觉领域的自监督学习方法。与MAE直接重建像素值不同,BEiT通过预训练的离散VAE tokenizer将图像转换为”视觉单词”,然后预测这些离散表示。本章详细介绍BEiT的设计原理、架构实现和与MAE的对比分析。

一、BEiT的核心思想

1.1 从BERT到视觉

BERT在NLP中取得巨大成功,其核心是掩码语言建模(MLM):

NLP BERT:
┌─────────────────────────────────────────────────────────────┐
│                                                             │
│   "The cat sat on the [MASK]"                              │
│                     │                                       │
│                     ▼                                       │
│            预测 [MASK] = "mat"                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Vision BEiT:
┌─────────────────────────────────────────────────────────────┐
│                                                             │
│   图像 patch: [🍎] [🍐] [🍊] [MASK] [🍋]                    │
│                      │                                      │
│                      ▼                                      │
│            预测 [MASK] = 🍇 (离散token)                     │
│                                                             │
└─────────────────────────────────────────────────────────────┘

1.2 为什么要预测离散token?

MAE直接预测像素值存在以下问题:

问题描述
信息密度低每个patch有数百个像素值,大部分是高频噪声
语义信息少像素级重建鼓励学习低层特征而非高层语义
语义鸿沟像素值与语义概念之间存在巨大差距

BEiT的解决方案:使用离散VAE将图像编码为语义相关的token

┌─────────────────────────────────────────────────────────────┐
│                    BEiT vs MAE 对比                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   MAE:                                                      │
│   图像 ──▶ Patch ──▶ 像素值 ──▶ 重建像素值                  │
│                    (768维)         (768维MSE)               │
│                                                             │
│   BEiT:                                                     │
│   图像 ──▶ Patch ──▶ 离散token ──▶ 预测token                │
│                    (1维)             (1维CE)                │
│                    ↓                                        │
│              离散VAE tokenizer                              │
│              (学习语义表示)                                  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

二、离散VAE Tokenizer

2.1 dVAE架构

BEiT使用离散变分自编码器(dVAE)将图像patch转换为离散token:

class dVAE(nn.Module):
    """
    离散VAE:编码器 + Gumbel-Softmax + 解码器
    
    核心思想:将连续特征向量映射到固定大小的离散codebook
    """
    
    def __init__(self, vocab_size=8192, img_size=224, patch_size=16, 
                 encoder_dim=256, embed_dim=256):
        super().__init__()
        self.vocab_size = vocab_size
        self.patch_size = patch_size
        
        # 编码器:将patch映射到连续表示
        self.encoder = nn.Sequential(
            nn.Linear(patch_size * patch_size * 3, encoder_dim),
            nn.ReLU(),
            nn.Linear(encoder_dim, encoder_dim),
            nn.ReLU(),
            nn.Linear(encoder_dim, embed_dim),
        )
        
        # 码本(Codebook):8192个可学习向量
        self.codebook = nn.Embedding(vocab_size, embed_dim)
        
        # 解码器:从code重建像素
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, encoder_dim),
            nn.ReLU(),
            nn.Linear(encoder_dim, encoder_dim),
            nn.ReLU(),
            nn.Linear(encoder_dim, patch_size * patch_size * 3),
        )
        
    def encode(self, x):
        """
        编码:图像patch → 连续向量 → 离散token
        """
        z = self.encoder(x)  # [B*N, embed_dim]
        
        # Gumbel-Softmax采样
        logits = z @ self.codebook.weight.T  # [B*N, vocab_size]
        
        # Gumbel-Softmax重参数化
        gumbel = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
        hard_one_hot = F.one_hot(logits.argmax(dim=-1), self.vocab_size).float()
        z_q = hard_one_hot @ self.codebook.weight - z.detach() + z
        
        return z_q, logits
    
    def decode(self, z_q):
        """解码:token → 像素"""
        return self.decoder(z_q)
    
    def forward(self, x):
        z_q, logits = self.encode(x)
        x_recon = self.decode(z_q)
        return x_recon, logits, z_q

2.2 Gumbel-Softmax重参数化

def gumbel_softmax(logits, temperature=1.0, hard=True):
    """
    Gumbel-Softmax重参数化
    
    使得离散的采样操作可导
    """
    # 采样Gumbel噪声
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
    
    # 加到logits上
    y = (logits + gumbel_noise) / temperature
    
    # Softmax
    y = F.softmax(y, dim=-1)
    
    if hard:
        # 硬采样:one-hot
        y_hard = torch.zeros_like(y)
        y_hard[torch.arange(y.shape[0]), y.argmax(dim=-1)] = 1
        y = (y_hard - y).detach() + y
    
    return y

2.3 Tokenizer训练策略

def train_dvae(dvae, image_loader, epochs=10):
    """训练离散VAE tokenizer"""
    
    optimizer = torch.optim.Adam(dvae.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        for images in image_loader:
            patches = patchify(images)  # [B, N, D]
            
            # 重建损失
            recon, logits, z_q = dvae(patches)
            recon_loss = F.mse_loss(recon, patches)
            
            # VAE损失
            # KL散度(先验 = 均匀分布)
            q = F.softmax(logits, dim=-1)
            kl_loss = q * (torch.log(q + 1e-20) - torch.log(1.0 / 8192))
            kl_loss = kl_loss.sum(dim=-1).mean()
            
            # 损失 = 重建 + β * KL
            loss = recon_loss + 0.1 * kl_loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

三、BEiT预训练

3.1 预训练目标

BEiT有两种预训练目标:

变体1:预测离散token(BEiT v1)

class BEiTPretraining(nn.Module):
    """BEiT v1: 预测离散token"""
    
    def __init__(self, vocab_size=8192, embed_dim=768, depth=12, num_heads=12):
        super().__init__()
        
        # Transformer编码器
        self.transformer = TransformerEncoder(depth, embed_dim, num_heads)
        
        # 预测头:预测被掩码patch的token
        self.decoder = nn.Linear(embed_dim, vocab_size)
        
        # Tokenizer(冻结)
        self.dvae = load_pretrained_dvae()
        
    def forward(self, images, mask_ratio=0.4):
        # Patch化
        patches = patchify(images)  # [B, N, D]
        
        # 获取token
        with torch.no_grad():
            _, tokens = self.dvae.encode(patches)  # [B, N]
        
        # 掩码
        masked_patches, mask, ids_restore = random_masking(patches, mask_ratio)
        
        # 编码
        features = self.transformer(masked_patches)
        
        # 预测token
        pred_tokens = self.decoder(features)
        
        # 损失:交叉熵
        # 仅在被掩码位置计算
        loss = F.cross_entropy(pred_tokens[mask], tokens[mask])
        
        return loss

变体2:预测像素值 + token(BEiT v2)

class BEiTv2Pretraining(nn.Module):
    """BEiT v2: 多任务学习"""
    
    def forward(self, images, mask_ratio=0.4):
        patches = patchify(images)
        masked_patches, mask, _ = random_masking(patches, mask_ratio)
        
        features = self.transformer(masked_patches)
        
        # 任务1: 预测token
        pred_tokens = self.token_head(features)
        tokens = self.dvae.encode(patches)  # ground truth tokens
        
        # 任务2: 预测像素值
        pred_pixels = self.pixel_head(features)
        
        # 联合损失
        loss = F.cross_entropy(pred_tokens[mask], tokens[mask]) + \
               F.mse_loss(pred_pixels[mask], patches[mask])
        
        return loss

3.2 掩码策略

BEiT使用与BERT类似的块级掩码策略:

def block_masking(x, mask_ratio=0.4, block_size=4):
    """
    块级掩码:将连续的patch块一起掩码
    
    比随机掩码更难,提供更好的预训练效果
    """
    B, N, D = x.shape
    num_masked = int(N * mask_ratio)
    
    # 构建块掩码
    num_blocks = N // (block_size * block_size)
    
    # 随机选择要掩码的块
    noise = torch.rand(B, num_blocks)
    block_rank = torch.argsort(noise, dim=-1)
    
    # 前num_masked块被掩码
    blocks_to_mask = block_rank[:, :num_masked // (block_size * block_size)]
    
    mask = torch.zeros(B, N)
    for b in range(B):
        for blk in blocks_to_mask[b]:
            i = blk // block_size
            j = blk % block_size
            mask[b, i*block_size:(i+1)*block_size, 
                    j*block_size:(j+1)*block_size] = 1
    
    return mask.bool()

四、实验结果

4.1 ImageNet分类

方法预训练LinearFinetune
From scratch--68.3%
SupervisedIN-1K79.9%83.6%
MoCo v3300ep76.4%83.2%
BEiT800ep86.3%83.1%
BEiT v2800ep87.3%85.2%
MAE1600ep87.8%83.6%

4.2 消融实验

Tokenizer Vocab Size

Vocab SizeLinearFinetune
51283.2%82.4%
204885.1%82.8%
819286.3%83.1%
3276886.1%83.0%

结论:Vocab Size=8192是最优选择

掩码比例

掩码比例Linear
15%82.1%
30%84.8%
40%86.3%
50%86.0%
75%85.2%

结论:BEiT的最优掩码比例约为40%,远低于MAE的75%

4.3 下游任务迁移

任务数据集BEiTBEiT v2MAE
分类ImageNet83.1%85.2%83.6%
检测COCO48.5 AP50.3 AP50.3 AP
分割ADE20K48.1 mIoU50.1 mIoU47.4 mIoU

五、BEiT vs MAE 对比分析

5.1 核心区别

方面BEiTMAE
重建目标离散token像素值
预测难度分类(8192类)回归(256维)
Tokenizer需要预训练dVAE无需
掩码比例40%75%
编码器处理所有patch仅处理可见patch
解码器轻量(BEiT v1无解码器)重量(8层Transformer)

5.2 训练效率

指标BEiTMAE
GPU Memory较高较低
训练时间相近相近
收敛速度

5.3 表示特性

┌─────────────────────────────────────────────────────────────┐
│                    表示特性对比                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   BEiT:                                                     │
│   - 学习更多语义特征                                        │
│   - 离散token提供更好的语义监督                            │
│   - 适合分类任务                                           │
│                                                             │
│   MAE:                                                      │
│   - 学习更多低层视觉特征                                    │
│   - 像素级重建提供更丰富的监督信号                          │
│   - 适合密集预测任务                                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

六、代码实现

6.1 完整BEiT预训练

class BEiTForPretraining(nn.Module):
    """BEiT预训练模型"""
    
    def __init__(self, config):
        super().__init__()
        self.vocab_size = config.vocab_size  # 8192
        self.embed_dim = config.embed_dim    # 768
        self.depth = config.depth            # 12
        self.num_heads = config.num_heads    # 12
        
        # Patch嵌入
        self.patch_embed = nn.Linear(16*16*3, self.embed_dim)
        
        # 位置编码
        self.pos_embed = nn.Parameter(torch.zeros(1, 196, self.embed_dim))
        
        # cls token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        
        # Transformer编码器
        self.blocks = nn.ModuleList([
            TransformerBlock(self.embed_dim, self.num_heads)
            for _ in range(self.depth)
        ])
        self.norm = nn.LayerNorm(self.embed_dim)
        
        # 预测头
        self.head = nn.Linear(self.embed_dim, self.vocab_size)
        
    def forward(self, images, masked_pos=None):
        # Patch嵌入
        x = self.patchify(images)  # [B, N, D_patch]
        x = self.patch_embed(x)
        
        # 添加位置编码
        x = x + self.pos_embed
        
        # 掩码
        if masked_pos is not None:
            x[masked_pos] = 0  # 或使用可学习的mask token
        
        # Transformer编码
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        
        # 预测
        pred = self.head(x[masked_pos])
        
        return pred
    
    def patchify(self, imgs):
        p = 16
        assert imgs.shape[2] % p == 0 and imgs.shape[3] % p == 0
        x = imgs.reshape(imgs.shape[0], 3, imgs.shape[2] // p, p, imgs.shape[3] // p, p)
        x = x.permute(0, 2, 4, 3, 5, 1).reshape(imgs.shape[0], -1, p*p*3)
        return x

七、关键洞察

7.1 语义 vs 低层特征

  • BEiT通过离散token提供语义级别的监督
  • MAE通过像素值提供低层特征的监督
  • 两者可以互补,BEiT v2结合两者

7.2 掩码策略的重要性

  • 块级掩码比随机掩码更有效
  • 掩码比例需要针对任务调优

7.3 Tokenizer的作用

  • dVAE tokenizer将像素映射到语义空间
  • codebook大小影响表示质量

八、参考论文