概述

SimSiam (Simple Siamese Network) 是陈皓(Chen & He)在CVPR 2021提出的自监督学习方法1。其核心发现是:孪生网络可以在没有负样本、没有动量编码器、没有大批量的情况下学习到有意义的表示

特性SimSiamSimCLRMoCoBYOL
负样本不需要必须必须不需要
动量编码器不需要不需要需要需要
批量大小2564096256256
停止梯度

1. 核心发现

1.1 简化对比学习

SimSiam之前,自监督学习方法普遍需要以下组件之一:

  • 负样本:SimCLR通过对比正负样本对避免崩溃
  • 动量编码器:MoCo通过维护负样本队列避免崩溃
  • 大批量:SimCLR需要4096+的批量来提供足够的负样本

SimSiam的发现:孪生网络+停止梯度就足够了。

1.2 实验动机

论文进行了系统的消融实验,发现:

  1. 仅用孪生网络 + MSE损失 → 崩溃
  2. 添加预测器 改善
  3. 添加停止梯度 (stop-grad) → 避免崩溃
  4. 添加对称损失 → 加速收敛

2. 数学形式化

2.1 网络架构

SimSiam的网络架构:

                    预测器 h
                       ↑
                       │
    ┌──────────────────┴──────────────────┐
    │                                     │
    ↓                                     ↓
┌─────────┐                          ┌─────────┐
│ 编码器  │  ←─── 共享权重 ───→   │ 编码器  │
│  (共享) │                          │  (共享) │
└─────────┘                          └─────────┘
    │                                     │
    ↓                                     ↓
┌─────────┐                          ┌─────────┐
│ 投影头  │                          │ 投影头  │
│  (共享) │                          │  (共享) │
└─────────┘                          └─────────┘
    │                                     │
    ↓                                     ↓
   p1                                    p2
    │                                     │
    ↓(stop grad)                          ↓(stop grad)
   z1                                    z2

关键组件:

  • 编码器 :骨干网络 + 投影头
  • 预测器 :预测器MLP
  • 所有参数共享(不向MoCo/BYOL那样维护两个网络)

2.2 损失函数

SimSiam使用负余弦相似度作为损失:

其中 是负余弦相似度。

对称形式:

2.3 停止梯度机制

**停止梯度(Stop Gradient)**是SimSiam的核心创新:

# SimSiam伪代码
z1 = encoder(x1)        # 视图1的表示
z2 = encoder(x2)        # 视图2的表示 (stop grad)
p1 = predictor(z1)       # 预测器
p2 = predictor(z2)      # 预测器
 
# 损失:p1预测z2,但z2不参与梯度计算
loss = -cosine(p1, stopgrad(z2)) / 2

2.4 预测器结构

预测器是一个MLP,结构与BYOL类似:

其中 是ReLU激活函数。预测器输出通常与投影表示维度相同。


3. PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50
 
 
class SimSiam(nn.Module):
    """SimSiam模型"""
    def __init__(self, backbone, proj_dim=2048, pred_dim=512):
        super().__init__()
        
        # 编码器:骨干网络 + 投影头
        self.backbone = backbone
        self.projector = nn.Sequential(
            nn.Linear(backbone.output_dim, proj_dim, bias=False),
            nn.BatchNorm1d(proj_dim),
            nn.ReLU(inplace=True),
            nn.Linear(proj_dim, proj_dim, bias=False),
            nn.BatchNorm1d(proj_dim, affine=False)  # 预测器输入不学习
        )
        
        # 预测器
        self.predictor = nn.Sequential(
            nn.Linear(proj_dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True),
            nn.Linear(pred_dim, proj_dim)
        )
        
        # 权重初始化
        self._init_weights()
    
    def _init_weights(self):
        """权重初始化"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                if m.affine:
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x1, x2):
        """
        Args:
            x1: 视图1
            x2: 视图2
        Returns:
            loss: SimSiam损失
            z1, z2: 投影表示(用于可视化)
        """
        # 编码
        z1 = self.projector(self.backbone(x1))
        z2 = self.projector(self.backbone(x2))
        
        # 预测
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        
        # 计算损失(stop gradient在计算图中自动处理)
        loss = self.sim_siam_loss(p1, z2.detach()) + \
               self.sim_siam_loss(p2, z1.detach())
        
        return loss / 2, z1, z2
    
    def sim_siam_loss(self, p, z):
        """SimSiam损失:负余弦相似度"""
        return - F.cosine_similarity(p, z, dim=-1).mean()
    
    @torch.no_grad()
    def eval_encode(self, x):
        """评估时使用的编码函数"""
        z = self.projector(self.backbone(x))
        return F.normalize(z, dim=-1)
 
 
class SimSiamLoss(nn.Module):
    """独立的SimSiam损失"""
    def __init__(self):
        super().__init__()
    
    def forward(self, p1, z1, p2, z2):
        """
        Args:
            p1, p2: 预测器输出
            z1, z2: 投影表示(z1需要detach,z2在外部detach)
        """
        loss = - F.cosine_similarity(p1, z1.detach(), dim=-1).mean() - \
               F.cosine_similarity(p2, z2.detach(), dim=-1).mean()
        return loss / 2
 
 
def train_simsiam(model, train_loader, optimizer, scheduler=None, 
                  epochs=100, device='cuda'):
    """SimSiam训练循环"""
    model = model.to(device)
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (images, _) in enumerate(train_loader):
            images = images.to(device)
            
            # 创建两个视图
            x1 = augment(images)
            x2 = augment(images)
            
            # 前向传播
            loss, z1, z2 = model(x1, x2)
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            # 梯度裁剪(防止爆炸)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            if scheduler is not None:
                scheduler.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")
 
 
def augment(x):
    """数据增强"""
    from torchvision import transforms
    
    return transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=11, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])(x)

解耦版本(便于理解)

def simsiam_loss_decomposed(p, z):
    """
    分解的SimSiam损失
    p: 预测器输出
    z: 投影表示(需要detach)
    """
    # 归一化
    p = F.normalize(p, dim=-1)
    z = F.normalize(z, dim=-1)
    
    # 负余弦相似度
    return -(p * z).sum(dim=-1).mean()
 
 
def forward_step(x1, x2, encoder, predictor):
    """SimSiam前向步骤"""
    # 编码(两个分支共享权重)
    z1 = encoder(x1)  # stop grad
    z2 = encoder(x2)  # stop grad
    
    # 预测
    p1 = predictor(z1)
    p2 = predictor(z2)
    
    # 损失
    loss = simsiam_loss_decomposed(p1, z2.detach()) + \
           simsiam_loss_decomposed(p2, z1.detach())
    
    return loss / 2

4. 训练动态分析

4.1 崩溃避免机制

SimSiam为什么不会崩溃?直觉解释:

  1. 非对称性:预测器 打破了完美对称性
  2. 停止梯度:z2不参与反向传播,形成”伪目标”
  3. 孪生网络:两个视图的表示应该相似

4.2 训练阶段

SimSiam训练可分为两个阶段:

阶段描述
崩溃阶段模型快速收敛到常数表示,损失高
解决阶段停止梯度开始生效,模型逐渐学习
平衡阶段预测器和编码器达到平衡

4.3 对称损失的影响

配置效果
无对称收敛慢,可能不稳定
有对称收敛快,更稳定

5. 与其他方法的对比

5.1 架构对比

方法编码器数量预测器停止梯度
SimCLR1 (共享)
MoCo2
BYOL2 (EMA)
SwAV1 (共享)有(聚类)
SimSiam1 (共享)

5.2 组件消融

配置Top-1准确率
基线(无预测器)0.0% (崩溃)
+ 预测器53.2%
+ 停止梯度68.1%
+ 对称损失71.3%
完整SimSiam71.3%

5.3 批量大小敏感性

批量大小SimSiamSimCLRMoCo
6467.6%-58.7%
12869.2%-63.2%
25670.4%66.0%66.1%
51271.0%68.3%67.4%
204871.3%70.9%-

6. 实验结果

6.1 ImageNet线性评估

方法Top-1准确率
SimSiam71.3%
SimCLR71.7%
MoCo v271.1%
BYOL74.3%
SwAV72.5%
Supervised76.5%

6.2 迁移学习

数据集SimSiamSupervised
CIFAR-1091.3%95.0%
CIFAR-10064.3%78.3%
STL-1085.9%89.4%
ImageNet检测44.8%46.9%

6.3 下游微调

SimSiam的表示适合在线性探测和微调两种设置:

# 线性探测
with torch.no_grad():
    features = encoder(images)
classifier = nn.Linear(2048, num_classes)
# 只训练classifier
 
# 微调
# 训练整个encoder + classifier

7. 实践技巧

7.1 超参数配置

参数推荐值说明
预测器维度512中间层
投影维度2048输出维度
学习率0.05配合余弦调度
权重衰减1e-4
BatchNormmomentum=0.1比标准更高
梯度裁剪max_norm=1.0防止爆炸

7.2 学习率调度

SimSiam推荐使用余弦退火调度:

epochs = 800
lr_decay = 0.05
optimizer = torch.optim.SGD(model.parameters(), lr=0.05, 
                             momentum=0.9, weight_decay=1e-4)
 
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=epochs, eta_min=lr_decay
)

7.3 BatchNorm位置

projector = nn.Sequential(
    nn.Linear(backbone.output_dim, proj_dim, bias=False),
    nn.BatchNorm1d(proj_dim),  # 在Linear后
    nn.ReLU(inplace=True),
    nn.Linear(proj_dim, proj_dim, bias=False),
    nn.BatchNorm1d(proj_dim, affine=False)  # 最后一层无affine
)

8. 理论分析

8.1 梯度流分析

SimSiam的梯度流:

z2 ←────── 反向传播
 │
 ↓ stop grad
p1 ←────── 反向传播
 │
 ↓
h
 │
 ↓
z1 ←────── 无梯度

8.2 等效优化

SimSiam可以解释为在执行:

  1. EM风格的优化:预测器类EM中的E步,编码器类M步
  2. 伪目标学习:z2作为p1的伪目标
  3. 知识蒸馏:预测器蒸馏编码器

8.3 崩溃的数学条件

当模型崩溃时:

损失变为:


9. 总结

SimSiam的核心贡献:

  1. 最小主义设计:仅用孪生网络 + 停止梯度实现表示学习
  2. 理论洞察:揭示了停止梯度的关键作用
  3. 工程简洁:无需负样本、无需动量编码器、无需大批量
  4. 训练稳定:梯度裁剪即可稳定训练

SimSiam证明了非对称预测器 + 停止梯度是避免崩溃的充分条件,为理解自监督学习提供了重要的理论框架。


参考

Footnotes

  1. Chen, X., & He, K. (2021). “Exploring Simple Siamese Representation Learning”. CVPR 2021. arXiv:2011.10566