扩散模型架构演进

扩散模型的架构设计经历了从像素空间到潜在空间的演进。本文解析主流扩散模型的架构设计,包括自编码器、文本编码器、U-Net变体等核心组件。12

扩散模型架构分类

按处理空间分类

类型代表模型特点
像素空间DDPM, ADM, Imagen直接在像素级别操作,计算量大
潜在空间Stable Diffusion, Kandinsky先压缩到潜在空间,更高效
多头空间DiT, UViTTransformer-based,多tokens处理

架构演进时间线

2020: DDPM (像素空间, U-Net)
    │
2021: ADM (像素空间, ADM U-Net + classifier guidance)
    │
2022: Stable Diffusion v1 (潜在空间, CLIP文本编码器)
    │
2022: DALL-E 2 (CLIP语义空间, 先验+解码器)
    │
2022: Imagen (T5文本编码器, 超分辨率级联)
    │
2023: SDXL (潜在空间, 更大UNet + Refiner)
    │
2023: DiT (Transformer架构, 自注意力的扩散)
    │
2024: FLUX (Transformer + 更大的文本编码器)

自编码器与潜在空间

为什么需要潜在空间

像素级扩散的计算复杂度为 ,对于 的图像:

而如果使用8倍压缩的自编码器:

压缩比:约48倍!

VQ-VAE架构

Stable Diffusion使用变分量化自编码器(VQ-VAE)

class VQVAE(torch.nn.Module):
    """
    Vector Quantized VAE for Stable Diffusion
    
    核心组件:
    1. 编码器:将图像压缩到潜在空间
    2. 量化层:将连续特征映射到离散codebook
    3. 解码器:从离散code重建图像
    """
    def __init__(self, in_channels=3, hidden_channels=128, latent_channels=4, num_codes=8192):
        super().__init__()
        
        # 编码器(下采样8倍)
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, 4, stride=2, padding=1),
            ResBlock(hidden_channels),
            nn.Conv2d(hidden_channels, hidden_channels, 4, stride=2, padding=1),
            ResBlock(hidden_channels),
            nn.Conv2d(hidden_channels, hidden_channels, 4, stride=2, padding=1),
            ResBlock(hidden_channels),
            nn.Conv2d(hidden_channels, latent_channels, 3, padding=1),
        )
        
        # Codebook(可学习)
        self.codebook = nn.Embedding(num_codes, latent_channels)
        self.codebook.weight.data.uniform_(-1.0 / num_codes, 1.0 / num_codes)
        
        # 解码器(上采样8倍)
        self.decoder = nn.Sequential(
            nn.Conv2d(latent_channels, hidden_channels, 3, padding=1),
            ResBlock(hidden_channels),
            nn.ConvTranspose2d(hidden_channels, hidden_channels, 4, stride=2, padding=1),
            ResBlock(hidden_channels),
            nn.ConvTranspose2d(hidden_channels, hidden_channels, 4, stride=2, padding=1),
            ResBlock(hidden_channels),
            nn.ConvTranspose2d(hidden_channels, in_channels, 4, stride=2, padding=1),
        )
    
    def encode(self, x):
        """编码"""
        z = self.encoder(x)
        
        # 量化
        b, c, h, w = z.shape
        z_flat = z.permute(0, 2, 3, 1).contiguous().view(-1, c)
        
        # 找最近邻code
        distances = torch.cdist(z_flat, self.codebook.weight)
        indices = torch.argmin(distances, dim=1)
        
        # 量化向量
        z_q = self.codebook(indices)
        z_q = z_q.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
        
        return z_q, indices
    
    def decode(self, z_q):
        """解码"""
        return self.decoder(z_q)

KL正则化 vs VQ

两种潜在空间表示方式:

特性VQ-VAEVAE (KL正则化)
表示类型离散(Codebook)连续(高斯)
重建质量中等
训练稳定性需要EMA更新codebook较稳定
典型应用Stable Diffusion其他潜在扩散模型

Stable Diffusion v1.x使用VQ-VAE,v2.x改用自动编码器(类似VAE但更稳定)。

文本编码器

CLIP文本编码器

Stable Diffusion v1.x使用OpenCLIP的ViT-L/14文本编码器:

