概述
SimSiam (Simple Siamese Network) 是陈皓(Chen & He)在CVPR 2021提出的自监督学习方法1。其核心发现是:孪生网络可以在没有负样本、没有动量编码器、没有大批量的情况下学习到有意义的表示。
| 特性 | SimSiam | SimCLR | MoCo | BYOL |
|---|---|---|---|---|
| 负样本 | 不需要 | 必须 | 必须 | 不需要 |
| 动量编码器 | 不需要 | 不需要 | 需要 | 需要 |
| 批量大小 | 256 | 4096 | 256 | 256 |
| 停止梯度 | 有 | 无 | 无 | 无 |
1. 核心发现
1.1 简化对比学习
SimSiam之前,自监督学习方法普遍需要以下组件之一:
- 负样本:SimCLR通过对比正负样本对避免崩溃
- 动量编码器:MoCo通过维护负样本队列避免崩溃
- 大批量:SimCLR需要4096+的批量来提供足够的负样本
SimSiam的发现:孪生网络+停止梯度就足够了。
1.2 实验动机
论文进行了系统的消融实验,发现:
- 仅用孪生网络 + MSE损失 → 崩溃
- 添加预测器 → 改善
- 添加停止梯度 (stop-grad) → 避免崩溃
- 添加对称损失 → 加速收敛
2. 数学形式化
2.1 网络架构
SimSiam的网络架构:
预测器 h
↑
│
┌──────────────────┴──────────────────┐
│ │
↓ ↓
┌─────────┐ ┌─────────┐
│ 编码器 │ ←─── 共享权重 ───→ │ 编码器 │
│ (共享) │ │ (共享) │
└─────────┘ └─────────┘
│ │
↓ ↓
┌─────────┐ ┌─────────┐
│ 投影头 │ │ 投影头 │
│ (共享) │ │ (共享) │
└─────────┘ └─────────┘
│ │
↓ ↓
p1 p2
│ │
↓(stop grad) ↓(stop grad)
z1 z2
关键组件:
- 编码器 :骨干网络 + 投影头
- 预测器 :预测器MLP
- 所有参数共享(不向MoCo/BYOL那样维护两个网络)
2.2 损失函数
SimSiam使用负余弦相似度作为损失:
其中 是负余弦相似度。
对称形式:
2.3 停止梯度机制
**停止梯度(Stop Gradient)**是SimSiam的核心创新:
# SimSiam伪代码
z1 = encoder(x1) # 视图1的表示
z2 = encoder(x2) # 视图2的表示 (stop grad)
p1 = predictor(z1) # 预测器
p2 = predictor(z2) # 预测器
# 损失:p1预测z2,但z2不参与梯度计算
loss = -cosine(p1, stopgrad(z2)) / 22.4 预测器结构
预测器是一个MLP,结构与BYOL类似:
其中 是ReLU激活函数。预测器输出通常与投影表示维度相同。
3. PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50
class SimSiam(nn.Module):
"""SimSiam模型"""
def __init__(self, backbone, proj_dim=2048, pred_dim=512):
super().__init__()
# 编码器:骨干网络 + 投影头
self.backbone = backbone
self.projector = nn.Sequential(
nn.Linear(backbone.output_dim, proj_dim, bias=False),
nn.BatchNorm1d(proj_dim),
nn.ReLU(inplace=True),
nn.Linear(proj_dim, proj_dim, bias=False),
nn.BatchNorm1d(proj_dim, affine=False) # 预测器输入不学习
)
# 预测器
self.predictor = nn.Sequential(
nn.Linear(proj_dim, pred_dim, bias=False),
nn.BatchNorm1d(pred_dim),
nn.ReLU(inplace=True),
nn.Linear(pred_dim, proj_dim)
)
# 权重初始化
self._init_weights()
def _init_weights(self):
"""权重初始化"""
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
if m.affine:
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x1, x2):
"""
Args:
x1: 视图1
x2: 视图2
Returns:
loss: SimSiam损失
z1, z2: 投影表示(用于可视化)
"""
# 编码
z1 = self.projector(self.backbone(x1))
z2 = self.projector(self.backbone(x2))
# 预测
p1 = self.predictor(z1)
p2 = self.predictor(z2)
# 计算损失(stop gradient在计算图中自动处理)
loss = self.sim_siam_loss(p1, z2.detach()) + \
self.sim_siam_loss(p2, z1.detach())
return loss / 2, z1, z2
def sim_siam_loss(self, p, z):
"""SimSiam损失:负余弦相似度"""
return - F.cosine_similarity(p, z, dim=-1).mean()
@torch.no_grad()
def eval_encode(self, x):
"""评估时使用的编码函数"""
z = self.projector(self.backbone(x))
return F.normalize(z, dim=-1)
class SimSiamLoss(nn.Module):
"""独立的SimSiam损失"""
def __init__(self):
super().__init__()
def forward(self, p1, z1, p2, z2):
"""
Args:
p1, p2: 预测器输出
z1, z2: 投影表示(z1需要detach,z2在外部detach)
"""
loss = - F.cosine_similarity(p1, z1.detach(), dim=-1).mean() - \
F.cosine_similarity(p2, z2.detach(), dim=-1).mean()
return loss / 2
def train_simsiam(model, train_loader, optimizer, scheduler=None,
epochs=100, device='cuda'):
"""SimSiam训练循环"""
model = model.to(device)
model.train()
for epoch in range(epochs):
total_loss = 0
for batch_idx, (images, _) in enumerate(train_loader):
images = images.to(device)
# 创建两个视图
x1 = augment(images)
x2 = augment(images)
# 前向传播
loss, z1, z2 = model(x1, x2)
# 反向传播
optimizer.zero_grad()
loss.backward()
# 梯度裁剪(防止爆炸)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
if scheduler is not None:
scheduler.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")
def augment(x):
"""数据增强"""
from torchvision import transforms
return transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=11, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])(x)解耦版本(便于理解)
def simsiam_loss_decomposed(p, z):
"""
分解的SimSiam损失
p: 预测器输出
z: 投影表示(需要detach)
"""
# 归一化
p = F.normalize(p, dim=-1)
z = F.normalize(z, dim=-1)
# 负余弦相似度
return -(p * z).sum(dim=-1).mean()
def forward_step(x1, x2, encoder, predictor):
"""SimSiam前向步骤"""
# 编码(两个分支共享权重)
z1 = encoder(x1) # stop grad
z2 = encoder(x2) # stop grad
# 预测
p1 = predictor(z1)
p2 = predictor(z2)
# 损失
loss = simsiam_loss_decomposed(p1, z2.detach()) + \
simsiam_loss_decomposed(p2, z1.detach())
return loss / 24. 训练动态分析
4.1 崩溃避免机制
SimSiam为什么不会崩溃?直觉解释:
- 非对称性:预测器 打破了完美对称性
- 停止梯度:z2不参与反向传播,形成”伪目标”
- 孪生网络:两个视图的表示应该相似
4.2 训练阶段
SimSiam训练可分为两个阶段:
| 阶段 | 描述 |
|---|---|
| 崩溃阶段 | 模型快速收敛到常数表示,损失高 |
| 解决阶段 | 停止梯度开始生效,模型逐渐学习 |
| 平衡阶段 | 预测器和编码器达到平衡 |
4.3 对称损失的影响
| 配置 | 效果 |
|---|---|
| 无对称 | 收敛慢,可能不稳定 |
| 有对称 | 收敛快,更稳定 |
5. 与其他方法的对比
5.1 架构对比
| 方法 | 编码器数量 | 预测器 | 停止梯度 |
|---|---|---|---|
| SimCLR | 1 (共享) | 无 | 无 |
| MoCo | 2 | 无 | 无 |
| BYOL | 2 (EMA) | 有 | 无 |
| SwAV | 1 (共享) | 有(聚类) | 无 |
| SimSiam | 1 (共享) | 有 | 有 |
5.2 组件消融
| 配置 | Top-1准确率 |
|---|---|
| 基线(无预测器) | 0.0% (崩溃) |
| + 预测器 | 53.2% |
| + 停止梯度 | 68.1% |
| + 对称损失 | 71.3% |
| 完整SimSiam | 71.3% |
5.3 批量大小敏感性
| 批量大小 | SimSiam | SimCLR | MoCo |
|---|---|---|---|
| 64 | 67.6% | - | 58.7% |
| 128 | 69.2% | - | 63.2% |
| 256 | 70.4% | 66.0% | 66.1% |
| 512 | 71.0% | 68.3% | 67.4% |
| 2048 | 71.3% | 70.9% | - |
6. 实验结果
6.1 ImageNet线性评估
| 方法 | Top-1准确率 |
|---|---|
| SimSiam | 71.3% |
| SimCLR | 71.7% |
| MoCo v2 | 71.1% |
| BYOL | 74.3% |
| SwAV | 72.5% |
| Supervised | 76.5% |
6.2 迁移学习
| 数据集 | SimSiam | Supervised |
|---|---|---|
| CIFAR-10 | 91.3% | 95.0% |
| CIFAR-100 | 64.3% | 78.3% |
| STL-10 | 85.9% | 89.4% |
| ImageNet检测 | 44.8% | 46.9% |
6.3 下游微调
SimSiam的表示适合在线性探测和微调两种设置:
# 线性探测
with torch.no_grad():
features = encoder(images)
classifier = nn.Linear(2048, num_classes)
# 只训练classifier
# 微调
# 训练整个encoder + classifier7. 实践技巧
7.1 超参数配置
| 参数 | 推荐值 | 说明 |
|---|---|---|
| 预测器维度 | 512 | 中间层 |
| 投影维度 | 2048 | 输出维度 |
| 学习率 | 0.05 | 配合余弦调度 |
| 权重衰减 | 1e-4 | |
| BatchNorm | momentum=0.1 | 比标准更高 |
| 梯度裁剪 | max_norm=1.0 | 防止爆炸 |
7.2 学习率调度
SimSiam推荐使用余弦退火调度:
epochs = 800
lr_decay = 0.05
optimizer = torch.optim.SGD(model.parameters(), lr=0.05,
momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs, eta_min=lr_decay
)7.3 BatchNorm位置
projector = nn.Sequential(
nn.Linear(backbone.output_dim, proj_dim, bias=False),
nn.BatchNorm1d(proj_dim), # 在Linear后
nn.ReLU(inplace=True),
nn.Linear(proj_dim, proj_dim, bias=False),
nn.BatchNorm1d(proj_dim, affine=False) # 最后一层无affine
)8. 理论分析
8.1 梯度流分析
SimSiam的梯度流:
z2 ←────── 反向传播
│
↓ stop grad
p1 ←────── 反向传播
│
↓
h
│
↓
z1 ←────── 无梯度
8.2 等效优化
SimSiam可以解释为在执行:
- EM风格的优化:预测器类EM中的E步,编码器类M步
- 伪目标学习:z2作为p1的伪目标
- 知识蒸馏:预测器蒸馏编码器
8.3 崩溃的数学条件
当模型崩溃时:
损失变为:
9. 总结
SimSiam的核心贡献:
- 最小主义设计:仅用孪生网络 + 停止梯度实现表示学习
- 理论洞察:揭示了停止梯度的关键作用
- 工程简洁:无需负样本、无需动量编码器、无需大批量
- 训练稳定:梯度裁剪即可稳定训练
SimSiam证明了非对称预测器 + 停止梯度是避免崩溃的充分条件,为理解自监督学习提供了重要的理论框架。
参考
Footnotes
-
Chen, X., & He, K. (2021). “Exploring Simple Siamese Representation Learning”. CVPR 2021. arXiv:2011.10566 ↩