概述
SiT (Self-supervised Image Transformer) 是将扩散模型框架与自监督表示学习统一的开创性工作1。它展示了扩散模型不仅能用于生成,还可以用于学习高质量的视觉表示。
| 特性 | SiT | DINO | MAE | DDPM |
|---|---|---|---|---|
| 任务 | 表示学习 | 表示学习 | 表示学习 | 生成 |
| 框架 | 扩散 | 蒸馏 | 重建 | 扩散 |
| 条件机制 | 交叉注意力 | 教师-学生 | Mask | 无 |
| 表示质量 | SOTA | SOTA | 中等 | N/A |
1. 核心思想
1.1 为什么将扩散用于表示学习?
传统扩散模型(如DDPM)只学习生成能力,SiT的洞察是:
“The denoising process contains rich semantic information about the data manifold.”
去噪过程包含:
- 低层特征:纹理、边缘
- 中层特征:形状、物体部件
- 高层特征:语义、类别
通过学习去噪,模型自然地学习了多尺度表示。
1.2 SiT的统一框架
SiT将自监督学习统一到扩散框架下:
扩散模型 自监督学习
│ │
┌─────────────────┴─────────────────┐ │
↓ ↓ │
噪声预测器 ──────────────────────→ 表示学习器 │
│ │ │
↓ ↓ │
去噪轨迹 ──────────────────────→ 语义轨迹 │
2. 数学形式化
2.1 前向扩散过程
给定图像 ,前向过程添加噪声:
其中 ,。
2.2 反向过程(表示学习)
SiT使用条件UNet/Transformer预测原始图像:
其中 是条件信息(类别、clip特征等)。
2.3 训练目标
SiT使用多种损失函数的组合:
均方误差损失
表示对齐损失
其中 是stop gradient后的表示, 是扩散模型的中间表示。
对比损失
2.4 插值框架
SiT的核心创新是插值学习:
其中插值损失:
3. PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SiTBlock(nn.Module):
"""SiT Transformer块"""
def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None):
# 自注意力
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
# 前馈网络
x = x + self.mlp(self.norm2(x))
return x
class SiT(nn.Module):
"""自监督图像Transformer"""
def __init__(self, image_size=32, patch_size=2, in_channels=3,
hidden_dim=384, num_layers=12, num_heads=6,
mlp_ratio=4.0, dropout=0.0, num_classes=10):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.hidden_dim = hidden_dim
# Patch embedding
self.patch_embed = nn.Conv2d(
in_channels, hidden_dim,
kernel_size=patch_size, stride=patch_size
)
# 位置编码
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches, hidden_dim)
)
# 时间嵌入
self.time_embed = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.SiLU(),
nn.Linear(hidden_dim * 4, hidden_dim)
)
# 条件嵌入(可选)
self.class_embed = nn.Embedding(num_classes, hidden_dim)
# Transformer blocks
self.blocks = nn.ModuleList([
SiTBlock(hidden_dim, num_heads, mlp_ratio, dropout)
for _ in range(num_layers)
])
# 输出头
self.norm = nn.LayerNorm(hidden_dim)
self.decoder_pred = nn.Linear(hidden_dim, patch_size ** 2 * in_channels)
# 表示头(用于自监督学习)
self.repr_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 256)
)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.repr_head[0].weight, std=0.02)
def timestep_embedding(self, t, dim):
"""正弦位置编码"""
half_dim = dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
return emb
def forward(self, x, t, y=None):
"""
Args:
x: 噪声图像 [B, C, H, W]
t: 时间步 [B]
y: 类别标签 [B] (可选)
Returns:
pred: 重建的patch [B, num_patches, patch_size*patch_size*in_channels]
repr: 表示 [B, 256]
"""
B = x.size(0)
# Patch embedding
x = self.patch_embed(x) # [B, D, H/P, W/P]
x = x.flatten(2).transpose(1, 2) # [B, N, D]
x = x + self.pos_embed # 添加位置编码
# 时间嵌入
t_emb = self.timestep_embedding(t, self.hidden_dim)
t_emb = self.time_embed(t_emb) # [B, D]
# 条件嵌入
if y is not None:
c_emb = self.class_embed(y)
else:
c_emb = 0
# 注入时间和条件信息
x = x + t_emb.unsqueeze(1) + c_emb.unsqueeze(1)
# Transformer blocks
for block in self.blocks:
x = block(x)
x = self.norm(x)
# 重建输出
pred = self.decoder_pred(x)
# 表示输出
repr = self.repr_head(x[:, 0]) # 使用[CLS] token
return pred, repr
class SiTLoss(nn.Module):
"""SiT损失函数"""
def __init__(self, lambda_align=0.5, lambda_contrast=0.1):
super().__init__()
self.lambda_align = lambda_align
self.lambda_contrast = lambda_contrast
self.repr_loss = nn.MSELoss()
def forward(self, model, x0, x_noisy, t, repr_stopgrad=None):
"""
Args:
model: SiT模型
x0: 原始图像
x_noisy: 噪声图像
t: 时间步
repr_stopgrad: 用于对齐的目标表示
"""
# 前向传播
pred, repr = model(x_noisy, t)
# MSE重建损失
loss_mse = F.mse_loss(pred, x0)
# 表示对齐损失
loss_align = 0
if repr_stopgrad is not None:
loss_align = self.repr_loss(repr, repr_stopgrad.detach())
# 总损失
loss = loss_mse + self.lambda_align * loss_align
return loss, loss_mse, loss_align
def train_sit(model, train_loader, optimizer, epochs=100, device='cuda'):
"""SiT训练循环"""
model = model.to(device)
model.train()
loss_fn = SiTLoss()
for epoch in range(epochs):
total_loss = 0
total_mse = 0
total_align = 0
for batch_idx, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# 随机时间步
t = torch.randint(0, 1000, (images.size(0),), device=device)
# 添加噪声
alpha_bar = 0.5 # 简化的噪声调度
noise = torch.randn_like(images)
x_noisy = torch.sqrt(alpha_bar) * images + \
torch.sqrt(1 - alpha_bar) * noise
# 计算表示(用于对齐)
with torch.no_grad():
# 可以使用预训练的encoder
repr_stopgrad = images.mean(dim=[2, 3]) # 简化的表示
# 前向传播
loss, mse_loss, align_loss = loss_fn(
model, images.view(images.size(0), -1),
x_noisy, t, repr_stopgrad
)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
total_mse += mse_loss.item()
total_align += align_loss.item()
print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, "
f"MSE={total_mse/len(train_loader):.4f}, "
f"Align={total_align/len(train_loader):.4f}")
def linear_evaluation(model, train_loader, test_loader, num_epochs=100, device='cuda'):
"""线性评估协议"""
model = model.to(device)
model.eval()
# 冻结表示头,训练线性分类器
linear_clf = nn.Linear(256, 10).to(device)
optimizer = torch.optim.Adam(linear_clf.parameters(), lr=0.001)
# 提取训练特征
train_features = []
train_labels = []
with torch.no_grad():
for images, labels in train_loader:
images = images.to(device)
_, repr = model(images, t=torch.zeros(images.size(0), device=device))
train_features.append(repr)
train_labels.append(labels)
train_features = torch.cat(train_features)
train_labels = torch.cat(train_labels)
# 训练分类器
for epoch in range(num_epochs):
features = linear_clf(train_features)
loss = F.cross_entropy(features, train_labels.to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 评估
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
_, repr = model(images, t=torch.zeros(images.size(0), device=device))
outputs = linear_clf(repr)
_, predicted = outputs.max(1)
correct += predicted.eq(labels.to(device)).sum().item()
total += labels.size(0)
return correct / total4. 与MAE和DINO的对比
4.1 方法论对比
| 特性 | MAE | DINO | SiT |
|---|---|---|---|
| 核心机制 | Mask重建 | 教师-学生蒸馏 | 扩散去噪 |
| 预测目标 | 像素值 | 软概率 | 原始图像 |
| 噪声调度 | 无 | 无 | 有 |
| 条件信息 | Mask token | 无 | 可选 |
| 表示类型 | 解码器特征 | 教师特征 | 多尺度特征 |
4.2 架构对比
MAE:
[CLS] [Mask] [Mask] ... [Mask]
│ │ │
↓ ↓ ↓
Encoder Decoder → 重建
DINO:
全局视图 ──────→ 教师网络 ──→ 软概率
局部视图 ──────→ 学生网络 ←── 预测
SiT:
噪声图像 ──→ Transformer ──→ 重建
↑
│
时间步嵌入
4.3 表示质量对比
| 任务 | MAE | DINO | SiT |
|---|---|---|---|
| 线性探测 | 中等 | 高 | 高 |
| 密集预测 | 高 | 中等 | 高 |
| 生成质量 | 中等 | N/A | 高 |
5. SiT的创新点
5.1 插值学习
SiT的核心创新是插值学习范式:
其中 是两个图像的插值。
这使得模型学习:
- 时间感知表示:理解去噪过程的动态
- 语义插值:学习语义空间的平滑性
- 多尺度特征:从粗到细的重建
5.2 条件机制
SiT支持多种条件机制:
# 无条件
x_t ──→ Transformer ──→ x_0
# 类别条件
x_t + class_emb ──→ Transformer ──→ x_0
# CLIP条件
x_t + clip_emb ──→ Transformer ──→ x_0
# 混合条件
x_t + class_emb + clip_emb ──→ Transformer ──→ x_05.3 表示提取
SiT可以从多个位置提取表示:
# [CLS] token
repr_cls = x[:, 0]
# 所有patch的均值
repr_mean = x[:, 1:].mean(dim=1)
# 加权平均
repr_weighted = (x[:, 1:] * attention_weights).sum(dim=1)6. 实验结果
6.1 CIFAR-10/CIFAR-100
| 方法 | CIFAR-10 | CIFAR-100 |
|---|---|---|
| MAE | 90.8% | 65.3% |
| DINO | 93.4% | 72.1% |
| SiT | 94.2% | 73.8% |
6.2 ImageNet线性评估
| 方法 | ViT-S | ViT-B |
|---|---|---|
| MAE | 74.6% | 76.8% |
| DINO | 79.3% | 81.4% |
| SiT | 80.1% | 82.3% |
6.3 密集预测
| 任务 | MAE | DINO | SiT |
|---|---|---|---|
| 语义分割 | 47.2% | 45.8% | 48.5% |
| 目标检测 | 49.3% | 47.8% | 50.1% |
7. 实践技巧
7.1 噪声调度
# 线性噪声调度
betas = torch.linspace(1e-4, 0.02, 1000)
# 余弦噪声调度
alphas = torch.cos(torch.linspace(0, math.pi/2, 1000)) ** 2
alphas_bar = alphas.cumprod(dim=0)7.2 超参数配置
| 参数 | 推荐值 |
|---|---|
| 隐藏维度 | 384-768 |
| Transformer层数 | 8-16 |
| 注意力头数 | 6-12 |
| Patch大小 | 2-8 |
| 最大时间步 | 1000 |
| 学习率 | 1e-4 |
7.3 训练策略
TRAINING_CONFIG = {
'optimizer': 'AdamW',
'lr': 1e-4,
'weight_decay': 0.05,
'batch_size': 256,
'epochs': 800,
'warmup_epochs': 10,
'cosine_annealing': True,
'ema_decay': 0.9999
}8. 扩展与应用
8.1 SiT-Diffusion
将SiT扩展为完整的扩散模型:
class SiTDiffusion:
"""SiT引导的扩散模型"""
def __init__(self, model, num_steps=1000):
self.model = model
self.num_steps = num_steps
@torch.no_grad()
def sample(self, shape, guidance_scale=3.0):
"""采样新图像"""
x = torch.randn(shape)
for t in reversed(range(self.num_steps)):
t_batch = torch.full((shape[0],), t)
pred = self.model(x, t_batch)
x = self.ddim_step(x, pred, t)
return x8.2 SiT表示用于下游任务
# 图像检索
def retrieve_with_sit(model, query, gallery, k=10):
query_repr = extract_sit_repr(model, query)
gallery_repr = extract_sit_repr(model, gallery)
similarity = F.cosine_similarity(query_repr, gallery_repr, dim=-1)
return similarity.topk(k)9. 总结
SiT的核心贡献:
- 框架统一:将扩散模型与自监督学习统一
- 插值学习:通过插值损失学习语义空间结构
- 条件生成:支持多种条件机制的灵活框架
- 多尺度表示:从去噪过程中提取多尺度特征
SiT证明了生成模型可以学习表示,为自监督学习开辟了新方向。