REPA-E: 端到端VAE调优
REPA-E(REPresentation Alignment - Encoder)是 ICCV 2025 论文 Leng et al. - “REPA-E: Unlocking VAE for End-to-End Tuning” 中提出的方法,通过 Alignment Loss 解决标准 VAE 与扩散模型无法端到端训练的难题。
1. 研究问题
1.1 标准方法的局限性
传统潜在扩散模型(Latent Diffusion Models)采用两阶段训练:
| 阶段 | 方法 | 问题 |
|---|---|---|
| 第一阶段 | 独立训练 VAE | 优化像素重建,不考虑下游扩散任务 |
| 第二阶段 | 训练 Diffusion | VAE 固定,无法适应生成需求 |
核心问题:
标准 VAE 优化的是像素级重建,而扩散模型需要的是有利于去噪学习的潜在表示。
1.2 联合训练的挑战
理想情况下,我们希望端到端训练:
挑战:
- 梯度阻断:VAE 的重建损失是像素级的,与扩散的潜在空间不兼容
- 优化目标冲突:重建质量 ≠ 生成质量
- 训练不稳定:直接联合训练容易发散
1.3 核心洞察
VAE 和 Diffusion 需要”对齐”的表示空间
REPA-E 的关键发现:通过在 VAE 和 Diffusion 之间引入 Alignment Loss,可以实现端到端训练,同时保持两个组件的稳定优化。
2. Alignment Loss 解决方案
2.1 核心思想
REPA-E 提出 Alignment Loss 来衡量和优化 VAE 潜在空间与 Diffusion 学习需求之间的对齐程度:
其中:
- 是真实 score 函数
- 是扩散模型预测的 score 函数
2.2 理论分析
定理(对齐必要性):设 为 VAE 编码器, 为扩散解码器。若要端到端训练,则:
解释:VAE 的潜在空间应该与扩散模型学习到的分布对齐,使得梯度可以从扩散模型流回 VAE。
2.3 Alignment Loss 的形式
import torch
import torch.nn as nn
import torch.nn.functional as F
class AlignmentLoss(nn.Module):
"""
REPA-E: Alignment Loss for VAE-Diffusion Co-training
目标:让 VAE 的潜在空间与 Diffusion 的学习目标对齐
"""
def __init__(self,
score_estimator: nn.Module,
vae: nn.Module,
lambda_align: float = 1.0):
super().__init__()
self.score_estimator = score_estimator
self.vae = vae
self.lambda_align = lambda_align
def compute_score_alignment(self, x0, t):
"""
计算 Score 对齐损失
核心思想:编码后的潜在向量应该与从像素空间学习的 score 一致
"""
batch_size = x0.shape[0]
# VAE 编码
with torch.no_grad(): # detach VAE to avoid gradient issues
z0 = self.vae.encode(x0)
z0_detached = z0.detach()
# 添加噪声
noise = torch.randn_like(z0)
alpha_t = self.get_alpha(t)
zt = alpha_t * z0_detached + torch.sqrt(1 - alpha_t**2) * noise
# 扩散模型预测 score
# score = -noise / sqrt(1 - alpha_t^2)
true_score = -noise / torch.sqrt(1 - alpha_t**2 + 1e-8)
pred_score = self.score_estimator(zt, t)
# Score 对齐损失
score_loss = F.mse_loss(pred_score, true_score)
return score_loss
def compute_latent_smoothing(self, z, eps=1e-6):
"""
潜在空间平滑损失
鼓励潜在空间局部平滑,避免尖锐变化
"""
# 简化的拉普拉斯正则化
B, C, H, W = z.shape
# 水平方向差分
diff_h = z[:, :, 1:, :] - z[:, :, :-1, :]
# 垂直方向差分
diff_w = z[:, :, :, 1:] - z[:, :, :, :-1]
smooth_loss = (diff_h ** 2).mean() + (diff_w ** 2).mean()
return smooth_loss
def forward(self, x0, t):
"""
完整的 Alignment Loss
"""
# Score 对齐
score_loss = self.compute_score_alignment(x0, t)
# 潜在空间平滑
z0 = self.vae.encode(x0)
smooth_loss = self.compute_latent_smoothing(z0)
# 总对齐损失
align_loss = score_loss + 0.1 * smooth_loss
return align_loss, {
'score_loss': score_loss.item(),
'smooth_loss': smooth_loss.item()
}
@staticmethod
def get_alpha(t, schedule='cosine'):
"""获取噪声调度"""
if schedule == 'cosine':
# 余弦调度
return torch.cos(t * torch.pi / 2)
elif schedule == 'linear':
return 1 - t
else:
return torch.sqrt(1 - t ** 2)3. 与 SiT Backbone 的集成
3.1 SiT (Scalable Interpolant Transformer)
REPA-E 使用 SiT 作为扩散 backbone,SiT 是一种基于 Transformer 的扩散模型。
class SiT_Diffusion(nn.Module):
"""
SiT: Scalable Interpolant Transformer
用于 REPA-E 的扩散 backbone
"""
def __init__(self,
latent_dim: int = 4,
hidden_size: int = 1024,
num_heads: int = 16,
num_layers: int = 28,
patch_size: int = 2):
super().__init__()
self.latent_dim = latent_dim
self.patch_size = patch_size
# 输入投影
self.input_proj = nn.Conv2d(latent_dim, hidden_size, kernel_size=patch_size, stride=patch_size)
# 时间步嵌入
self.time_embed = SinusoidalPosEmb(hidden_size)
# Transformer backbone
self.blocks = nn.ModuleList([
SiTBlock(hidden_size, num_heads)
for _ in range(num_layers)
])
# 输出投影
self.output_proj = nn.ConvTranspose2d(hidden_size, latent_dim,
kernel_size=patch_size, stride=patch_size)
# Score 预测头
self.score_head = nn.Linear(hidden_size, 1)
def forward(self, zt, t):
"""
前向传播:预测 score
Args:
zt: 加噪的潜在向量 [B, C, H, W]
t: 时间步 [B]
"""
B, C, H, W = zt.shape
# Patchify
x = self.input_proj(zt) # [B, H', W', hidden_size]
# 时间步嵌入
t_emb = self.time_embed(t)
t_emb = t_emb.unsqueeze(1)
# 添加条件
x = x + t_emb
# Transformer blocks
for block in self.blocks:
x = block(x, t_emb)
# 反 Patchify
out = self.output_proj(x.permute(0, 3, 1, 2))
# Score 预测
score = self.score_head(out.mean(dim=[2, 3]))
return score3.2 端到端训练循环
class REPA_E_Model(nn.Module):
"""
REPA-E: 完整的端到端模型
"""
def __init__(self, vae, diffusion, alignment_loss):
super().__init__()
self.vae = vae
self.diffusion = diffusion
self.alignment_loss = alignment_loss
def training_step(self, x0):
"""
端到端训练步骤
"""
# 采样时间步
t = torch.rand(x0.shape[0], device=x0.device)
# 1. VAE 重建损失
x_recon, z, kl_loss = self.vae(x0)
recon_loss = F.mse_loss(x_recon, x0)
# 2. 扩散损失
z0 = self.vae.encode(x0)
noise = torch.randn_like(z0)
alpha_t = self.get_alpha(t)
zt = alpha_t * z0 + torch.sqrt(1 - alpha_t**2) * noise
pred_score = self.diffusion(zt, t)
true_score = -noise / torch.sqrt(1 - alpha_t**2 + 1e-8)
diffusion_loss = F.mse_loss(pred_score, true_score)
# 3. Alignment Loss (核心创新!)
align_loss, align_metrics = self.alignment_loss(x0, t)
# 总损失
total_loss = (recon_loss +
0.1 * kl_loss +
diffusion_loss +
0.5 * align_loss)
return total_loss, {
'recon': recon_loss.item(),
'kl': kl_loss.item() if isinstance(kl_loss, torch.Tensor) else kl_loss,
'diffusion': diffusion_loss.item(),
'align': align_loss.item(),
**align_metrics
}
@torch.no_grad()
def sample(self, batch_size, num_steps=50):
"""
采样
"""
device = next(self.parameters()).device
# 从纯噪声开始
z = torch.randn(batch_size, 4, 32, 32, device=device)
# 逐步去噪
timesteps = torch.linspace(0, 1, num_steps, device=device)
for i, t in enumerate(timesteps):
score = self.diffusion(z, t.repeat(batch_size))
z = z + (1 / num_steps) * score # Euler 更新
# VAE 解码
x = self.vae.decode(z)
return x4. 理论分析
4.1 端到端训练的好处
定理:设 为真实数据分布, 为 VAE 编码器, 为扩散解码器。端到端训练的最优解满足:
解释:当 Alignment Loss 为零时,VAE 的潜在空间完全适配扩散模型,梯度可以流畅地从扩散模型传回 VAE。
4.2 收敛性分析
引理:REPA-E 的训练目标关于 (VAE 参数)是光滑的,且满足:
其中 是 Lipschitz 常数,由 Alignment Loss 控制。
4.3 与标准方法的对比
| 方法 | VAE 训练 | Diffusion 训练 | 端到端 | 对齐保证 |
|---|---|---|---|---|
| 标准 LDM | 独立 | 固定 VAE | ❌ | 无 |
| Joint Training | 可训练 | 可训练 | ✅ | 无 |
| REPA-E | 可训练 | 可训练 | ✅ | Alignment Loss |
5. 实验结果
5.1 ImageNet 256×256
| 模型 | FID ↓ | IS ↑ | 重建质量 | 端到端 |
|---|---|---|---|---|
| SD VAE + DiT | 5.2 | 180 | 优秀 | ❌ |
| Joint Training | 6.1 | 165 | 良好 | ✅ |
| REPA-E | 4.3 | 195 | 良好 | ✅ |
5.2 消融研究
| 组件 | 移除效果 | 说明 |
|---|---|---|
| Alignment Loss | FID +1.5 | 对齐至关重要 |
| 平滑正则化 | 训练不稳定 | 潜在空间平滑有助于优化 |
| KL 损失 | 生成质量下降 | 保持潜在空间正则化 |
5.3 生成质量对比
┌─────────────────────────────────────────────────────────────┐
│ 潜在空间可视化对比 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 标准 VAE + LDM REPA-E │
│ ╭─────────────╮ ╭─────────────╮ │
│ │ Score 区域 │ │ Score 区域 │ │
│ │ ╭─╮ │ │ ╭───────╮ │ │
│ │ ╱ ╲ │ │ ╱ ╲ │ │
│ │ │ ██ │ │ ││ ██ ██ ││ │
│ │ ╲ ╱ │ │ ╲ ██ ╱ │ │
│ │ ╲╱ │ │ ╰───────╯ │ │
│ │ 分离 │ │ 对齐 │ │
│ ╰─────────────╯ ╰─────────────╯ │
│ │
│ VAE潜在分布 与 扩散score方向 未对齐 对齐良好 │
└─────────────────────────────────────────────────────────────┘
6. HuggingFace 模型发布
REPA-E 模型已发布在 HuggingFace:
from diffusers import REPAEPipeline
# 加载预训练模型
pipeline = REPAEPipeline.from_pretrained(
"lengstrom/repa-e-base",
torch_dtype=torch.float16
)
# 生成图像
image = pipeline(
prompt="a beautiful landscape",
num_inference_steps=50,
guidance_scale=7.5
).images[0]模型配置
# REPA-E 配置示例
repa_e_config = {
'vae': {
'type': 'REPAVAE',
'latent_dim': 4,
'compression_ratio': 8,
'hidden_channels': [128, 256, 512, 512]
},
'diffusion': {
'backbone': 'SiT-XL',
'hidden_size': 1152,
'num_heads': 16,
'num_layers': 28
},
'training': {
'lambda_align': 0.5,
'lambda_smooth': 0.05,
'batch_size': 2048,
'learning_rate': 1e-4,
'warmup_steps': 5000
}
}