class CLIPTextEncoder(torch.nn.Module):
    """
    CLIP Text Encoder (Frozen)
    
    将文本映射到768维的语义空间
    """
    def __init__(self):
        super().__init__()
        # 使用预训练的OpenCLIP模型
        self.model, _, _ = open_clip.create_model_and_transforms(
            'ViT-L/14', pretrained='openai'
        )
        # 只保留文本编码器
        self.text_encoder = self.model.transformer
        self.tokenizer = open_clip.get_tokenizer('ViT-L/14')
    
    @torch.no_grad()
    def encode(self, texts):
        """
        Args:
            texts: 文本列表 ["a cat", "a dog"]
        Returns:
            text_embeddings: (batch, seq_len, 768)
        """
        # Tokenize
        text_tokens = self.tokenizer(texts).to(self.device)
        
        # 编码
        x = self.model.token_embedding(text_tokens)
        x = x + self.model.positional_embedding
        x = x.permute(1, 0, 2)
        x = self.text_encoder(x)
        x = x.permute(1, 0, 2)
        
        # 提取[EOS] token作为句子级表示
        # 或者返回完整的token序列
        return x

T5文本编码器

Imagen使用Google的T5-XXL编码器,更大更强:

class T5TextEncoder:
    """
    T5 Text Encoder for Imagen
    
    将文本映射到4096维的语义空间
    """
    def __init__(self, model_name='google/t5-xxl'):
        from transformers import T5EncoderModel, T5Tokenizer
        
        self.model = T5EncoderModel.from_pretrained(model_name)
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        
        # T5比CLIP更大更强,但推理更慢
        # google/t5-v1_1-xxl: 4.7B参数
        # google/t5-xxl: 11B参数
    
    def encode(self, texts):
        inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
        outputs = self.model(**inputs)
        return outputs.last_hidden_state  # (batch, seq_len, 4096)

文本编码器对比

编码器参数量维度特点
CLIP (ViT-L/14)428M768视觉-语言对齐好,推理快
T5-XXL11B4096文本理解更强,推理慢
T5-XXL (1.1)4.7B4096更新的T5变体
GIT340M768图文预训练
UL220B4096混合去噪目标

U-Net架构变体

基础U-Net

