概述
联合嵌入预测架构(Joint Embedding Predictive Architecture, JEPA)是由Meta AI研究团队提出的一种新型自监督学习框架,旨在弥合生成式AI(如重建式自编码器)和判别式AI(如对比学习)之间的差距。1 JEPA通过学习预测潜在空间中的表示来实现高效的自监督学习,避免了像素空间的重建复杂性。
为什么需要JEPA?
自监督学习的范式对比
┌─────────────────────────────────────────────────────────────────┐
│ 自监督学习范式对比 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │
│ │ 生成式 │ │ 对比式 │ │ JEPA │ │
│ │ (重建式) │ │ (对比式) │ │ (预测式) │ │
│ └───────┬───────┘ └───────┬───────┘ └───────┬───────┘ │
│ │ │ │ │
│ ↓ ↓ ↓ │
│ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │
│ │ 解码器重建 │ │ 正负样本对比 │ │ 潜在空间预测 │ │
│ │ 像素/Token │ │ InfoNCE损失 │ │ 预测表示 │ │
│ └───────────────┘ └───────────────┘ └───────────────┘ │
│ │ │ │ │
│ ↓ ↓ ↓ │
│ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │
│ │ 高计算成本 │ │ 需要负样本 │ │ 高效表示学习 │ │
│ │ 细节重建 │ │ 潜在崩溃 │ │ 无需重建/对比 │ │
│ └───────────────┘ └───────────────┘ └───────────────┘ │
└─────────────────────────────────────────────────────────────────┘
现有方法的局限性
| 方法 | 优点 | 缺点 |
|---|---|---|
| MAE/DALL-E | 生成能力强 | 计算成本高、像素级细节不必要 |
| SimCLR/BYOL | 判别能力强 | 需要负样本、可能陷入表示崩溃 |
| JEPA | 高效、无需负样本 | 需要精心设计预测头 |
JEPA核心架构
基本框架
JEPA的核心思想是:在潜在空间预测表示,而非像素空间重建。
import torch
import torch.nn as nn
import torch.nn.functional as F
class JEPA(nn.Module):
"""
联合嵌入预测架构(基础版本)
核心组件:
1. 编码器 E: 提取输入x的表示
2. 预测器 P: 从x的表示预测y的表示
3. 目标编码器 E': 编码器E的EMA版本,用于生成目标
"""
def __init__(self, encoder, predictor, predictor_output_dim):
super().__init__()
self.encoder = encoder
self.predictor = predictor
# 目标编码器(EMA)
self.target_encoder = copy.deepcopy(encoder)
for param in self.target_encoder.parameters():
param.requires_grad = False
# 预测器输出维度
self.predictor_output_dim = predictor_output_dim
def forward(self, x, y, block_mask=None):
"""
前向传播
Args:
x: 上下文视图(如部分遮挡的图像)
y: 目标视图(如完整图像)
block_mask: 用于遮蔽的掩码
"""
# 编码x(可能有掩码)
z_x = self.encoder(x, mask=block_mask)
# 预测y的表示
z_y_pred = self.predictor(z_x)
# 目标表示(通过EMA编码器)
with torch.no_grad():
z_y_target = self.target_encoder(y)
return z_y_pred, z_y_target
def loss(self, x, y, block_mask=None):
"""
JEPA损失:L2距离(或其他距离度量)
"""
z_pred, z_target = self.forward(x, y, block_mask)
# L2距离
loss = F.mse_loss(z_pred, z_target)
return loss
def update_target_encoder(self, momentum=0.996):
"""
更新目标编码器(EMA)
"""
for param_q, param_k in zip(
self.encoder.parameters(),
self.target_encoder.parameters()
):
param_k.data.mul_(momentum).add_(
param_q.data, alpha=1 - momentum
)关键设计决策
1. 目标编码器(Target Encoder)
目标编码器是编码器的**指数移动平均(EMA)**版本:
class EMATargetEncoder:
"""
EMA目标编码器
"""
def __init__(self, model, momentum=0.996):
self.model = model
self.momentum = momentum
self.shadow = {}
self.backup = {}
# 初始化影子参数
self.register()
def register(self):
"""注册模型参数为影子"""
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
@torch.no_grad()
def update(self):
"""更新影子参数"""
for name, param in self.model.named_parameters():
if param.requires_grad:
new_average = (
self.momentum * self.shadow[name] +
(1.0 - self.momentum) * param.data
)
self.shadow[name] = new_average.clone()
def __call__(self, x):
"""使用影子参数前向传播"""
for name, param in self.model.named_parameters():
if param.requires_grad:
self.backup[name] = param.data
param.data = self.shadow[name]
output = self.model(x)
# 恢复原始参数
for name, param in self.model.named_parameters():
if param.requires_grad:
param.data = self.backup[name]
return output2. 预测器(Predictor)
预测器将上下文表示映射到目标表示空间:
class Predictor(nn.Module):
"""
JEPA预测器
将上下文表示预测为目标表示
"""
def __init__(self, input_dim, output_dim, hidden_dim=None, num_layers=4):
super().__init__()
if hidden_dim is None:
hidden_dim = input_dim
# 多层感知机预测器
layers = []
layers.append(nn.Linear(input_dim, hidden_dim))
layers.append(nn.GELU())
for _ in range(num_layers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, output_dim))
self.predictor = nn.Sequential(*layers)
# 可选的mask token(用于预测特定部分)
self.mask_token = nn.Parameter(torch.randn(1, 1, input_dim))
def forward(self, z_x, mask=None):
"""
Args:
z_x: 编码器输出的上下文表示
mask: 可选掩码,指示要预测的位置
"""
if mask is not None:
# 在掩码位置使用mask token
z_x = z_x * (1 - mask) + self.mask_token * mask
return self.predictor(z_x)I-JEPA:图像JEPA
架构详解
I-JEPA(Image JEPA)是JEPA在图像领域的应用,核心思想是从部分遮挡的图像块预测完整图像的表示。2
import torch
import torch.nn as nn
from torchvision.models import vision_transformer
class IJEPA(nn.Module):
"""
Image JEPA
论文: "Self-supervised Visual Representation Learning with JEPA" (ICML 2023)
"""
def __init__(
self,
image_size=224,
patch_size=16,
num_layers=12,
num_heads=12,
embedding_dim=768,
predictor_depth=4,
momentum=0.996
):
super().__init__()
# ViT编码器
self.encoder = vision_transformer(
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=embedding_dim,
mlp_dim=embedding_dim * 4
)
# 目标编码器(EMA)
self.target_encoder = copy.deepcopy(self.encoder)
for param in self.target_encoder.parameters():
param.requires_grad = False
# 预测器
self.predictor = nn.Sequential(
nn.Linear(embedding_dim, embedding_dim * 4),
nn.GELU(),
nn.Linear(embedding_dim * 4, embedding_dim),
# 输出与目标编码器相同维度
)
# 用于预测的位置编码
self.predictor_pos_embed = nn.Parameter(
torch.randn(1, 64, embedding_dim) * 0.02
)
# Mask生成器
self.mask_generator = MaskGenerator(
num_patches=image_size // patch_size,
mask_ratio=0.75
)
def forward(self, x):
"""
前向传播
"""
# 生成mask(上下文和目标mask)
context_mask, target_mask = self.mask_generator()
# 上下文编码
z_x = self.encoder(x, mask=context_mask)
# 预测目标表示
# 只在目标位置进行预测
z_pred = self.predictor(z_x + self.predictor_pos_embed[target_mask])
# 目标表示(EMA编码器)
with torch.no_grad():
z_target = self.target_encoder(x)
# 只取目标位置的表示
z_target = z_target[target_mask]
return z_pred, z_target
def loss(self, x):
"""L2损失"""
z_pred, z_target = self.forward(x)
return F.mse_loss(z_pred, z_target)
def update_target(self, momentum=0.996):
"""更新EMA目标编码器"""
with torch.no_grad():
for param_q, param_k in zip(
self.encoder.parameters(),
self.target_encoder.parameters()
):
param_k.data.mul_(momentum).add_(
param_q.data, alpha=1 - momentum
)
class MaskGenerator:
"""
Mask生成器
生成上下文mask和目标mask
"""
def __init__(self, num_patches, mask_ratio=0.75):
self.num_patches = num_patches * num_patches
self.mask_ratio = mask_ratio
def __call__(self):
"""
生成mask
Returns:
context_mask: 上下文位置的mask(False表示保留)
target_mask: 目标位置的mask(True表示预测)
"""
# 随机采样要遮蔽的位置
num_keep = int(self.num_patches * (1 - self.mask_ratio))
# 随机打乱
perm = torch.randperm(self.num_patches)
# 目标:要被遮蔽的位置
target_mask = torch.zeros(self.num_patches, dtype=torch.bool)
target_mask[perm[:num_keep]] = True
# 上下文:保留的位置
context_mask = ~target_mask
return context_mask, target_maskI-JEPA的训练流程
def train_ijepa(model, dataloader, epochs=100, lr=1e-3, momentum=0.996):
"""
I-JEPA训练流程
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
model.train()
for epoch in range(epochs):
total_loss = 0
for batch_idx, (images,) in enumerate(dataloader):
optimizer.zero_grad()
# 计算损失
loss = model.loss(images)
# 反向传播
loss.backward()
optimizer.step()
# 更新EMA目标编码器
model.update_target(momentum)
total_loss += loss.item()
print(f"Epoch {epoch}, Loss: {total_loss / len(dataloader):.4f}")V-JEPA:视频JEPA
架构详解
V-JEPA(Video JEPA)是JEPA在视频领域的扩展,关注时间维度的表示学习。3
class VJEPA(nn.Module):
"""
Video JEPA
用于视频表示学习
"""
def __init__(
self,
spatial_patch_size=16,
temporal_patch_size=2,
num_frames=16,
embedding_dim=768,
num_layers=12,
predictor_depth=4
):
super().__init__()
# 时空ViT编码器
self.encoder = SpatioTemporalViT(
spatial_patch_size=spatial_patch_size,
temporal_patch_size=temporal_patch_size,
num_frames=num_frames,
embedding_dim=embedding_dim,
depth=num_layers
)
# EMA目标编码器
self.target_encoder = copy.deepcopy(self.encoder)
for param in self.target_encoder.parameters():
param.requires_grad = False
# 预测器
self.predictor = VideoPredictor(
embedding_dim=embedding_dim,
predictor_depth=predictor_depth
)
def forward(self, video, context_mask, target_mask):
"""
前向传播
Args:
video: 视频张量 (B, C, T, H, W)
context_mask: 上下文时空mask
target_mask: 目标时空mask
"""
# 编码上下文
z_x = self.encoder(video, mask=context_mask)
# 预测目标表示
z_pred = self.predictor(z_x, target_mask)
# 目标表示
with torch.no_grad():
z_target = self.target_encoder(video, mask=target_mask)
return z_pred, z_target
def loss(self, video, context_mask, target_mask):
"""L2损失"""
z_pred, z_target = self.forward(video, context_mask, target_mask)
return F.mse_loss(z_pred, z_target)
class SpatioTemporalViT(nn.Module):
"""
时空ViT编码器
"""
def __init__(self, spatial_patch_size, temporal_patch_size,
num_frames, embedding_dim, depth):
super().__init__()
# 时空patch嵌入
self.patch_embed = SpatioTemporalPatchEmbed(
spatial_patch_size=spatial_patch_size,
temporal_patch_size=temporal_patch_size,
in_channels=3,
embed_dim=embedding_dim
)
# 时空位置编码
num_patches_per_frame = (224 // spatial_patch_size) ** 2
self.pos_embed = nn.Parameter(
torch.randn(1, num_frames, num_patches_per_frame, embedding_dim) * 0.02
)
# Transformer块
self.blocks = nn.ModuleList([
SpatioTemporalTransformerBlock(embedding_dim)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embedding_dim)
def forward(self, x, mask=None):
"""
Args:
x: (B, C, T, H, W)
mask: 可选mask
"""
B, C, T, H, W = x.shape
# Patch嵌入
x = self.patch_embed(x) # (B, T, num_patches, embed_dim)
# 添加位置编码
x = x + self.pos_embed[:, :T]
# Flatten空间和时间
x = x.flatten(1, 2) # (B, T*num_patches, embed_dim)
# 应用mask
if mask is not None:
x = x * mask.unsqueeze(-1)
# Transformer块
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x
class VideoPredictor(nn.Module):
"""
视频预测器
从上下文时空patch预测目标时空patch
"""
def __init__(self, embedding_dim, predictor_depth):
super().__init__()
# Mask token
self.mask_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
# 预测器MLP
self.predictor = nn.ModuleList([
nn.Sequential(
nn.Linear(embedding_dim, embedding_dim * 4),
nn.GELU(),
nn.Linear(embedding_dim * 4, embedding_dim)
)
for _ in range(predictor_depth)
])
self.norm = nn.LayerNorm(embedding_dim)
def forward(self, z_context, target_mask):
"""
Args:
z_context: 上下文patch的表示
target_mask: 目标patch的掩码
"""
# 在目标位置插入mask token
z = z_context.clone()
num_patches = z.shape[1]
# 生成目标位置的索引
target_indices = torch.where(target_mask.flatten())[0]
# 添加mask token
for idx in target_indices:
if idx < z.shape[1]:
z[:, idx] = self.mask_token.squeeze(0)
# 预测
for layer in self.predictor:
z = layer(z)
z = self.norm(z)
# 只返回目标位置
return z[:, target_indices]JEPA与其他方法的对比
理论对比
| 方面 | MAE/重建式 | SimCLR/对比式 | JEPA/预测式 |
|---|---|---|---|
| 学习目标 | 像素重建 | 实例判别 | 表示预测 |
| 计算成本 | 高(解码器) | 中(对比) | 低(无解码器) |
| 负样本需求 | 否 | 是 | 否 |
| 语义级别 | 低 | 高 | 高 |
| 崩溃风险 | 低 | 中 | 低 |
实际性能对比
在ImageNet分类任务上的对比:
| 方法 | Top-1 Acc | 预训练 Epochs | 所需显存 |
|---|---|---|---|
| MoCo v3 | 83.2% | 300 | 16GB |
| DINO | 82.8% | 300 | 16GB |
| MAE | 83.6% | 1600 | 16GB |
| I-JEPA | 84.1% | 200 | 8GB |
实践指南
超参数设置
# I-JEPA推荐配置
config = {
'image_size': 224,
'patch_size': 16,
'embedding_dim': 768,
'num_layers': 12,
'num_heads': 12,
'predictor_depth': 4,
'mask_ratio': 0.75, # 遮蔽75%的patch
'momentum': 0.996, # EMA动量
'lr': 1e-3,
'weight_decay': 0.04,
'batch_size': 2048,
'warmup_epochs': 10,
}下游任务微调
class IJEPAForDownstream(nn.Module):
"""
I-JEPA用于下游任务
"""
def __init__(self, ijepa_model, num_classes):
super().__init__()
# 冻结编码器
self.encoder = ijepa_model.encoder
for param in self.encoder.parameters():
param.requires_grad = False
# 添加分类头
self.classifier = nn.Linear(ijepa_model.encoder.embedding_dim, num_classes)
def forward(self, x):
with torch.no_grad():
z = self.encoder(x)
# 使用[CLS] token或平均池化
z = z[:, 0] # [CLS] token
return self.classifier(z)扩展与应用
1. 多模态JEPA
class MultimodalJEPA(nn.Module):
"""
多模态JEPA
支持图像-文本配对学习
"""
def __init__(self, image_encoder, text_encoder, embedding_dim):
super().__init__()
self.image_encoder = image_encoder
self.text_encoder = text_encoder
# 图像预测器
self.image_predictor = Predictor(embedding_dim, embedding_dim)
# 文本预测器
self.text_predictor = Predictor(embedding_dim, embedding_dim)
# EMA目标编码器
self.image_target = copy.deepcopy(image_encoder)
self.text_target = copy.deepcopy(text_encoder)
def forward(self, images, texts):
"""
交叉预测:图像预测文本,文本预测图像
"""
# 图像 -> 文本
z_img = self.image_encoder(images)
text_pred = self.text_predictor(z_img)
# 文本 -> 图像
z_text = self.text_encoder(texts)
img_pred = self.image_predictor(z_text)
# 目标
with torch.no_grad():
text_target = self.text_target(texts)
img_target = self.image_target(images)
return {
'img2text': (text_pred, text_target),
'text2img': (img_pred, img_target)
}2. 图JEPA
class GraphJEPA(nn.Module):
"""
图JEPA
用于图表示学习
"""
def __init__(self, gnn_encoder, embedding_dim):
super().__init__()
self.encoder = gnn_encoder
self.target_encoder = copy.deepcopy(gnn_encoder)
self.predictor = nn.Sequential(
nn.Linear(embedding_dim, embedding_dim * 4),
nn.GELU(),
nn.Linear(embedding_dim * 4, embedding_dim)
)
def forward(self, x, edge_index, mask_ratio=0.5):
"""
Args:
x: 节点特征
edge_index: 边索引
mask_ratio: 节点遮蔽比例
"""
# 随机遮蔽节点
num_nodes = x.shape[0]
mask = torch.rand(num_nodes) < mask_ratio
# 编码非遮蔽节点
z_x = self.encoder(x, edge_index)
# 预测遮蔽节点
z_pred = self.predictor(z_x)
# 目标
with torch.no_grad():
z_target = self.target_encoder(x, edge_index)
return z_pred[mask], z_target[mask]总结
JEPA的核心要点:
| 特性 | 说明 |
|---|---|
| 核心思想 | 在潜在空间预测表示,而非像素空间重建 |
| 关键组件 | 编码器、预测器、EMA目标编码器 |
| 优势 | 高效、无需负样本、语义级别表示 |
| 应用 | I-JEPA(图像)、V-JEPA(视频)、多模态 |
参考
相关文章
- contrastive-learning-theory — 对比学习理论
- masked-autoencoder — MAE详解
- self-supervised-learning — 自监督学习综述
- dino-self-supervised-vision-transformer — DINO方法