概述

SiT (Self-supervised Image Transformer) 是将扩散模型框架自监督表示学习统一的开创性工作1。它展示了扩散模型不仅能用于生成,还可以用于学习高质量的视觉表示。

特性SiTDINOMAEDDPM
任务表示学习表示学习表示学习生成
框架扩散蒸馏重建扩散
条件机制交叉注意力教师-学生Mask
表示质量SOTASOTA中等N/A

1. 核心思想

1.1 为什么将扩散用于表示学习?

传统扩散模型(如DDPM)只学习生成能力,SiT的洞察是:

“The denoising process contains rich semantic information about the data manifold.”

去噪过程包含:

  1. 低层特征:纹理、边缘
  2. 中层特征:形状、物体部件
  3. 高层特征:语义、类别

通过学习去噪,模型自然地学习了多尺度表示。

1.2 SiT的统一框架

SiT将自监督学习统一到扩散框架下:

                    扩散模型                    自监督学习
                          │                          │
        ┌─────────────────┴─────────────────┐        │
        ↓                                   ↓        │
    噪声预测器 ──────────────────────→ 表示学习器      │
        │                                   │        │
        ↓                                   ↓        │
    去噪轨迹 ──────────────────────→ 语义轨迹       │

2. 数学形式化

2.1 前向扩散过程

给定图像 ,前向过程添加噪声:

其中

2.2 反向过程(表示学习)

SiT使用条件UNet/Transformer预测原始图像:

其中 是条件信息(类别、clip特征等)。

2.3 训练目标

SiT使用多种损失函数的组合:

均方误差损失

表示对齐损失

其中 是stop gradient后的表示, 是扩散模型的中间表示。

对比损失

2.4 插值框架

SiT的核心创新是插值学习

其中插值损失:


3. PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
 
class SiTBlock(nn.Module):
    """SiT Transformer块"""
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x, context=None):
        # 自注意力
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        # 前馈网络
        x = x + self.mlp(self.norm2(x))
        return x
 
 
class SiT(nn.Module):
    """自监督图像Transformer"""
    def __init__(self, image_size=32, patch_size=2, in_channels=3,
                 hidden_dim=384, num_layers=12, num_heads=6,
                 mlp_ratio=4.0, dropout=0.0, num_classes=10):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.hidden_dim = hidden_dim
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(
            in_channels, hidden_dim,
            kernel_size=patch_size, stride=patch_size
        )
        
        # 位置编码
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches, hidden_dim)
        )
        
        # 时间嵌入
        self.time_embed = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.SiLU(),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )
        
        # 条件嵌入(可选)
        self.class_embed = nn.Embedding(num_classes, hidden_dim)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            SiTBlock(hidden_dim, num_heads, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])
        
        # 输出头
        self.norm = nn.LayerNorm(hidden_dim)
        self.decoder_pred = nn.Linear(hidden_dim, patch_size ** 2 * in_channels)
        
        # 表示头(用于自监督学习)
        self.repr_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, 256)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.repr_head[0].weight, std=0.02)
    
    def timestep_embedding(self, t, dim):
        """正弦位置编码"""
        half_dim = dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return emb
    
    def forward(self, x, t, y=None):
        """
        Args:
            x: 噪声图像 [B, C, H, W]
            t: 时间步 [B]
            y: 类别标签 [B] (可选)
        Returns:
            pred: 重建的patch [B, num_patches, patch_size*patch_size*in_channels]
            repr: 表示 [B, 256]
        """
        B = x.size(0)
        
        # Patch embedding
        x = self.patch_embed(x)  # [B, D, H/P, W/P]
        x = x.flatten(2).transpose(1, 2)  # [B, N, D]
        x = x + self.pos_embed  # 添加位置编码
        
        # 时间嵌入
        t_emb = self.timestep_embedding(t, self.hidden_dim)
        t_emb = self.time_embed(t_emb)  # [B, D]
        
        # 条件嵌入
        if y is not None:
            c_emb = self.class_embed(y)
        else:
            c_emb = 0
        
        # 注入时间和条件信息
        x = x + t_emb.unsqueeze(1) + c_emb.unsqueeze(1)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        
        # 重建输出
        pred = self.decoder_pred(x)
        
        # 表示输出
        repr = self.repr_head(x[:, 0])  # 使用[CLS] token
        
        return pred, repr
 
 