class UNet(torch.nn.Module):
    """
    Standard U-Net for Diffusion Models
    """
    def __init__(self, in_channels=4, out_channels=4, base_channels=320, channel_mults=(1,2,4,4)):
        super().__init__()
        self.base_channels = base_channels
        
        # 时间嵌入
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(base_channels),
            nn.Linear(base_channels, base_channels * 4),
            nn.GELU(),
            nn.Linear(base_channels * 4, base_channels)
        )
        
        # 输入卷积
        self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
        
        # 编码器(下采样)
        self.encoder_blocks = nn.ModuleList()
        self.downs = nn.ModuleList()
        
        channels = base_channels
        for i, mult in enumerate(channel_mults):
            out_ch = base_channels * mult
            for _ in range(2):  # 每层2个ResBlock
                self.encoder_blocks.append(
                    ResBlock(channels, out_ch, time_emb)
                )
                channels = out_ch
            if i != len(channel_mults) - 1:
                self.downs.append(nn.Conv2d(channels, channels, 3, stride=2, padding=1))
        
        # 瓶颈
        self.bottleneck = nn.ModuleList([
            ResBlock(channels, channels, time_emb),
            ResBlock(channels, channels, time_emb),
        ])
        
        # 解码器(上采样)
        self.decoder_blocks = nn.ModuleList()
        self.ups = nn.ModuleList()
        
        for i, mult in reversed(list(enumerate(channel_mults))):
            out_ch = base_channels * mult
            for j in range(2):
                self.decoder_blocks.append(
                    ResBlock(channels + encoder_channels.pop(), out_ch, time_emb)
                )
                channels = out_ch
            if i != 0:
                self.ups.append(nn.ConvTranspose2d(channels, channels, 4, stride=2, padding=1))
        
        # 输出卷积
        self.output_conv = nn.Sequential(
            nn.GroupNorm(32, channels),
            nn.SiLU(),
            nn.Conv2d(channels, out_channels, 3, padding=1)
        )
    
    def forward(self, x, t, cond_emb=None):
        """
        Args:
            x: 噪声图像 (batch, 4, H, W)
            t: 时间步 (batch,)
            cond_emb: 条件嵌入 (batch, d_cond)
        """
        # 时间嵌入
        t_emb = self.time_mlp(t)
        
        # 编码器特征
        h = self.input_conv(x)
        encoder_features = []
        
        for i, block in enumerate(self.encoder_blocks):
            h = block(h, t_emb, cond_emb)
            encoder_features.append(h)
            if i < len(self.downs) and i % 2 == 1:  # 下采样时机
                h = self.downs[i // 2](h)
        
        # 瓶颈
        for block in self.bottleneck:
            h = block(h, t_emb, cond_emb)
        
        # 解码器
        for i, block in enumerate(self.decoder_blocks):
            h = torch.cat([h, encoder_features.pop()], dim=1)
            h = block(h, t_emb, cond_emb)
            if i % 2 == 1 and i < len(self.ups):
                h = self.ups[i // 2](h)
        
        return self.output_conv(h)

注意力机制

现代U-Net在瓶颈层和输出层使用交叉注意力处理文本条件:

class CrossAttention(nn.Module):
    def __init__(self, d_model, d_cond, num_heads=8):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.to_q = nn.Linear(d_model, d_model)
        self.to_k = nn.Linear(d_cond, d_model)
        self.to_v = nn.Linear(d_cond, d_model)
        self.to_out = nn.Linear(d_model, d_model)
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
    
    def forward(self, x, cond):
        """
        Args:
            x: 图像特征 (B, N, C) 其中 N = H*W
            cond: 条件嵌入 (B, M, D_cond)
        """
        B, N, C = x.shape
        _, M, _ = cond.shape
        
        # LayerNorm + QKV投影
        x_norm = self.norm(x)
        q = self.to_q(x_norm)
        k = self.to_k(cond)
        v = self.to_v(cond)
        
        # 分头
        q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, M, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, M, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 注意力
        attn = F.scaled_dot_product_attention(q, k, v)
        attn = attn.transpose(1, 2).contiguous().view(B, N, C)
        
        return x + self.to_out(attn)

常用ResBlock变体

class ResBlock(nn.Module):
    """
    带时间条件注入的Residual Block
    """
    def __init__(self, in_channels, out_channels, time_emb_dim=None):
        super().__init__()
        
        self.norm1 = nn.GroupNorm(32, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        
        # 时间条件投影
        self.time_emb = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, out_channels)
        ) if time_emb_dim else None
        
        self.norm2 = nn.GroupNorm(32, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        
        # 残差连接
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
    
    def forward(self, x, t_emb=None, cond_emb=None):
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)
        
        # 添加时间条件
        if t_emb is not None:
            h = h + self.time_emb(t_emb).unsqueeze(-1).unsqueeze(-1)
        
        h = self.norm2(h)
        h = F.silu(h)
        h = self.conv2(h)
        
        return h + self.shortcut(x)

DiT: Transformer架构

核心设计

DiT(Diffusion Transformer)3 用Transformer替代U-Net:

class DiT(nn.Module):
    """
    Diffusion Transformer
    
    核心组件:
    1. Patch Embedding: 将图像patch化
    2. Transformer Blocks: 标准Transformer层
    3. Output Projection: 预测噪声/velocity
    """
    def __init__(self, 
                 img_size=32,
                 patch_size=2,
                 in_channels=4,
                 hidden_size=384,
                 depth=12,
                 num_heads=6,
                 mlp_ratio=4.0):
        super().__init__()
        self.hidden_size = hidden_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Patch Embedding
        self.x_embed = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
        
        # 时间步嵌入
        self.t_embed = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Linear(hidden_size * 4, hidden_size)
        )
        
        # 文本条件(可选)
        self.c_embed = nn.Linear(768, hidden_size) if use_condition else nn.Identity()
        
        # 位置编码(可学习)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, hidden_size))
        
        # CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
        
        # Transformer Blocks
        self.blocks = nn.ModuleList([
            DiTBlock(hidden_size, num_heads, mlp_ratio)
            for _ in range(depth)
        ])
        
        # 输出层
        self.norm_final = nn.LayerNorm(hidden_size)
        self.proj_out = nn.Linear(hidden_size, patch_size ** 2 * in_channels)
    
    def forward(self, x, t, c=None):
        """
        Args:
            x: 噪声图像 (B, C, H, W)
            t: 时间步 (B,)
            c: 文本嵌入 (B, D)
        """
        B = x.shape[0]
        
        # Patch化
        x = self.x_embed(x)  # (B, hidden_size, H/P, W/P)
        x = x.flatten(2).transpose(1, 2)  # (B, N, hidden_size)
        
        # 时间步嵌入
        t_emb = self.get_timestep_embedding(t, self.hidden_size)
        t_emb = self.t_embed(t_emb)
        
        # 文本条件
        if c is not None:
            c_emb = self.c_embed(c)
            t_emb = t_emb + c_emb
        
        # 添加CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        
        # 位置编码
        x = x + self.pos_embed
        
        # Transformer块
        for block in self.blocks:
            x = block(x, t_emb)
        
        # 预测
        x = self.norm_final(x[:, 1:])  # 去掉CLS
        x = self.proj_out(x)  # (B, N, P*P*C)
        
        # 重建图像
        h = w = int(math.sqrt(x.shape[1]))
        x = x.view(B, h, w, self.patch_size, self.patch_size, 4)
        x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
        x = x.view(B, 4, h * self.patch_size, w * self.patch_size)
        
        return x

