概述

BYOL (Bootstrap Your Own Latent) 是一种自监督表示学习方法,由DeepMind团队在NeurIPS 2020提出1。其核心创新在于无需负样本对即可学习高质量表示,打破了对比学习对大批量和负样本的依赖。

特性BYOLSimCLR
负样本不需要必须
批量大小较小(256-4096)较大(4096+)
网络结构非对称孪生网络对称孪生网络
崩溃缓解预测器+EMA对比损失

1. 核心思想

1.1 为什么不需要负样本?

传统对比学习(如SimCLR)需要负样本来防止表示崩溃(collapse)。负样本通过推开不同样本的表示来促使模型学习有意义的特征。

BYOL的核心洞察:如果能预测好的表示,自然就不需要负样本

“We aim to learn a representation from images, such that a classifier trained on yields good classification performance.”

BYOL通过非对称网络结构在线预测器实现这一点。

1.2 非对称孪生网络

                    预测器 (MLP)
                       ↑
                       │
    ┌──────────────────┴──────────────────┐
    │                                     │
    ↓                                     ↓
┌─────────┐                          ┌─────────┐
│ 在线    │                          │ 目标    │
│ 网络 θ  │  ←── EMA更新 ──→       │ 网络 ξ  │
│ f_θ     │                          │ f_ξ     │
└─────────┘                          └─────────┘
    │                                     │
    ↓                                     ↓
┌─────────┐                          ┌─────────┐
│ 投影头  │                          │ 投影头  │
│ g_θ     │                          │ g_ξ     │
└─────────┘                          └─────────┘
    │                                     │
    ↓                                     ↓
  y_θ                                   y_ξ

2. 数学形式化

2.1 网络架构

BYOL包含两个网络:在线网络(online)和目标网络(target)。

在线网络

在线网络由三部分组成:

  • 骨干网络 (如ResNet)
  • 投影头 (MLP)
  • 预测器 (MLP)

目标网络

目标网络是在线网络的”老师”,结构相同但参数不同:

  • 骨干网络
  • 投影头

2.2 损失函数

BYOL使用均方误差作为损失函数:

由于我们希望在线网络预测目标网络,所以损失函数为:

等价于:

2.3 目标网络更新

目标网络通过**指数移动平均(EMA)**更新在线网络参数:

其中 是动量系数,通常设置为0.996或使用余弦调度。

2.4 停止梯度

关键创新:目标网络的输出 不参与梯度计算。只有 的参数 被更新。


3. PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50
 
 
class MLPHead(nn.Module):
    """投影头和预测器"""
    def __init__(self, in_dim, hidden_dim=4096, out_dim=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )
    
    def forward(self, x):
        return F.normalize(self.mlp(x), dim=-1)
 
 
class Predictor(nn.Module):
    """预测器 - 与投影头结构相同"""
    def __init__(self, in_dim, hidden_dim=4096, out_dim=256):
        super().__init__()
        self.predictor = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )
    
    def forward(self, x):
        return self.predictor(x)
 
 
class BYOL(nn.Module):
    """BYOL模型"""
    def __init__(self, backbone, hidden_dim=4096, proj_dim=256, tau=0.996):
        super().__init__()
        self.tau = tau
        
        # 在线网络
        self.online_backbone = backbone
        self.online_projector = MLPHead(backbone.output_dim, hidden_dim, proj_dim)
        self.predictor = Predictor(proj_dim, hidden_dim, proj_dim)
        
        # 目标网络(初始与在线网络相同)
        self.target_backbone = self._copy_weights(backbone)
        self.target_projector = self._copy_weights(self.online_projector)
        
        # 冻结目标网络参数
        for param in self.target_backbone.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False
    
    def _copy_weights(self, module):
        new_module = type(module)(*module.args, **module.kwargs)
        new_module.load_state_dict(module.state_dict())
        return new_module
    
    @torch.no_grad()
    def _update_target_network(self):
        """EMA更新目标网络"""
        for (name, online_p), target_p in zip(
            self.online_backbone.named_parameters() + self.online_projector.named_parameters(),
            self.target_backbone.parameters() + self.target_projector.parameters()
        ):
            target_p.data = self.tau * target_p.data + (1 - self.tau) * online_p.data
    
    def forward(self, x1, x2, update_target=True):
        """
        Args:
            x1: 视图1
            x2: 视图2
            update_target: 是否更新目标网络
        """
        # 在线网络处理两个视图
        z1 = self.predictor(self.online_projector(self.online_backbone(x1)))
        z2 = self.predictor(self.online_projector(self.online_backbone(x2)))
        
        # 目标网络处理两个视图(不计算梯度)
        with torch.no_grad():
            y1 = self.target_projector(self.target_backbone(x1))
            y2 = self.target_projector(self.target_backbone(x2))
        
        # 计算损失:z1预测y2,z2预测y1
        loss = 2 - 2 * (
            F.cosine_similarity(z1, y2.detach(), dim=-1).mean() +
            F.cosine_similarity(z2, y1.detach(), dim=-1).mean()
        ) / 2
        
        # EMA更新目标网络
        if update_target:
            self._update_target_network()
        
        return loss
 
 
