扩散模型架构演进
扩散模型的架构设计经历了从像素空间到潜在空间的演进。本文解析主流扩散模型的架构设计,包括自编码器、文本编码器、U-Net变体等核心组件。12
扩散模型架构分类
按处理空间分类
| 类型 | 代表模型 | 特点 |
|---|---|---|
| 像素空间 | DDPM, ADM, Imagen | 直接在像素级别操作,计算量大 |
| 潜在空间 | Stable Diffusion, Kandinsky | 先压缩到潜在空间,更高效 |
| 多头空间 | DiT, UViT | Transformer-based,多tokens处理 |
架构演进时间线
2020: DDPM (像素空间, U-Net)
│
2021: ADM (像素空间, ADM U-Net + classifier guidance)
│
2022: Stable Diffusion v1 (潜在空间, CLIP文本编码器)
│
2022: DALL-E 2 (CLIP语义空间, 先验+解码器)
│
2022: Imagen (T5文本编码器, 超分辨率级联)
│
2023: SDXL (潜在空间, 更大UNet + Refiner)
│
2023: DiT (Transformer架构, 自注意力的扩散)
│
2024: FLUX (Transformer + 更大的文本编码器)
自编码器与潜在空间
为什么需要潜在空间
像素级扩散的计算复杂度为 ,对于 的图像:
而如果使用8倍压缩的自编码器:
压缩比:约48倍!
VQ-VAE架构
Stable Diffusion使用变分量化自编码器(VQ-VAE):
class VQVAE(torch.nn.Module):
"""
Vector Quantized VAE for Stable Diffusion
核心组件:
1. 编码器:将图像压缩到潜在空间
2. 量化层:将连续特征映射到离散codebook
3. 解码器:从离散code重建图像
"""
def __init__(self, in_channels=3, hidden_channels=128, latent_channels=4, num_codes=8192):
super().__init__()
# 编码器(下采样8倍)
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, hidden_channels, 4, stride=2, padding=1),
ResBlock(hidden_channels),
nn.Conv2d(hidden_channels, hidden_channels, 4, stride=2, padding=1),
ResBlock(hidden_channels),
nn.Conv2d(hidden_channels, hidden_channels, 4, stride=2, padding=1),
ResBlock(hidden_channels),
nn.Conv2d(hidden_channels, latent_channels, 3, padding=1),
)
# Codebook(可学习)
self.codebook = nn.Embedding(num_codes, latent_channels)
self.codebook.weight.data.uniform_(-1.0 / num_codes, 1.0 / num_codes)
# 解码器(上采样8倍)
self.decoder = nn.Sequential(
nn.Conv2d(latent_channels, hidden_channels, 3, padding=1),
ResBlock(hidden_channels),
nn.ConvTranspose2d(hidden_channels, hidden_channels, 4, stride=2, padding=1),
ResBlock(hidden_channels),
nn.ConvTranspose2d(hidden_channels, hidden_channels, 4, stride=2, padding=1),
ResBlock(hidden_channels),
nn.ConvTranspose2d(hidden_channels, in_channels, 4, stride=2, padding=1),
)
def encode(self, x):
"""编码"""
z = self.encoder(x)
# 量化
b, c, h, w = z.shape
z_flat = z.permute(0, 2, 3, 1).contiguous().view(-1, c)
# 找最近邻code
distances = torch.cdist(z_flat, self.codebook.weight)
indices = torch.argmin(distances, dim=1)
# 量化向量
z_q = self.codebook(indices)
z_q = z_q.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
return z_q, indices
def decode(self, z_q):
"""解码"""
return self.decoder(z_q)KL正则化 vs VQ
两种潜在空间表示方式:
| 特性 | VQ-VAE | VAE (KL正则化) |
|---|---|---|
| 表示类型 | 离散(Codebook) | 连续(高斯) |
| 重建质量 | 高 | 中等 |
| 训练稳定性 | 需要EMA更新codebook | 较稳定 |
| 典型应用 | Stable Diffusion | 其他潜在扩散模型 |
Stable Diffusion v1.x使用VQ-VAE,v2.x改用自动编码器(类似VAE但更稳定)。
文本编码器
CLIP文本编码器
Stable Diffusion v1.x使用OpenCLIP的ViT-L/14文本编码器:
class CLIPTextEncoder(torch.nn.Module):
"""
CLIP Text Encoder (Frozen)
将文本映射到768维的语义空间
"""
def __init__(self):
super().__init__()
# 使用预训练的OpenCLIP模型
self.model, _, _ = open_clip.create_model_and_transforms(
'ViT-L/14', pretrained='openai'
)
# 只保留文本编码器
self.text_encoder = self.model.transformer
self.tokenizer = open_clip.get_tokenizer('ViT-L/14')
@torch.no_grad()
def encode(self, texts):
"""
Args:
texts: 文本列表 ["a cat", "a dog"]
Returns:
text_embeddings: (batch, seq_len, 768)
"""
# Tokenize
text_tokens = self.tokenizer(texts).to(self.device)
# 编码
x = self.model.token_embedding(text_tokens)
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2)
x = self.text_encoder(x)
x = x.permute(1, 0, 2)
# 提取[EOS] token作为句子级表示
# 或者返回完整的token序列
return xT5文本编码器
Imagen使用Google的T5-XXL编码器,更大更强:
class T5TextEncoder:
"""
T5 Text Encoder for Imagen
将文本映射到4096维的语义空间
"""
def __init__(self, model_name='google/t5-xxl'):
from transformers import T5EncoderModel, T5Tokenizer
self.model = T5EncoderModel.from_pretrained(model_name)
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
# T5比CLIP更大更强,但推理更慢
# google/t5-v1_1-xxl: 4.7B参数
# google/t5-xxl: 11B参数
def encode(self, texts):
inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
outputs = self.model(**inputs)
return outputs.last_hidden_state # (batch, seq_len, 4096)文本编码器对比
| 编码器 | 参数量 | 维度 | 特点 |
|---|---|---|---|
| CLIP (ViT-L/14) | 428M | 768 | 视觉-语言对齐好,推理快 |
| T5-XXL | 11B | 4096 | 文本理解更强,推理慢 |
| T5-XXL (1.1) | 4.7B | 4096 | 更新的T5变体 |
| GIT | 340M | 768 | 图文预训练 |
| UL2 | 20B | 4096 | 混合去噪目标 |
U-Net架构变体
基础U-Net
class UNet(torch.nn.Module):
"""
Standard U-Net for Diffusion Models
"""
def __init__(self, in_channels=4, out_channels=4, base_channels=320, channel_mults=(1,2,4,4)):
super().__init__()
self.base_channels = base_channels
# 时间嵌入
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(base_channels),
nn.Linear(base_channels, base_channels * 4),
nn.GELU(),
nn.Linear(base_channels * 4, base_channels)
)
# 输入卷积
self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
# 编码器(下采样)
self.encoder_blocks = nn.ModuleList()
self.downs = nn.ModuleList()
channels = base_channels
for i, mult in enumerate(channel_mults):
out_ch = base_channels * mult
for _ in range(2): # 每层2个ResBlock
self.encoder_blocks.append(
ResBlock(channels, out_ch, time_emb)
)
channels = out_ch
if i != len(channel_mults) - 1:
self.downs.append(nn.Conv2d(channels, channels, 3, stride=2, padding=1))
# 瓶颈
self.bottleneck = nn.ModuleList([
ResBlock(channels, channels, time_emb),
ResBlock(channels, channels, time_emb),
])
# 解码器(上采样)
self.decoder_blocks = nn.ModuleList()
self.ups = nn.ModuleList()
for i, mult in reversed(list(enumerate(channel_mults))):
out_ch = base_channels * mult
for j in range(2):
self.decoder_blocks.append(
ResBlock(channels + encoder_channels.pop(), out_ch, time_emb)
)
channels = out_ch
if i != 0:
self.ups.append(nn.ConvTranspose2d(channels, channels, 4, stride=2, padding=1))
# 输出卷积
self.output_conv = nn.Sequential(
nn.GroupNorm(32, channels),
nn.SiLU(),
nn.Conv2d(channels, out_channels, 3, padding=1)
)
def forward(self, x, t, cond_emb=None):
"""
Args:
x: 噪声图像 (batch, 4, H, W)
t: 时间步 (batch,)
cond_emb: 条件嵌入 (batch, d_cond)
"""
# 时间嵌入
t_emb = self.time_mlp(t)
# 编码器特征
h = self.input_conv(x)
encoder_features = []
for i, block in enumerate(self.encoder_blocks):
h = block(h, t_emb, cond_emb)
encoder_features.append(h)
if i < len(self.downs) and i % 2 == 1: # 下采样时机
h = self.downs[i // 2](h)
# 瓶颈
for block in self.bottleneck:
h = block(h, t_emb, cond_emb)
# 解码器
for i, block in enumerate(self.decoder_blocks):
h = torch.cat([h, encoder_features.pop()], dim=1)
h = block(h, t_emb, cond_emb)
if i % 2 == 1 and i < len(self.ups):
h = self.ups[i // 2](h)
return self.output_conv(h)注意力机制
现代U-Net在瓶颈层和输出层使用交叉注意力处理文本条件:
class CrossAttention(nn.Module):
def __init__(self, d_model, d_cond, num_heads=8):
super().__init__()
self.norm = nn.LayerNorm(d_model)
self.to_q = nn.Linear(d_model, d_model)
self.to_k = nn.Linear(d_cond, d_model)
self.to_v = nn.Linear(d_cond, d_model)
self.to_out = nn.Linear(d_model, d_model)
self.num_heads = num_heads
self.head_dim = d_model // num_heads
def forward(self, x, cond):
"""
Args:
x: 图像特征 (B, N, C) 其中 N = H*W
cond: 条件嵌入 (B, M, D_cond)
"""
B, N, C = x.shape
_, M, _ = cond.shape
# LayerNorm + QKV投影
x_norm = self.norm(x)
q = self.to_q(x_norm)
k = self.to_k(cond)
v = self.to_v(cond)
# 分头
q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, M, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, M, self.num_heads, self.head_dim).transpose(1, 2)
# 注意力
attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.transpose(1, 2).contiguous().view(B, N, C)
return x + self.to_out(attn)常用ResBlock变体
class ResBlock(nn.Module):
"""
带时间条件注入的Residual Block
"""
def __init__(self, in_channels, out_channels, time_emb_dim=None):
super().__init__()
self.norm1 = nn.GroupNorm(32, in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
# 时间条件投影
self.time_emb = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, out_channels)
) if time_emb_dim else None
self.norm2 = nn.GroupNorm(32, out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
# 残差连接
self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
def forward(self, x, t_emb=None, cond_emb=None):
h = self.norm1(x)
h = F.silu(h)
h = self.conv1(h)
# 添加时间条件
if t_emb is not None:
h = h + self.time_emb(t_emb).unsqueeze(-1).unsqueeze(-1)
h = self.norm2(h)
h = F.silu(h)
h = self.conv2(h)
return h + self.shortcut(x)DiT: Transformer架构
核心设计
DiT(Diffusion Transformer)3 用Transformer替代U-Net:
class DiT(nn.Module):
"""
Diffusion Transformer
核心组件:
1. Patch Embedding: 将图像patch化
2. Transformer Blocks: 标准Transformer层
3. Output Projection: 预测噪声/velocity
"""
def __init__(self,
img_size=32,
patch_size=2,
in_channels=4,
hidden_size=384,
depth=12,
num_heads=6,
mlp_ratio=4.0):
super().__init__()
self.hidden_size = hidden_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# Patch Embedding
self.x_embed = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
# 时间步嵌入
self.t_embed = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(),
nn.Linear(hidden_size * 4, hidden_size)
)
# 文本条件(可选)
self.c_embed = nn.Linear(768, hidden_size) if use_condition else nn.Identity()
# 位置编码(可学习)
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, hidden_size))
# CLS token
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
# Transformer Blocks
self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, mlp_ratio)
for _ in range(depth)
])
# 输出层
self.norm_final = nn.LayerNorm(hidden_size)
self.proj_out = nn.Linear(hidden_size, patch_size ** 2 * in_channels)
def forward(self, x, t, c=None):
"""
Args:
x: 噪声图像 (B, C, H, W)
t: 时间步 (B,)
c: 文本嵌入 (B, D)
"""
B = x.shape[0]
# Patch化
x = self.x_embed(x) # (B, hidden_size, H/P, W/P)
x = x.flatten(2).transpose(1, 2) # (B, N, hidden_size)
# 时间步嵌入
t_emb = self.get_timestep_embedding(t, self.hidden_size)
t_emb = self.t_embed(t_emb)
# 文本条件
if c is not None:
c_emb = self.c_embed(c)
t_emb = t_emb + c_emb
# 添加CLS token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# 位置编码
x = x + self.pos_embed
# Transformer块
for block in self.blocks:
x = block(x, t_emb)
# 预测
x = self.norm_final(x[:, 1:]) # 去掉CLS
x = self.proj_out(x) # (B, N, P*P*C)
# 重建图像
h = w = int(math.sqrt(x.shape[1]))
x = x.view(B, h, w, self.patch_size, self.patch_size, 4)
x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
x = x.view(B, 4, h * self.patch_size, w * self.patch_size)
return xDiT Block变体
class DiTBlock(nn.Module):
"""
DiT Block: LayerNorm + Attention + MLP
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size)
self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
self.norm2 = nn.LayerNorm(hidden_size)
mlp_hidden = int(hidden_size * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden),
nn.GELU(),
nn.Linear(mlp_hidden, hidden_size)
)
def forward(self, x, t_emb):
# 自注意力
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
# MLP
x = x + self.mlp(self.norm2(x))
return xDiT vs U-Net
| 特性 | U-Net | DiT (Transformer) |
|---|---|---|
| 架构类型 | 卷积 + 跳跃连接 | Transformer |
| 参数量 | 860M (SD v1.5) | 675M (DiT-XL/2) |
| GPU内存 | 高(跳跃连接存储) | 较低 |
| 扩展性 | 有限 | 好(Transformer可扩展) |
| 质量 | 很好 | DiT-XL/2 超过U-Net |
| 主流应用 | SD v1.x | SD 3.0, Sora |
主流模型架构对比
Stable Diffusion系列
# Stable Diffusion v1.5 架构
SD_v15_config = {
"model": "runwayml/stable-diffusion-v1-5",
"vae": "stabilityai/sd-vae-ft-mse", # VQ-VAE with MSE loss
"text_encoder": "openai/clip-vit-large-patch14", # CLIP ViT-L/14
"unet": {
"in_channels": 4,
"out_channels": 4,
"base_channels": 320,
"channel_mults": (1, 2, 4, 4),
"num_res_blocks": 2,
"attention_resolutions": (4, 2, 1),
"num_heads": 8,
},
"latent_scale_factor": 8,
}
# Stable Diffusion XL 架构
SDXL_config = {
"model": "stabilityai/stable-diffusion-xl-base-1.0",
"vae": "stabilityai/sdxl-vae",
"text_encoder": ["openai/clip-vit-large-patch14", "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"],
"unet": {
"in_channels": 4,
"out_channels": 4,
"base_channels": 320,
"depth": 2,
"attention_downsample": [1, 2, 8],
},
"refiner": {
"enabled": True,
"num_steps": 20,
}
}架构组件对比表
| 模型 | 文本编码器 | U-Net/DiT | VAE | 特殊设计 |
|---|---|---|---|---|
| SD v1.5 | CLIP-L | U-Net 860M | VQ-VAE | 成熟生态 |
| SD v2.1 | OpenCLIP-L | U-Net 865M | Autoencoder | 更好的图像-文本对齐 |
| SDXL | CLIP-L + CLIP-G | U-Net 3.5B | SDXL-VAE | 级联Refiner |
| DALL-E 3 | GPT-4 | 不公开 | 不公开 | 重标注训练 |
| Imagen | T5-XXL 11B | ADM U-Net 2B | x64 | 超分辨率级联 |
| DiT-XL/2 | CLIP-L | DiT 675M | 不需要 | 纯Transformer |
高效架构设计
注意力机制优化
class Attention(nn.Module):
"""
Flash Attention + Cross-Attention组合
"""
def __init__(self, d_model, d_cond, num_heads=8, use_flash=True):
super().__init__()
self.use_flash = use_flash
if use_flash:
# Flash Attention 2
self.attn = FlashAttention(d_model, num_heads)
else:
# 标准注意力
self.attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
# Cross-Attention 用于文本条件
self.norm = nn.LayerNorm(d_model)
self.to_k = nn.Linear(d_cond, d_model, bias=False)
self.to_v = nn.Linear(d_cond, d_model, bias=False)
def forward(self, x, cond):
"""
Args:
x: 图像特征 (B, N, C)
cond: 文本嵌入 (B, M, D_cond)
"""
# 自注意力
x = x + self.attn(self.norm(x), self.norm(x), self.norm(x))
# 交叉注意力(文本条件)
k = self.to_k(cond)
v = self.to_v(cond)
x = x + self.attn(self.norm(x), k, v)
return x推理优化技巧
- Flux 使用流匹配(Flow Matching)替代DDPM
- SD3 使用MMDiT(多模态DiT)
- SDXL Turbo 使用对抗蒸馏
参考
Footnotes
-
Rombach et al., “High-Resolution Image Synthesis with Latent Diffusion Models”, CVPR 2022. https://arxiv.org/abs/2112.10752 ↩
-
Saharia et al., “Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding”, arXiv 2022. https://arxiv.org/abs/2205.11487 ↩
-
Peebles & Xie, “Scalable Diffusion Models with Transformers”, ICCV 2023. https://arxiv.org/abs/2212.09748 ↩