DiT Block变体

class DiTBlock(nn.Module):
    """
    DiT Block: LayerNorm + Attention + MLP
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size)
        self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(hidden_size)
        
        mlp_hidden = int(hidden_size * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden),
            nn.GELU(),
            nn.Linear(mlp_hidden, hidden_size)
        )
    
    def forward(self, x, t_emb):
        # 自注意力
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        # MLP
        x = x + self.mlp(self.norm2(x))
        return x

DiT vs U-Net

特性U-NetDiT (Transformer)
架构类型卷积 + 跳跃连接Transformer
参数量860M (SD v1.5)675M (DiT-XL/2)
GPU内存高(跳跃连接存储)较低
扩展性有限好(Transformer可扩展)
质量很好DiT-XL/2 超过U-Net
主流应用SD v1.xSD 3.0, Sora

主流模型架构对比

Stable Diffusion系列

# Stable Diffusion v1.5 架构
SD_v15_config = {
    "model": "runwayml/stable-diffusion-v1-5",
    "vae": "stabilityai/sd-vae-ft-mse",  # VQ-VAE with MSE loss
    "text_encoder": "openai/clip-vit-large-patch14",  # CLIP ViT-L/14
    "unet": {
        "in_channels": 4,
        "out_channels": 4,
        "base_channels": 320,
        "channel_mults": (1, 2, 4, 4),
        "num_res_blocks": 2,
        "attention_resolutions": (4, 2, 1),
        "num_heads": 8,
    },
    "latent_scale_factor": 8,
}
 
# Stable Diffusion XL 架构
SDXL_config = {
    "model": "stabilityai/stable-diffusion-xl-base-1.0",
    "vae": "stabilityai/sdxl-vae",
    "text_encoder": ["openai/clip-vit-large-patch14", "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"],
    "unet": {
        "in_channels": 4,
        "out_channels": 4,
        "base_channels": 320,
        "depth": 2,
        "attention_downsample": [1, 2, 8],
    },
    "refiner": {
        "enabled": True,
        "num_steps": 20,
    }
}

架构组件对比表

模型文本编码器U-Net/DiTVAE特殊设计
SD v1.5CLIP-LU-Net 860MVQ-VAE成熟生态
SD v2.1OpenCLIP-LU-Net 865MAutoencoder更好的图像-文本对齐
SDXLCLIP-L + CLIP-GU-Net 3.5BSDXL-VAE级联Refiner
DALL-E 3GPT-4不公开不公开重标注训练
ImagenT5-XXL 11BADM U-Net 2Bx64超分辨率级联
DiT-XL/2CLIP-LDiT 675M不需要纯Transformer

高效架构设计

注意力机制优化

class Attention(nn.Module):
    """
    Flash Attention + Cross-Attention组合
    """
    def __init__(self, d_model, d_cond, num_heads=8, use_flash=True):
        super().__init__()
        self.use_flash = use_flash
        
        if use_flash:
            # Flash Attention 2
            self.attn = FlashAttention(d_model, num_heads)
        else:
            # 标准注意力
            self.attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        
        # Cross-Attention 用于文本条件
        self.norm = nn.LayerNorm(d_model)
        self.to_k = nn.Linear(d_cond, d_model, bias=False)
        self.to_v = nn.Linear(d_cond, d_model, bias=False)
    
    def forward(self, x, cond):
        """
        Args:
            x: 图像特征 (B, N, C)
            cond: 文本嵌入 (B, M, D_cond)
        """
        # 自注意力
        x = x + self.attn(self.norm(x), self.norm(x), self.norm(x))
        
        # 交叉注意力(文本条件)
        k = self.to_k(cond)
        v = self.to_v(cond)
        x = x + self.attn(self.norm(x), k, v)
        
        return x

推理优化技巧

  1. Flux 使用流匹配(Flow Matching)替代DDPM
  2. SD3 使用MMDiT(多模态DiT)
  3. SDXL Turbo 使用对抗蒸馏

参考

相关链接:扩散模型理论基础条件扩散与CFG高效采样技术

Footnotes

  1. Rombach et al., “High-Resolution Image Synthesis with Latent Diffusion Models”, CVPR 2022. https://arxiv.org/abs/2112.10752

  2. Saharia et al., “Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding”, arXiv 2022. https://arxiv.org/abs/2205.11487

  3. Peebles & Xie, “Scalable Diffusion Models with Transformers”, ICCV 2023. https://arxiv.org/abs/2212.09748