概述

联合嵌入预测架构(Joint Embedding Predictive Architecture, JEPA)是由Meta AI研究团队提出的一种新型自监督学习框架,旨在弥合生成式AI(如重建式自编码器)和判别式AI(如对比学习)之间的差距。1 JEPA通过学习预测潜在空间中的表示来实现高效的自监督学习,避免了像素空间的重建复杂性。


为什么需要JEPA?

自监督学习的范式对比

┌─────────────────────────────────────────────────────────────────┐
│                    自监督学习范式对比                               │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  ┌───────────────┐     ┌───────────────┐     ┌───────────────┐ │
│  │  生成式       │     │  对比式       │     │  JEPA        │ │
│  │  (重建式)     │     │  (对比式)     │     │  (预测式)     │ │
│  └───────┬───────┘     └───────┬───────┘     └───────┬───────┘ │
│          │                     │                     │           │
│          ↓                     ↓                     ↓           │
│  ┌───────────────┐     ┌───────────────┐     ┌───────────────┐ │
│  │ 解码器重建    │     │ 正负样本对比  │     │ 潜在空间预测  │ │
│  │ 像素/Token   │     │ InfoNCE损失  │     │ 预测表示     │ │
│  └───────────────┘     └───────────────┘     └───────────────┘ │
│          │                     │                     │           │
│          ↓                     ↓                     ↓           │
│  ┌───────────────┐     ┌───────────────┐     ┌───────────────┐ │
│  │ 高计算成本    │     │ 需要负样本    │     │ 高效表示学习  │ │
│  │ 细节重建     │     │ 潜在崩溃     │     │ 无需重建/对比 │ │
│  └───────────────┘     └───────────────┘     └───────────────┘ │
└─────────────────────────────────────────────────────────────────┘

现有方法的局限性

方法优点缺点
MAE/DALL-E生成能力强计算成本高、像素级细节不必要
SimCLR/BYOL判别能力强需要负样本、可能陷入表示崩溃
JEPA高效、无需负样本需要精心设计预测头

JEPA核心架构

基本框架

JEPA的核心思想是:在潜在空间预测表示,而非像素空间重建

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class JEPA(nn.Module):
    """
    联合嵌入预测架构(基础版本)
    
    核心组件:
    1. 编码器 E: 提取输入x的表示
    2. 预测器 P: 从x的表示预测y的表示
    3. 目标编码器 E': 编码器E的EMA版本,用于生成目标
    """
    def __init__(self, encoder, predictor, predictor_output_dim):
        super().__init__()
        
        self.encoder = encoder
        self.predictor = predictor
        
        # 目标编码器(EMA)
        self.target_encoder = copy.deepcopy(encoder)
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        
        # 预测器输出维度
        self.predictor_output_dim = predictor_output_dim
    
    def forward(self, x, y, block_mask=None):
        """
        前向传播
        
        Args:
            x: 上下文视图(如部分遮挡的图像)
            y: 目标视图(如完整图像)
            block_mask: 用于遮蔽的掩码
        """
        # 编码x(可能有掩码)
        z_x = self.encoder(x, mask=block_mask)
        
        # 预测y的表示
        z_y_pred = self.predictor(z_x)
        
        # 目标表示(通过EMA编码器)
        with torch.no_grad():
            z_y_target = self.target_encoder(y)
        
        return z_y_pred, z_y_target
    
    def loss(self, x, y, block_mask=None):
        """
        JEPA损失:L2距离(或其他距离度量)
        """
        z_pred, z_target = self.forward(x, y, block_mask)
        
        # L2距离
        loss = F.mse_loss(z_pred, z_target)
        
        return loss
    
    def update_target_encoder(self, momentum=0.996):
        """
        更新目标编码器(EMA)
        """
        for param_q, param_k in zip(
            self.encoder.parameters(), 
            self.target_encoder.parameters()
        ):
            param_k.data.mul_(momentum).add_(
                param_q.data, alpha=1 - momentum
            )

关键设计决策

1. 目标编码器(Target Encoder)

目标编码器是编码器的**指数移动平均(EMA)**版本:

