概述
BYOL (Bootstrap Your Own Latent) 是一种自监督表示学习方法,由DeepMind团队在NeurIPS 2020提出1。其核心创新在于无需负样本对即可学习高质量表示,打破了对比学习对大批量和负样本的依赖。
| 特性 | BYOL | SimCLR |
|---|---|---|
| 负样本 | 不需要 | 必须 |
| 批量大小 | 较小(256-4096) | 较大(4096+) |
| 网络结构 | 非对称孪生网络 | 对称孪生网络 |
| 崩溃缓解 | 预测器+EMA | 对比损失 |
1. 核心思想
1.1 为什么不需要负样本?
传统对比学习(如SimCLR)需要负样本来防止表示崩溃(collapse)。负样本通过推开不同样本的表示来促使模型学习有意义的特征。
BYOL的核心洞察:如果能预测好的表示,自然就不需要负样本。
“We aim to learn a representation from images, such that a classifier trained on yields good classification performance.”
BYOL通过非对称网络结构和在线预测器实现这一点。
1.2 非对称孪生网络
预测器 (MLP)
↑
│
┌──────────────────┴──────────────────┐
│ │
↓ ↓
┌─────────┐ ┌─────────┐
│ 在线 │ │ 目标 │
│ 网络 θ │ ←── EMA更新 ──→ │ 网络 ξ │
│ f_θ │ │ f_ξ │
└─────────┘ └─────────┘
│ │
↓ ↓
┌─────────┐ ┌─────────┐
│ 投影头 │ │ 投影头 │
│ g_θ │ │ g_ξ │
└─────────┘ └─────────┘
│ │
↓ ↓
y_θ y_ξ
2. 数学形式化
2.1 网络架构
BYOL包含两个网络:在线网络(online)和目标网络(target)。
在线网络
在线网络由三部分组成:
- 骨干网络 (如ResNet)
- 投影头 (MLP)
- 预测器 (MLP)
目标网络
目标网络是在线网络的”老师”,结构相同但参数不同:
- 骨干网络
- 投影头
2.2 损失函数
BYOL使用均方误差作为损失函数:
由于我们希望在线网络预测目标网络,所以损失函数为:
等价于:
2.3 目标网络更新
目标网络通过**指数移动平均(EMA)**更新在线网络参数:
其中 是动量系数,通常设置为0.996或使用余弦调度。
2.4 停止梯度
关键创新:目标网络的输出 不参与梯度计算。只有 的参数 被更新。
3. PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50
class MLPHead(nn.Module):
"""投影头和预测器"""
def __init__(self, in_dim, hidden_dim=4096, out_dim=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, out_dim)
)
def forward(self, x):
return F.normalize(self.mlp(x), dim=-1)
class Predictor(nn.Module):
"""预测器 - 与投影头结构相同"""
def __init__(self, in_dim, hidden_dim=4096, out_dim=256):
super().__init__()
self.predictor = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, out_dim)
)
def forward(self, x):
return self.predictor(x)
class BYOL(nn.Module):
"""BYOL模型"""
def __init__(self, backbone, hidden_dim=4096, proj_dim=256, tau=0.996):
super().__init__()
self.tau = tau
# 在线网络
self.online_backbone = backbone
self.online_projector = MLPHead(backbone.output_dim, hidden_dim, proj_dim)
self.predictor = Predictor(proj_dim, hidden_dim, proj_dim)
# 目标网络(初始与在线网络相同)
self.target_backbone = self._copy_weights(backbone)
self.target_projector = self._copy_weights(self.online_projector)
# 冻结目标网络参数
for param in self.target_backbone.parameters():
param.requires_grad = False
for param in self.target_projector.parameters():
param.requires_grad = False
def _copy_weights(self, module):
new_module = type(module)(*module.args, **module.kwargs)
new_module.load_state_dict(module.state_dict())
return new_module
@torch.no_grad()
def _update_target_network(self):
"""EMA更新目标网络"""
for (name, online_p), target_p in zip(
self.online_backbone.named_parameters() + self.online_projector.named_parameters(),
self.target_backbone.parameters() + self.target_projector.parameters()
):
target_p.data = self.tau * target_p.data + (1 - self.tau) * online_p.data
def forward(self, x1, x2, update_target=True):
"""
Args:
x1: 视图1
x2: 视图2
update_target: 是否更新目标网络
"""
# 在线网络处理两个视图
z1 = self.predictor(self.online_projector(self.online_backbone(x1)))
z2 = self.predictor(self.online_projector(self.online_backbone(x2)))
# 目标网络处理两个视图(不计算梯度)
with torch.no_grad():
y1 = self.target_projector(self.target_backbone(x1))
y2 = self.target_projector(self.target_backbone(x2))
# 计算损失:z1预测y2,z2预测y1
loss = 2 - 2 * (
F.cosine_similarity(z1, y2.detach(), dim=-1).mean() +
F.cosine_similarity(z2, y1.detach(), dim=-1).mean()
) / 2
# EMA更新目标网络
if update_target:
self._update_target_network()
return loss
def byol_loss(z1, y2):
"""简化的BYOL损失"""
return 2 - 2 * F.cosine_similarity(z1, y2.detach(), dim=-1).mean()训练循环
def train_byol(model, train_loader, optimizer, epochs=100, device='cuda'):
model = model.to(device)
model.train()
for epoch in range(epochs):
total_loss = 0
for batch_idx, (images, _) in enumerate(train_loader):
images = images.to(device)
# 通过数据增强获取两个视图
x1 = augment(images, view=1)
x2 = augment(images, view=2)
# 前向传播
loss = model(x1, x2)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")
def augment(x, view=1):
"""BYOL数据增强"""
from torchvision import transforms
if view == 1:
return transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])(x)
else:
return transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])(x)4. 训练动态分析
4.1 为什么BYOL不会崩溃?
BYOL避免崩溃的关键机制:
- 非对称性:预测器 只存在于在线网络中
- EMA更新:目标网络缓慢跟随在线网络
- 停止梯度:目标表示不参与反向传播
数学直觉:BYOL实际上在优化一个对比目标,但通过在线/目标网络的分离隐式实现。
4.2 预测器的作用
预测器 是防止崩溃的关键:
| 组件 | 作用 |
|---|---|
| 无预测器 | 模型学习常数表示(崩溃) |
| 随机预测器 | 部分缓解崩溃 |
| 可学习预测器 | 完全避免崩溃 |
4.3 动量系数调度
BYOL的原始论文使用固定的 ,但后续研究表明:
其中 是当前步数, 是总步数, 通常设为0.99或1.0。
5. 与其他方法的对比
5.1 SimCLR vs BYOL
| 特性 | SimCLR | BYOL |
|---|---|---|
| 损失函数 | InfoNCE (对比) | MSE (预测) |
| 负样本 | 需要 | 不需要 |
| 批量大小 | 4096 | 256-4096 |
| 内存占用 | 高 | 中等 |
| 投影头 | 2层MLP | 3层MLP |
| 预测器 | 无 | 有 |
5.2 SwAV vs BYOL
SwAV通过聚类实现无负样本学习,与BYOL的方法论不同但目标相同。
6. 实验结果
6.1 ImageNet线性评估
| 方法 | Top-1准确率 |
|---|---|
| BYOL | 74.3% |
| SimCLR | 70.7% |
| MoCo | 60.6% |
| Supervised | 76.5% |
6.2 迁移学习结果
BYOL在多个下游任务上展现出优秀的迁移性能:
| 数据集 | BYOL vs 监督 |
|---|---|
| Fine-grained分类 | 相当或更好 |
| 检测/分割 | 相当 |
| 医学影像 | 通常更好 |
7. 实践技巧
7.1 增强策略
BYOL的增强组合与SimCLR类似,但高斯模糊是BYOL的重要组成部分:
BYOL_AUGMENTATION = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23), # 关键增强
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])7.2 超参数建议
| 参数 | 推荐值 |
|---|---|
| 批量大小 | 256-4096 |
| 投影维度 | 256-4096 |
| 预测器隐藏层 | 4096 |
| EMA动量 | 0.996-0.999 |
| 学习率 | 0.05 (batch_size=256) |
| 权重衰减 | 1e-6 |
7.3 LARS优化器
BYOL原论文使用LARS优化器:
optimizer = LARS(
[model.online_backbone, model.online_projector, model.predictor],
lr=0.05 * batch_size / 256,
weight_decay=1e-6,
momentum=0.9
)8. 变体与扩展
8.1 BYOL-S
BYOL-S使用孪生网络架构的变体,通过权重共享简化模型。
8.2 BYOL-A
BYOL-A针对移动设备优化,使用更轻量的骨干网络(如MobileNet)。
8.3 联合训练
BYOL可以与其他目标联合训练:
9. 总结
BYOL的核心贡献:
- 无需负样本:通过非对称网络结构避免崩溃
- EMA目标网络:稳定的训练动态
- 预测器机制:引导在线网络学习更好的表示
- 小批量友好:降低对大批量的依赖
BYOL证明了预测另一个表示(而不只是区分不同样本)足以学习好的表示,为自监督学习开辟了新方向。
参考
Footnotes
-
Grill, J. B., et al. (2020). “Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning”. NeurIPS 2020. arXiv:2006.07733 ↩