class SiTLoss(nn.Module):
    """SiT损失函数"""
    def __init__(self, lambda_align=0.5, lambda_contrast=0.1):
        super().__init__()
        self.lambda_align = lambda_align
        self.lambda_contrast = lambda_contrast
        self.repr_loss = nn.MSELoss()
    
    def forward(self, model, x0, x_noisy, t, repr_stopgrad=None):
        """
        Args:
            model: SiT模型
            x0: 原始图像
            x_noisy: 噪声图像
            t: 时间步
            repr_stopgrad: 用于对齐的目标表示
        """
        # 前向传播
        pred, repr = model(x_noisy, t)
        
        # MSE重建损失
        loss_mse = F.mse_loss(pred, x0)
        
        # 表示对齐损失
        loss_align = 0
        if repr_stopgrad is not None:
            loss_align = self.repr_loss(repr, repr_stopgrad.detach())
        
        # 总损失
        loss = loss_mse + self.lambda_align * loss_align
        
        return loss, loss_mse, loss_align
 
 
def train_sit(model, train_loader, optimizer, epochs=100, device='cuda'):
    """SiT训练循环"""
    model = model.to(device)
    model.train()
    
    loss_fn = SiTLoss()
    
    for epoch in range(epochs):
        total_loss = 0
        total_mse = 0
        total_align = 0
        
        for batch_idx, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            
            # 随机时间步
            t = torch.randint(0, 1000, (images.size(0),), device=device)
            
            # 添加噪声
            alpha_bar = 0.5  # 简化的噪声调度
            noise = torch.randn_like(images)
            x_noisy = torch.sqrt(alpha_bar) * images + \
                     torch.sqrt(1 - alpha_bar) * noise
            
            # 计算表示(用于对齐)
            with torch.no_grad():
                # 可以使用预训练的encoder
                repr_stopgrad = images.mean(dim=[2, 3])  # 简化的表示
            
            # 前向传播
            loss, mse_loss, align_loss = loss_fn(
                model, images.view(images.size(0), -1), 
                x_noisy, t, repr_stopgrad
            )
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_mse += mse_loss.item()
            total_align += align_loss.item()
        
        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, "
              f"MSE={total_mse/len(train_loader):.4f}, "
              f"Align={total_align/len(train_loader):.4f}")
 
 
def linear_evaluation(model, train_loader, test_loader, num_epochs=100, device='cuda'):
    """线性评估协议"""
    model = model.to(device)
    model.eval()
    
    # 冻结表示头,训练线性分类器
    linear_clf = nn.Linear(256, 10).to(device)
    optimizer = torch.optim.Adam(linear_clf.parameters(), lr=0.001)
    
    # 提取训练特征
    train_features = []
    train_labels = []
    with torch.no_grad():
        for images, labels in train_loader:
            images = images.to(device)
            _, repr = model(images, t=torch.zeros(images.size(0), device=device))
            train_features.append(repr)
            train_labels.append(labels)
    
    train_features = torch.cat(train_features)
    train_labels = torch.cat(train_labels)
    
    # 训练分类器
    for epoch in range(num_epochs):
        features = linear_clf(train_features)
        loss = F.cross_entropy(features, train_labels.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # 评估
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            _, repr = model(images, t=torch.zeros(images.size(0), device=device))
            outputs = linear_clf(repr)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels.to(device)).sum().item()
            total += labels.size(0)
    
    return correct / total

4. 与MAE和DINO的对比

4.1 方法论对比

特性MAEDINOSiT
核心机制Mask重建教师-学生蒸馏扩散去噪
预测目标像素值软概率原始图像
噪声调度
条件信息Mask token可选
表示类型解码器特征教师特征多尺度特征

4.2 架构对比

MAE:
[CLS] [Mask] [Mask] ... [Mask]
  │      │                  │
  ↓      ↓                  ↓
Encoder                  Decoder → 重建