def byol_loss(z1, y2):
    """简化的BYOL损失"""
    return 2 - 2 * F.cosine_similarity(z1, y2.detach(), dim=-1).mean()

训练循环

def train_byol(model, train_loader, optimizer, epochs=100, device='cuda'):
    model = model.to(device)
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (images, _) in enumerate(train_loader):
            images = images.to(device)
            
            # 通过数据增强获取两个视图
            x1 = augment(images, view=1)
            x2 = augment(images, view=2)
            
            # 前向传播
            loss = model(x1, x2)
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")
 
 
def augment(x, view=1):
    """BYOL数据增强"""
    from torchvision import transforms
    
    if view == 1:
        return transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])(x)
    else:
        return transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])(x)

4. 训练动态分析

4.1 为什么BYOL不会崩溃?

BYOL避免崩溃的关键机制:

  1. 非对称性:预测器 只存在于在线网络中
  2. EMA更新:目标网络缓慢跟随在线网络
  3. 停止梯度:目标表示不参与反向传播

数学直觉:BYOL实际上在优化一个对比目标,但通过在线/目标网络的分离隐式实现。

4.2 预测器的作用

预测器 是防止崩溃的关键:

组件作用
无预测器模型学习常数表示(崩溃)
随机预测器部分缓解崩溃
可学习预测器完全避免崩溃

4.3 动量系数调度

BYOL的原始论文使用固定的 ,但后续研究表明:

其中 是当前步数, 是总步数, 通常设为0.99或1.0。


5. 与其他方法的对比

5.1 SimCLR vs BYOL

特性SimCLRBYOL
损失函数InfoNCE (对比)MSE (预测)
负样本需要不需要
批量大小4096256-4096
内存占用中等
投影头2层MLP3层MLP
预测器

5.2 SwAV vs BYOL

SwAV通过聚类实现无负样本学习,与BYOL的方法论不同但目标相同。


6. 实验结果

6.1 ImageNet线性评估

方法Top-1准确率
BYOL74.3%
SimCLR70.7%
MoCo60.6%
Supervised76.5%

6.2 迁移学习结果

BYOL在多个下游任务上展现出优秀的迁移性能:

数据集BYOL vs 监督
Fine-grained分类相当或更好
检测/分割相当
医学影像通常更好

7. 实践技巧

7.1 增强策略

BYOL的增强组合与SimCLR类似,但高斯模糊是BYOL的重要组成部分:

BYOL_AUGMENTATION = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=23),  # 关键增强
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

7.2 超参数建议

参数推荐值
批量大小256-4096
投影维度256-4096
预测器隐藏层4096
EMA动量0.996-0.999
学习率0.05 (batch_size=256)
权重衰减1e-6

7.3 LARS优化器

BYOL原论文使用LARS优化器:

optimizer = LARS(
    [model.online_backbone, model.online_projector, model.predictor],
    lr=0.05 * batch_size / 256,
    weight_decay=1e-6,
    momentum=0.9
)

8. 变体与扩展

8.1 BYOL-S

BYOL-S使用孪生网络架构的变体,通过权重共享简化模型。

8.2 BYOL-A

BYOL-A针对移动设备优化,使用更轻量的骨干网络(如MobileNet)。

8.3 联合训练

BYOL可以与其他目标联合训练:


9. 总结

BYOL的核心贡献:

  1. 无需负样本:通过非对称网络结构避免崩溃
  2. EMA目标网络:稳定的训练动态
  3. 预测器机制:引导在线网络学习更好的表示
  4. 小批量友好:降低对大批量的依赖

BYOL证明了预测另一个表示(而不只是区分不同样本)足以学习好的表示,为自监督学习开辟了新方向。


参考

Footnotes

  1. Grill, J. B., et al. (2020). “Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning”. NeurIPS 2020. arXiv:2006.07733