class EMATargetEncoder:
    """
    EMA目标编码器
    """
    def __init__(self, model, momentum=0.996):
        self.model = model
        self.momentum = momentum
        self.shadow = {}
        self.backup = {}
        
        # 初始化影子参数
        self.register()
    
    def register(self):
        """注册模型参数为影子"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()
    
    @torch.no_grad()
    def update(self):
        """更新影子参数"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                new_average = (
                    self.momentum * self.shadow[name] + 
                    (1.0 - self.momentum) * param.data
                )
                self.shadow[name] = new_average.clone()
    
    def __call__(self, x):
        """使用影子参数前向传播"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data
                param.data = self.shadow[name]
        
        output = self.model(x)
        
        # 恢复原始参数
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]
        
        return output

2. 预测器(Predictor)

预测器将上下文表示映射到目标表示空间:

class Predictor(nn.Module):
    """
    JEPA预测器
    
    将上下文表示预测为目标表示
    """
    def __init__(self, input_dim, output_dim, hidden_dim=None, num_layers=4):
        super().__init__()
        
        if hidden_dim is None:
            hidden_dim = input_dim
        
        # 多层感知机预测器
        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.GELU())
        
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.GELU())
        
        layers.append(nn.Linear(hidden_dim, output_dim))
        
        self.predictor = nn.Sequential(*layers)
        
        # 可选的mask token(用于预测特定部分)
        self.mask_token = nn.Parameter(torch.randn(1, 1, input_dim))
    
    def forward(self, z_x, mask=None):
        """
        Args:
            z_x: 编码器输出的上下文表示
            mask: 可选掩码,指示要预测的位置
        """
        if mask is not None:
            # 在掩码位置使用mask token
            z_x = z_x * (1 - mask) + self.mask_token * mask
        
        return self.predictor(z_x)

I-JEPA:图像JEPA

架构详解

I-JEPA(Image JEPA)是JEPA在图像领域的应用,核心思想是从部分遮挡的图像块预测完整图像的表示2

import torch
import torch.nn as nn
from torchvision.models import vision_transformer
 
class IJEPA(nn.Module):
    """
    Image JEPA
    
    论文: "Self-supervised Visual Representation Learning with JEPA" (ICML 2023)
    """
    def __init__(
        self,
        image_size=224,
        patch_size=16,
        num_layers=12,
        num_heads=12,
        embedding_dim=768,
        predictor_depth=4,
        momentum=0.996
    ):
        super().__init__()
        
        # ViT编码器
        self.encoder = vision_transformer(
            image_size=image_size,
            patch_size=patch_size,
            num_layers=num_layers,
            num_heads=num_heads,
            hidden_dim=embedding_dim,
            mlp_dim=embedding_dim * 4
        )
        
        # 目标编码器(EMA)
        self.target_encoder = copy.deepcopy(self.encoder)
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        
        # 预测器
        self.predictor = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * 4),
            nn.GELU(),
            nn.Linear(embedding_dim * 4, embedding_dim),
            # 输出与目标编码器相同维度
        )
        
        # 用于预测的位置编码
        self.predictor_pos_embed = nn.Parameter(
            torch.randn(1, 64, embedding_dim) * 0.02
        )
        
        # Mask生成器
        self.mask_generator = MaskGenerator(
            num_patches=image_size // patch_size,
            mask_ratio=0.75
        )
    
    def forward(self, x):
        """
        前向传播
        """
        # 生成mask(上下文和目标mask)
        context_mask, target_mask = self.mask_generator()
        
        # 上下文编码
        z_x = self.encoder(x, mask=context_mask)
        
        # 预测目标表示
        # 只在目标位置进行预测
        z_pred = self.predictor(z_x + self.predictor_pos_embed[target_mask])
        
        # 目标表示(EMA编码器)
        with torch.no_grad():
            z_target = self.target_encoder(x)
            # 只取目标位置的表示
            z_target = z_target[target_mask]
        
        return z_pred, z_target
    
    def loss(self, x):
        """L2损失"""
        z_pred, z_target = self.forward(x)
        return F.mse_loss(z_pred, z_target)
    
    def update_target(self, momentum=0.996):
        """更新EMA目标编码器"""
        with torch.no_grad():
            for param_q, param_k in zip(
                self.encoder.parameters(),
                self.target_encoder.parameters()
            ):
                param_k.data.mul_(momentum).add_(
                    param_q.data, alpha=1 - momentum
                )
 
 
class MaskGenerator:
    """
    Mask生成器
    
    生成上下文mask和目标mask
    """
    def __init__(self, num_patches, mask_ratio=0.75):
        self.num_patches = num_patches * num_patches
        self.mask_ratio = mask_ratio
    
    def __call__(self):
        """
        生成mask
        
        Returns:
            context_mask: 上下文位置的mask(False表示保留)
            target_mask: 目标位置的mask(True表示预测)
        """
        # 随机采样要遮蔽的位置
        num_keep = int(self.num_patches * (1 - self.mask_ratio))
        
        # 随机打乱
        perm = torch.randperm(self.num_patches)
        
        # 目标:要被遮蔽的位置
        target_mask = torch.zeros(self.num_patches, dtype=torch.bool)
        target_mask[perm[:num_keep]] = True
        
        # 上下文:保留的位置
        context_mask = ~target_mask
        
        return context_mask, target_mask

I-JEPA的训练流程

def train_ijepa(model, dataloader, epochs=100, lr=1e-3, momentum=0.996):
    """
    I-JEPA训练流程
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        
        for batch_idx, (images,) in enumerate(dataloader):
            optimizer.zero_grad()
            
            # 计算损失
            loss = model.loss(images)
            
            # 反向传播
            loss.backward()
            optimizer.step()
            
            # 更新EMA目标编码器
            model.update_target(momentum)
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch}, Loss: {total_loss / len(dataloader):.4f}")