DINO:
全局视图 ──────→ 教师网络 ──→ 软概率
局部视图 ──────→ 学生网络 ←── 预测

SiT:
噪声图像 ──→ Transformer ──→ 重建
              ↑
              │
          时间步嵌入

4.3 表示质量对比

任务MAEDINOSiT
线性探测中等
密集预测中等
生成质量中等N/A

5. SiT的创新点

5.1 插值学习

SiT的核心创新是插值学习范式

其中 是两个图像的插值。

这使得模型学习:

  1. 时间感知表示:理解去噪过程的动态
  2. 语义插值:学习语义空间的平滑性
  3. 多尺度特征:从粗到细的重建

5.2 条件机制

SiT支持多种条件机制:

# 无条件
x_t ──→ Transformer ──→ x_0
 
# 类别条件
x_t + class_emb ──→ Transformer ──→ x_0
 
# CLIP条件
x_t + clip_emb ──→ Transformer ──→ x_0
 
# 混合条件
x_t + class_emb + clip_emb ──→ Transformer ──→ x_0

5.3 表示提取

SiT可以从多个位置提取表示:

# [CLS] token
repr_cls = x[:, 0]
 
# 所有patch的均值
repr_mean = x[:, 1:].mean(dim=1)
 
# 加权平均
repr_weighted = (x[:, 1:] * attention_weights).sum(dim=1)

6. 实验结果

6.1 CIFAR-10/CIFAR-100

方法CIFAR-10CIFAR-100
MAE90.8%65.3%
DINO93.4%72.1%
SiT94.2%73.8%

6.2 ImageNet线性评估

方法ViT-SViT-B
MAE74.6%76.8%
DINO79.3%81.4%
SiT80.1%82.3%

6.3 密集预测

任务MAEDINOSiT
语义分割47.2%45.8%48.5%
目标检测49.3%47.8%50.1%

7. 实践技巧

7.1 噪声调度

# 线性噪声调度
betas = torch.linspace(1e-4, 0.02, 1000)
 
# 余弦噪声调度
alphas = torch.cos(torch.linspace(0, math.pi/2, 1000)) ** 2
alphas_bar = alphas.cumprod(dim=0)

7.2 超参数配置

参数推荐值
隐藏维度384-768
Transformer层数8-16
注意力头数6-12
Patch大小2-8
最大时间步1000
学习率1e-4

7.3 训练策略

TRAINING_CONFIG = {
    'optimizer': 'AdamW',
    'lr': 1e-4,
    'weight_decay': 0.05,
    'batch_size': 256,
    'epochs': 800,
    'warmup_epochs': 10,
    'cosine_annealing': True,
    'ema_decay': 0.9999
}

8. 扩展与应用

8.1 SiT-Diffusion

将SiT扩展为完整的扩散模型:

class SiTDiffusion:
    """SiT引导的扩散模型"""
    def __init__(self, model, num_steps=1000):
        self.model = model
        self.num_steps = num_steps
    
    @torch.no_grad()
    def sample(self, shape, guidance_scale=3.0):
        """采样新图像"""
        x = torch.randn(shape)
        
        for t in reversed(range(self.num_steps)):
            t_batch = torch.full((shape[0],), t)
            pred = self.model(x, t_batch)
            x = self.ddim_step(x, pred, t)
        
        return x

8.2 SiT表示用于下游任务

# 图像检索
def retrieve_with_sit(model, query, gallery, k=10):
    query_repr = extract_sit_repr(model, query)
    gallery_repr = extract_sit_repr(model, gallery)
    
    similarity = F.cosine_similarity(query_repr, gallery_repr, dim=-1)
    return similarity.topk(k)

9. 总结

SiT的核心贡献:

  1. 框架统一:将扩散模型与自监督学习统一
  2. 插值学习:通过插值损失学习语义空间结构
  3. 条件生成:支持多种条件机制的灵活框架
  4. 多尺度表示:从去噪过程中提取多尺度特征

SiT证明了生成模型可以学习表示,为自监督学习开辟了新方向。


参考

Footnotes

  1. Li, Z., et al. (2023). “SiT: Self-supervised Image Transformer”. arXiv:2305.18256. Link