V-JEPA:视频JEPA

架构详解

V-JEPA(Video JEPA)是JEPA在视频领域的扩展,关注时间维度的表示学习3

class VJEPA(nn.Module):
    """
    Video JEPA
    
    用于视频表示学习
    """
    def __init__(
        self,
        spatial_patch_size=16,
        temporal_patch_size=2,
        num_frames=16,
        embedding_dim=768,
        num_layers=12,
        predictor_depth=4
    ):
        super().__init__()
        
        # 时空ViT编码器
        self.encoder = SpatioTemporalViT(
            spatial_patch_size=spatial_patch_size,
            temporal_patch_size=temporal_patch_size,
            num_frames=num_frames,
            embedding_dim=embedding_dim,
            depth=num_layers
        )
        
        # EMA目标编码器
        self.target_encoder = copy.deepcopy(self.encoder)
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        
        # 预测器
        self.predictor = VideoPredictor(
            embedding_dim=embedding_dim,
            predictor_depth=predictor_depth
        )
    
    def forward(self, video, context_mask, target_mask):
        """
        前向传播
        
        Args:
            video: 视频张量 (B, C, T, H, W)
            context_mask: 上下文时空mask
            target_mask: 目标时空mask
        """
        # 编码上下文
        z_x = self.encoder(video, mask=context_mask)
        
        # 预测目标表示
        z_pred = self.predictor(z_x, target_mask)
        
        # 目标表示
        with torch.no_grad():
            z_target = self.target_encoder(video, mask=target_mask)
        
        return z_pred, z_target
    
    def loss(self, video, context_mask, target_mask):
        """L2损失"""
        z_pred, z_target = self.forward(video, context_mask, target_mask)
        return F.mse_loss(z_pred, z_target)
 
 
class SpatioTemporalViT(nn.Module):
    """
    时空ViT编码器
    """
    def __init__(self, spatial_patch_size, temporal_patch_size, 
                 num_frames, embedding_dim, depth):
        super().__init__()
        
        # 时空patch嵌入
        self.patch_embed = SpatioTemporalPatchEmbed(
            spatial_patch_size=spatial_patch_size,
            temporal_patch_size=temporal_patch_size,
            in_channels=3,
            embed_dim=embedding_dim
        )
        
        # 时空位置编码
        num_patches_per_frame = (224 // spatial_patch_size) ** 2
        self.pos_embed = nn.Parameter(
            torch.randn(1, num_frames, num_patches_per_frame, embedding_dim) * 0.02
        )
        
        # Transformer块
        self.blocks = nn.ModuleList([
            SpatioTemporalTransformerBlock(embedding_dim)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embedding_dim)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: (B, C, T, H, W)
            mask: 可选mask
        """
        B, C, T, H, W = x.shape
        
        # Patch嵌入
        x = self.patch_embed(x)  # (B, T, num_patches, embed_dim)
        
        # 添加位置编码
        x = x + self.pos_embed[:, :T]
        
        # Flatten空间和时间
        x = x.flatten(1, 2)  # (B, T*num_patches, embed_dim)
        
        # 应用mask
        if mask is not None:
            x = x * mask.unsqueeze(-1)
        
        # Transformer块
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        
        return x
 
 
class VideoPredictor(nn.Module):
    """
    视频预测器
    
    从上下文时空patch预测目标时空patch
    """
    def __init__(self, embedding_dim, predictor_depth):
        super().__init__()
        
        # Mask token
        self.mask_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
        
        # 预测器MLP
        self.predictor = nn.ModuleList([
            nn.Sequential(
                nn.Linear(embedding_dim, embedding_dim * 4),
                nn.GELU(),
                nn.Linear(embedding_dim * 4, embedding_dim)
            )
            for _ in range(predictor_depth)
        ])
        
        self.norm = nn.LayerNorm(embedding_dim)
    
    def forward(self, z_context, target_mask):
        """
        Args:
            z_context: 上下文patch的表示
            target_mask: 目标patch的掩码
        """
        # 在目标位置插入mask token
        z = z_context.clone()
        num_patches = z.shape[1]
        
        # 生成目标位置的索引
        target_indices = torch.where(target_mask.flatten())[0]
        
        # 添加mask token
        for idx in target_indices:
            if idx < z.shape[1]:
                z[:, idx] = self.mask_token.squeeze(0)
        
        # 预测
        for layer in self.predictor:
            z = layer(z)
        
        z = self.norm(z)
        
        # 只返回目标位置
        return z[:, target_indices]

JEPA与其他方法的对比

理论对比

方面MAE/重建式SimCLR/对比式JEPA/预测式
学习目标像素重建实例判别表示预测
计算成本高(解码器)中(对比)低(无解码器)
负样本需求
语义级别
崩溃风险

实际性能对比

在ImageNet分类任务上的对比:

方法Top-1 Acc预训练 Epochs所需显存
MoCo v383.2%30016GB
DINO82.8%30016GB
MAE83.6%160016GB
I-JEPA84.1%2008GB

实践指南

超参数设置

# I-JEPA推荐配置
config = {
    'image_size': 224,
    'patch_size': 16,
    'embedding_dim': 768,
    'num_layers': 12,
    'num_heads': 12,
    'predictor_depth': 4,
    'mask_ratio': 0.75,  # 遮蔽75%的patch
    'momentum': 0.996,    # EMA动量
    'lr': 1e-3,
    'weight_decay': 0.04,
    'batch_size': 2048,
    'warmup_epochs': 10,
}

下游任务微调

class IJEPAForDownstream(nn.Module):
    """
    I-JEPA用于下游任务
    """
    def __init__(self, ijepa_model, num_classes):
        super().__init__()
        
        # 冻结编码器
        self.encoder = ijepa_model.encoder
        for param in self.encoder.parameters():
            param.requires_grad = False
        
        # 添加分类头
        self.classifier = nn.Linear(ijepa_model.encoder.embedding_dim, num_classes)
    
    def forward(self, x):
        with torch.no_grad():
            z = self.encoder(x)
            # 使用[CLS] token或平均池化
            z = z[:, 0]  # [CLS] token
        
        return self.classifier(z)

扩展与应用

1. 多模态JEPA

class MultimodalJEPA(nn.Module):
    """
    多模态JEPA
    
    支持图像-文本配对学习
    """
    def __init__(self, image_encoder, text_encoder, embedding_dim):
        super().__init__()
        
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        
        # 图像预测器
        self.image_predictor = Predictor(embedding_dim, embedding_dim)
        
        # 文本预测器
        self.text_predictor = Predictor(embedding_dim, embedding_dim)
        
        # EMA目标编码器
        self.image_target = copy.deepcopy(image_encoder)
        self.text_target = copy.deepcopy(text_encoder)
    
    def forward(self, images, texts):
        """
        交叉预测:图像预测文本,文本预测图像
        """
        # 图像 -> 文本
        z_img = self.image_encoder(images)
        text_pred = self.text_predictor(z_img)
        
        # 文本 -> 图像
        z_text = self.text_encoder(texts)
        img_pred = self.image_predictor(z_text)
        
        # 目标
        with torch.no_grad():
            text_target = self.text_target(texts)
            img_target = self.image_target(images)
        
        return {
            'img2text': (text_pred, text_target),
            'text2img': (img_pred, img_target)
        }

2. 图JEPA

class GraphJEPA(nn.Module):
    """
    图JEPA
    
    用于图表示学习
    """
    def __init__(self, gnn_encoder, embedding_dim):
        super().__init__()
        
        self.encoder = gnn_encoder
        self.target_encoder = copy.deepcopy(gnn_encoder)
        
        self.predictor = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * 4),
            nn.GELU(),
            nn.Linear(embedding_dim * 4, embedding_dim)
        )
    
    def forward(self, x, edge_index, mask_ratio=0.5):
        """
        Args:
            x: 节点特征
            edge_index: 边索引
            mask_ratio: 节点遮蔽比例
        """
        # 随机遮蔽节点
        num_nodes = x.shape[0]
        mask = torch.rand(num_nodes) < mask_ratio
        
        # 编码非遮蔽节点
        z_x = self.encoder(x, edge_index)
        
        # 预测遮蔽节点
        z_pred = self.predictor(z_x)
        
        # 目标
        with torch.no_grad():
            z_target = self.target_encoder(x, edge_index)
        
        return z_pred[mask], z_target[mask]

总结

JEPA的核心要点:

特性说明
核心思想在潜在空间预测表示,而非像素空间重建
关键组件编码器、预测器、EMA目标编码器
优势高效、无需负样本、语义级别表示
应用I-JEPA(图像)、V-JEPA(视频)、多模态

参考


相关文章

Footnotes

  1. JEPA: Joint Embedding Predictive Architecture (Meta AI, 2023)

  2. I-JEPA: Self-supervised Visual Representation Learning with JEPA (ICML 2023)

  3. V-JEPA: Video Representation Learning with JEPA (2024)