Transfusion:统一多模态模型——下一个token预测与图像扩散

概述

Transfusion 是 ICLR 2025 的一篇 Oral 论文,提出了一种训练多模态模型的统一方法。该方法将语言建模(下一个token预测)与扩散模型相结合,用单一Transformer架构处理离散(文本)和连续(图像)数据的混合模态序列。1

Transfusion 的核心创新在于:为不同模态使用不同的损失函数——语言建模损失用于文本,下一代扩散(DDPM)损失用于图像,同时保持共享的数据和参数。这种方法避免了将图像量化为离散token所带来的信息损失,同时保留了语言模型和扩散模型各自的优势。

在 7B 参数规模、2T 多模态token的训练规模下,Transfusion 实现了与同规模纯扩散模型和语言模型相当的性能,展现了统一多模态模型的巨大潜力。


1. 问题背景:多模态模型的两大范式

1.1 离散与连续数据的本质差异

多模态生成模型需要处理两类本质不同的数据:

数据类型示例建模方法代表模型
离散数据文本、代码语言建模(下一个token预测)GPT、LLaMA
连续数据图像、音频、视频扩散模型(去噪)Stable Diffusion、DDPM

语言模型通过最大化似然概率 来预测下一个离散token,而扩散模型则通过学习去噪过程来生成连续向量。这种本质差异使得统一建模变得困难。1

1.2 现有方法的局限性

将语言模型和扩散模型结合的现有方法主要分为三类:

(1)工具调用式

将预训练的扩散模型作为语言模型的外部工具调用。这种方法需要两个独立模型,跨模态交互受限。

# 工具调用式的简化示意
text_output = language_model.generate(input_text)
if need_image:
    image_output = diffusion_model.generate(text_output)

(2)嫁接式

将预训练扩散模型嫁接到语言模型上,通过特征对齐实现多模态交互。但训练复杂且存在架构兼容问题。

(3)量化式(Chameleon等)

将连续图像数据量化为离散token,然后用标准语言模型处理。这种方法简化了架构,但不可避免地丢失信息——量化本身就是一种有损压缩。1

1.3 Transfusion的核心思想

能否用单一模型无缝生成离散和连续模态,无需量化,不损失信息?

Transfusion 给出了肯定的答案:通过在混合模态序列上组合语言建模和扩散目标函数,实现两种模态的完全整合。


2. Transfusion方法详解

2.1 数据表示

文本数据

文本字符串被tokenize为来自固定词表的离散token序列,每个token表示为一个整数:

图像数据

图像使用预训练的变分自编码器(VAE)编码为潜在patch表示:

  1. VAE编码:将图像编码为低维潜在表示(通常为8×8像素patch对应一个8维向量)
  2. Patch化:将潜在表示按从左到右、从上到下的顺序组织为patch向量序列

混合模态序列

对于混合模态样本,使用特殊的BOI(Begin of Image)EOI(End of Image) token包裹每个图像序列:

[文本token] [BOI] [patch_1] [patch_2] ... [patch_n] [EOI] [更多文本token]

这样得到一个可能同时包含离散元素(文本token整数)和连续元素(图像patch向量)的统一序列。

2.2 模型架构

Transfusion的模型架构包括:

核心Transformer

模型参数的绝大部分属于一个单一的Transformer,处理所有模态的序列。Transformer接收 中的高维向量序列作为输入,输出相似的高维向量。

采用 Llama 风格的 transformer 块,包含:

  • SwiGLU激活函数
  • 旋转位置编码(RoPE)

模态特定的轻量级组件

为将数据转换到Transformer空间,使用轻量级的模态特定组件(参数不共享):

模态输入处理输出处理
文本嵌入矩阵(整数→向量)线性层(向量→词表分布)
图像Patch化层 + 时间步嵌入反Patch化层

图像Patch编码/解码层

Transfusion实验了两种将 patch向量压缩为单个Transformer向量(以及反向)的方法:

(1)简单线性层

对于每个patch向量,在输入线性层之前添加时间步嵌入

(2)U-Net上下块

采用SDXL中U-Net的降采样和上采样块(0.27B参数),可大幅减少图像patch数量,同时保持较小性能损失:

使用U-Net块可将每张图像压缩至仅16个patch,潜在减少服务成本高达64倍

2.3 Transfusion注意力机制

关键挑战

  • 语言模型通常使用因果掩码,避免未来token的信息泄露
  • 图像通常使用双向(无限制)注意力,因为图像本身没有自然顺序

Transfusion的解决方案

Transfusion创新性地结合了两种注意力模式:

对序列中的所有元素应用因果注意力,对每个图像内部的元素应用双向注意力。

这使得:

  • 每个图像patch可以 attending 到同一图像的所有其他patch
  • 只能 attending 到序列中之前出现的文本或其他图像的patch
def transfusion_attention_mask(seq_elements, image_boundaries):
    """
    构建Transfusion注意力掩码
    seq_elements: 序列中的元素列表
    image_boundaries: 记录每个图像的起止位置
    """
    n = len(seq_elements)
    mask = torch.zeros(n, n, dtype=torch.bool)
    
    # 因果掩码:所有元素只能看向过去
    mask = torch.tril(mask)
    
    # 图像内部双向:同一图像的patch可以互相看
    for start, end in image_boundaries:
        mask[start:end, start:end] = True
    
    return mask

2.4 训练目标

Transfusion的总体损失函数为两个损失的加权和:

语言建模损失(文本)

对文本token应用标准的下一个token预测损失:

扩散损失(图像)

对图像patch应用DDPM去噪损失。首先对输入潜在图像 按扩散过程添加噪声:

然后计算噪声预测损失:

其中 是文本条件(caption), 是扩散时间步。

噪声调度

采用常用的余弦调度:

关键设计

  • 图像级扩散损失:每个图像可能跨越序列中的多个元素(多个patch),但扩散损失在图像级别计算
  • 噪声后patch化:在patch化之前对潜在图像添加噪声,因此下游token在训练时对带噪声的图像进行条件化

2.5 推理过程

Transfusion的解码算法同样在两种模式之间切换:

LM模式

遵循标准语言模型实践,从预测的分布中逐token采样。

扩散模式

当采样到BOI token时,解码切换到扩散模式:

  1. 初始化:将纯噪声 作为 个图像patch附加到输入序列
  2. 迭代去噪:对于每个时间步
    • 模型预测噪声
    • 根据噪声调度,移除相应比例的预测噪声
    • 产生
  3. 模式切换:扩散过程结束后,附加EOI token,切换回LM模式

这种算法支持任意文本和图像模态的混合生成。


3. 完整PyTorch实现

以下是一个简化但完整的Transfusion模型实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List
import math
 
# ============ 配置参数 ============
class TransfusionConfig:
    def __init__(self,
                 vocab_size: int = 32000,
                 hidden_dim: int = 4096,
                 num_heads: int = 32,
                 num_layers: int = 32,
                 patch_size: int = 2,          # VAE latent中的patch大小
                 num_patches: int = 256,        # 每个图像的patch数
                 latent_dim: int = 16,          # VAE潜在空间维度
                 num diffusion_steps: int = 1000,
                 max_seq_len: int = 8192,
                 use_unet_blocks: bool = False):
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.latent_dim = latent_dim
        self.num_diffusion_steps = num_diffusion_steps
        self.max_seq_len = max_seq_len
        self.use_unet_blocks = use_unet_blocks
 
 
# ============ 模态特定组件 ============
class TextEmbedding(nn.Module):
    """文本嵌入层和输出层"""
    def __init__(self, config: TransfusionConfig):
        super().__init__()
        self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
        self.output_projection = nn.Linear(config.hidden_dim, config.vocab_size)
        
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        """token_ids: (batch, seq_len) 整数token ID"""
        return self.token_embedding(token_ids)
    
    def output(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """将隐藏状态投影回词表分布"""
        return self.output_projection(hidden_states)
 
 
class PatchifyLayer(nn.Module):
    """将图像patch转换为Transformer向量"""
    def __init__(self, config: TransfusionConfig):
        super().__init__()
        self.config = config
        self.patch_proj = nn.Linear(config.latent_dim, config.hidden_dim)
        self.time_embed = nn.Embedding(config.num_diffusion_steps, config.hidden_dim)
        
        if config.use_unet_blocks:
            # 使用U-Net块进行更高级的patch处理
            self.unet_down = UNetDownBlock(config.latent_dim, config.hidden_dim)
            self.unet_up = UNetUpBlock(config.hidden_dim, config.latent_dim)
        
    def encode(self, image_patches: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
        """
        image_patches: (batch, num_patches, latent_dim) VAE编码的patch
        timesteps: (batch,) 扩散时间步
        """
        # 添加时间步嵌入
        t_embed = self.time_embed(timesteps)  # (batch, hidden_dim)
        h = self.patch_proj(image_patches)     # (batch, num_patches, hidden_dim)
        
        # broadcast时间步嵌入到所有patch
        h = h + t_embed.unsqueeze(1)
        return h
    
    def decode(self, hidden_states: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
        """
        hidden_states: (batch, num_patches, hidden_dim) Transformer输出
        返回: (batch, num_patches, latent_dim) 重建的patch
        """
        return self.patch_proj_inv(hidden_states)
 
 
class UNetDownBlock(nn.Module):
    """U-Net下采样块"""
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_dim, out_dim, 3, padding=1)
        self.conv2 = nn.Conv2d(out_dim, out_dim, 3, padding=1)
        self.norm1 = nn.LayerNorm(out_dim)
        self.norm2 = nn.LayerNorm(out_dim)
        self.downsample = nn.Conv2d(out_dim, out_dim, 3, stride=2, padding=1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, num_patches, dim) -> (B, dim, h, w)
        B, N, D = x.shape
        h = int(math.sqrt(N))
        x = x.transpose(1, 2).reshape(B, D, h, h)
        
        x = self.conv1(F.silu(self.norm1(x.transpose(2, 3)).transpose(2, 3)))
        x = self.conv2(F.silu(self.norm2(x.transpose(2, 3)).transpose(2, 3)))
        x = self.downsample(x)
        
        return x.flatten(2).transpose(1, 2)
 
 
class UNetUpBlock(nn.Module):
    """U-Net上采样块"""
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_dim, out_dim, 3, padding=1)
        self.conv2 = nn.Conv2d(out_dim, out_dim, 3, padding=1)
        self.norm1 = nn.LayerNorm(out_dim)
        self.norm2 = nn.LayerNorm(out_dim)
        self.upsample = nn.ConvTranspose2d(in_dim, out_dim, 4, stride=2, padding=1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, D = x.shape
        h = int(math.sqrt(N))
        x = x.transpose(1, 2).reshape(B, D, h, h)
        
        x = self.upsample(x)
        x = self.conv1(F.silu(self.norm1(x.transpose(2, 3)).transpose(2, 3)))
        x = self.conv2(F.silu(self.norm2(x.transpose(2, 3)).transpose(2, 3)))
        
        return x.flatten(2).transpose(1, 2)
 
 
# ============ Transfusion注意力 ============
class TransfusionAttention(nn.Module):
    """支持模态特定注意力的多头注意力"""
    def __init__(self, config: TransfusionConfig):
        super().__init__()
        self.hidden_dim = config.hidden_dim
        self.num_heads = config.num_heads
        self.head_dim = config.hidden_dim // config.num_heads
        
        self.qkv = nn.Linear(config.hidden_dim, 3 * config.hidden_dim)
        self.proj = nn.Linear(config.hidden_dim, config.hidden_dim)
        
    def forward(self, x: torch.Tensor, 
                image_boundaries: List[Tuple[int, int]]) -> torch.Tensor:
        """
        x: (batch, seq_len, hidden_dim)
        image_boundaries: [(start, end), ...] 每个图像在序列中的边界
        """
        B, N, D = x.shape
        
        # 计算QKV
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(2)  # 各 (B, N, num_heads, head_dim)
        
        # 重新排列为 (B, num_heads, N, head_dim)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # 构建Transfusion注意力掩码
        attn_mask = self._build_transfusion_mask(N, image_boundaries)
        
        # 计算注意力分数
        scale = math.sqrt(self.head_dim)
        attn = (q @ k.transpose(-2, -1)) / scale
        
        # 应用掩码
        if attn_mask is not None:
            attn = attn.masked_fill(~attn_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
        
        attn = F.softmax(attn, dim=-1)
        
        # 应用注意力
        out = attn @ v  # (B, num_heads, N, head_dim)
        out = out.transpose(1, 2).reshape(B, N, D)
        return self.proj(out)
    
    def _build_transfusion_mask(self, N: int, 
                               image_boundaries: List[Tuple[int, int]]) -> Optional[torch.Tensor]:
        """构建Transfusion特定的注意力掩码"""
        # 基础因果掩码
        causal_mask = torch.tril(torch.ones(N, N, dtype=torch.bool, device='cuda'))
        
        # 图像内部双向掩码
        for start, end in image_boundaries:
            causal_mask[start:end, start:end] = True
        
        return causal_mask
 
 
# ============ Transformer块 ============
class TransfusionBlock(nn.Module):
    """Transfusion Transformer块"""
    def __init__(self, config: TransfusionConfig):
        super().__init__()
        self.attention = TransfusionAttention(config)
        self.ffn = nn.Sequential(
            nn.Linear(config.hidden_dim, config.hidden_dim * 4),
            nn.SiLU(),
            nn.Linear(config.hidden_dim * 4, config.hidden_dim)
        )
        self.attention_norm = nn.RMSNorm(config.hidden_dim)
        self.ffn_norm = nn.RMSNorm(config.hidden_dim)
        
    def forward(self, x: torch.Tensor, image_boundaries: List[Tuple[int, int]]) -> torch.Tensor:
        x = x + self.attention(self.attention_norm(x), image_boundaries)
        x = x + self.ffn(self.ffn_norm(x))
        return x
 
 
# ============ 主模型 ============
class TransfusionModel(nn.Module):
    """完整的Transfusion模型"""
    def __init__(self, config: TransfusionConfig):
        super().__init__()
        self.config = config
        
        # 文本组件
        self.text_embed = TextEmbedding(config)
        
        # 图像组件
        self.patchify = PatchifyLayer(config)
        
        # Transformer主体
        self.blocks = nn.ModuleList([
            TransfusionBlock(config) for _ in range(config.num_layers)
        ])
        self.final_norm = nn.RMSNorm(config.hidden_dim)
        
        # 特殊token
        self.boi_token = nn.Parameter(torch.randn(config.hidden_dim))
        self.eoi_token = nn.Parameter(torch.randn(config.hidden_dim))
        
        # 扩散组件(用于推理)
        self._init_diffusion_weights()
        
    def _init_diffusion_weights(self):
        """初始化扩散相关权重"""
        # 在实际实现中,这里会初始化VAE和解码器
        pass
    
    def forward(self, 
                input_ids: Optional[torch.Tensor] = None,
                image_patches: Optional[torch.Tensor] = None,
                image_boundaries: Optional[List[List[Tuple[int, int]]]] = None,
                timesteps: Optional[torch.Tensor] = None,
                is_training: bool = True):
        """
        前向传播
        input_ids: (batch, text_len) 文本token ID
        image_patches: (batch, num_images, num_patches, latent_dim) 图像patch
        image_boundaries: 每个样本中图像在序列中的边界
        timesteps: (batch,) 扩散时间步(仅用于图像)
        """
        hidden_states = []
        boundaries = []  # 记录每个序列元素的模态边界
        
        # 处理文本
        if input_ids is not None:
            text_hidden = self.text_embed(input_ids)
            hidden_states.append(text_hidden)
            boundaries.extend([('text', i) for i in range(text_hidden.shape[1])])
        
        # 处理图像
        if image_patches is not None:
            B = image_patches.shape[0]
            for b in range(B):
                # 添加BOI token
                hidden_states.append(self.boi_token.unsqueeze(0))
                boundaries.append(('boi', len(hidden_states) - 1))
                
                # 添加图像patch
                patch_hidden = self.patchify.encode(
                    image_patches[b], timesteps[b]
                )
                start_pos = len(hidden_states)
                hidden_states.append(patch_hidden)
                boundaries.append(('image', (start_pos, start_pos + patch_hidden.shape[1])))
                
                # 添加EOI token
                hidden_states.append(self.eoi_token.unsqueeze(0))
                boundaries.append(('eoi', len(hidden_states) - 1))
        
        # 合并所有隐藏状态
        hidden_states = torch.cat(hidden_states, dim=1)
        
        # 通过Transformer块
        for block in self.blocks:
            hidden_states = block(hidden_states, image_boundaries)
        
        hidden_states = self.final_norm(hidden_states)
        
        # 分离文本和图像输出
        text_logits = None
        image_hidden = None
        
        return {
            'hidden_states': hidden_states,
            'text_logits': text_logits,
            'image_hidden': image_hidden
        }
 
 
# ============ 训练损失计算 ============
class TransfusionLoss(nn.Module):
    """Transfusion损失函数"""
    def __init__(self, config: TransfusionConfig):
        super().__init__()
        self.config = config
        self.lm_loss = nn.CrossEntropyLoss(ignore_index=-100)
        
    def compute_lm_loss(self, text_logits: torch.Tensor, 
                       target_ids: torch.Tensor) -> torch.Tensor:
        """计算语言建模损失"""
        # text_logits: (batch, seq_len, vocab_size)
        # target_ids: (batch, seq_len)
        shift_logits = text_logits[:, :-1, :].contiguous()
        shift_labels = target_ids[:, 1:].contiguous()
        
        return self.lm_loss(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1)
        )
    
    def compute_diffusion_loss(self, 
                              noise: torch.Tensor,
                              predicted_noise: torch.Tensor) -> torch.Tensor:
        """计算扩散损失"""
        return F.mse_loss(noise, predicted_noise)
    
    def forward(self, model_output, targets):
        """总体损失 = LM损失 + λ × DDPM损失"""
        lm_loss = self.compute_lm_loss(
            model_output['text_logits'], 
            targets['text_labels']
        )
        diff_loss = self.compute_diffusion_loss(
            targets['noise'],
            model_output['predicted_noise']
        )
        
        # λ默认设为1.0
        total_loss = lm_loss + self.config.lambda_ddpm * diff_loss
        return total_loss, {'lm_loss': lm_loss, 'diff_loss': diff_loss}
 
 
# ============ 推理/生成 ============
class TransfusionSampler:
    """Transfusion推理采样器"""
    def __init__(self, model: TransfusionModel, vae, config: TransfusionConfig):
        self.model = model
        self.vae = vae
        self.config = config
        
    @torch.no_grad()
    def generate(self, prompt: str, num_inference_steps: int = 50,
                 guidance_scale: float = 7.5) -> torch.Tensor:
        """
        生成图像
        prompt: 文本提示
        num_inference_steps: 扩散采样步数
        """
        self.model.eval()
        
        # 编码文本提示
        text_tokens = self.model.text_embed.tokenizer.encode(prompt)
        text_ids = torch.tensor([text_tokens], device='cuda')
        
        # 在LM模式下生成文本
        generated_tokens = self._sample_lm(text_ids)
        
        # 检查是否生成了BOI token
        boi_idx = self._find_token(generated_tokens, '<BOI>')
        
        if boi_idx is not None:
            # 切换到扩散模式
            image_tokens = self._sample_diffusion(
                context_tokens=generated_tokens[:boi_idx+1],
                num_steps=num_inference_steps,
                guidance_scale=guidance_scale
            )
            return image_tokens
        
        return generated_tokens
    
    def _sample_lm(self, tokens: torch.Tensor, 
                   max_new_tokens: int = 100) -> torch.Tensor:
        """从语言模型采样"""
        for _ in range(max_new_tokens):
            logits = self.model(input_ids=tokens)['text_logits']
            next_token = F.softmax(logits[:, -1, :] / 0.8, dim=-1).argmax()
            tokens = torch.cat([tokens, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
            
            if next_token.item() == self.model.text_embed.tokenizer.eos_id:
                break
        
        return tokens
    
    def _sample_diffusion(self, context_tokens: torch.Tensor,
                         num_steps: int, guidance_scale: float):
        """从扩散模型采样"""
        # 从纯噪声开始
        B = 1
        num_patches = self.config.num_patches
        latent_dim = self.config.latent_dim
        
        x_t = torch.randn(B, num_patches, latent_dim, device='cuda')
        
        # 扩散去噪
        for t in reversed(range(num_steps)):
            # 预测噪声
            noise_pred = self.model(
                input_ids=context_tokens,
                image_patches=x_t.unsqueeze(1),
                timesteps=torch.tensor([t], device='cuda')
            )['predicted_noise']
            
            # 分类器自由引导
            # (简化实现,实际需要条件/无条件预测的对比)
            
            # 去噪步骤
            alpha_t = self._get_alpha(t)
            x_t = (x_t - (1 - alpha_t) * noise_pred) / alpha_t
        
        # VAE解码
        image = self.vae.decode(x_t)
        return image
    
    def _get_alpha(self, t: int) -> float:
        """获取余弦调度的alpha值"""
        # 简化的余弦调度
        return math.cos(t / self.config.num_diffusion_steps * math.pi / 2)

4. 实验结果与分析

4.1 与Chameleon(量化方法)的对比

Transfusion的核心实验是与Chameleon方法的控制变量对比。Chameleon将图像量化为离散token,然后使用标准语言模型处理。

计算效率对比

指标TransfusionChameleon效率提升
文本→图像生成
FID达到相同水平1x计算量~3x计算量3倍
CLIP达到相同水平1x计算量~3x计算量3倍
图像→文本生成
达到相同准确率21.8% FLOPs100% FLOPs4.6倍
文本→文本生成
困惑度达到相同水平50-60% FLOPs100% FLOPs~2倍

关键发现

当控制FLOPs时,Transfusion实现约2倍更低的FID分数。

这表明Transfusion在同等计算资源下能更有效地学习图像生成。

4.2 7B模型的结果

训练配置

参数
模型规模7B参数
U-Net上/下块0.27B参数
总参数~7.27B
训练token1T文本 + 1T图像patch
图像数据~692M图像及对应caption
数据混合50%文本 + 50%图像

基准测试结果

基准Transfusion 7B对比模型
GenEval超越DALL-E 2、SDXL图像生成指令遵循
文本基准与LLaMA 1相当证明统一建模不损害文本能力
图像质量与同规模扩散模型相当FID、CLIP分数

生成样例

论文展示了7B模型生成的多种图像,包括:

  • 物体组合(如”牛油果椅子”)
  • 场景生成(如”面包、苹果、刀”)
  • 风格化图像(如”分形艺术”)
  • 文本渲染(如”手握蓝色T恤”)

4.3 消融实验关键发现

双向注意力的重要性

实验表明,图像内部的的双向注意力对模型性能至关重要:

  • 移除图像内部双向注意力(替换为因果注意力)会损害文本→图像生成
  • 这证实了图像patch之间的全连接建模是必要的

U-Net块的效果

添加U-Net上/下块可以:

  • 将图像patch压缩至16个(原始可能需要1024个)
  • 服务成本潜在减少高达64倍
  • 性能损失相对较小

5. 技术细节与设计考量

5.1 为什么Transfusion优于量化方法

信息论角度

量化方法(如Chameleon)将连续图像压缩为离散token,存在信息瓶颈

其中 是码本的信息熵上限。

Transfusion直接保留连续表示,避免了这种信息损失。

表征空间角度

  • 量化:每个图像patch被强制映射到码本中的某个向量,引入重建误差
  • Transfusion:每个patch保留完整连续表示,可以精确重建

5.2 混合注意力掩码的实现

Transfusion的注意力掩码设计是其核心创新之一。考虑一个混合序列:

[文本1] [文本2] [BOI] [patch1] [patch2] [patch3] [EOI] [文本3]

其注意力掩码结构如下:

t1t2BOIp1p2p3EOIt3
t1
t2
BOI
p1
p2
p3
EOI
t3

其中:

  • = 可 attending(因果或图像内部双向)
  • = 不可 attending(违反因果或跨图像)

5.3 扩散时间步的融合

Transfusion将扩散时间步 融入模型的方式:

  1. 文本侧:时间步嵌入通过交叉注意力或特征相加传递给文本部分
  2. 图像侧:在patch编码时直接添加到每个patch向量

这种设计允许模型根据当前的噪声水平调整处理策略。


6. 与相关工作的比较

6.1 与纯扩散模型比较

方面纯扩散模型Transfusion
图像生成质量✓ 最优✓ 相当
文本生成能力✗ 无✓ 支持
架构统一性✗ 分离的LM+DM✓ 单一模型
模态交互✗ 外部拼接✓ 内部融合

6.2 与自回归图像模型比较

方面自回归图像模型(Chameleon等)Transfusion
表示方式离散token连续向量
信息保留✗ 量化损失✓ 无损失
计算效率✗ 长序列AR✓ 并行图像处理
与语言模型的兼容性✓ 统一token空间✓ 保留各自优势

6.3 与多模态融合方法比较

方面工具调用/嫁接方法Transfusion
模型数量多个独立模型单一统一模型
跨模态交互✗ 受限✓ 深度融合
训练复杂度中等
推理一致性✗ 模态间差异✓ 统一框架

7. 局限性与发展方向

7.1 当前局限性

  1. 训练复杂度:同时优化两个不同的目标函数需要仔细的平衡
  2. 推理延迟:扩散解码通常比自回归生成慢
  3. 超参数敏感 系数的选择对性能有显著影响

7.2 未来研究方向

  1. Flow Matching替代:用flow-matching替代扩散可能简化训练
  2. 更多模态扩展:将Transfusion扩展到音频、视频等其他连续模态
  3. 高效推理:开发更快的扩散解码策略
  4. 条件引导改进:探索更先进的CFG策略

7.3 扩展到更多模态的理论分析

Transfusion的核心思想——为不同数据模态使用最适合同态分布的损失函数——具有广泛的适用性。理论上,任何包含离散和连续成分的混合数据都可以采用类似方法:

音频模态

音频数据本质上是连续的,可以通过扩散或Flow Matching建模:

# 音频处理的Transfusion扩展
class AudioPatchify(nn.Module):
    """将音频转换为patch序列"""
    def __init__(self, sample_rate=16000, patch_duration=0.02):
        super().__init__()
        self.patch_size = int(sample_rate * patch_duration)  # 20ms patches
        self.hop_size = int(sample_rate * patch_duration // 2)  # 10ms hop
        
    def forward(self, audio: torch.Tensor) -> torch.Tensor:
        """
        audio: (batch, seq_len) 原始音频波形
        返回: (batch, num_patches, feat_dim) 频谱图patch
        """
        # STFT变换
        stft = torch.stft(
            audio, 
            n_fft=512, 
            hop_length=self.hop_size,
            window=torch.hann_window(512)
        )
        # 提取幅度和相位
        magnitude = torch.sqrt(stft[..., 0]**2 + stft[..., 1]**2)
        return magnitude.transpose(1, 2)  # (batch, time, freq)

视频模态

视频可以视为图像的时间扩展,可以直接在Transfusion框架中处理:

  1. 将视频帧视为图像序列
  2. 每个帧内部使用扩散损失
  3. 帧之间使用时间注意力(可加入额外的注意力机制)
  4. 文本caption作为全局条件

7.4 训练稳定性与调优技巧

Transfusion的训练涉及同时优化两个不同性质的目标函数,以下是关键的调优经验:

损失平衡

参数控制两个损失的相对重要性:

效果
文本能力优先,图像生成较弱
平衡训练(论文默认值)
图像质量优先,可能损害文本

学习率调度

  • 预热阶段:建议使用学习率预热,避免早期梯度不稳定
  • 余弦衰减:训练后期使用余弦衰减,平衡两个目标的收敛
  • 分别调度:可考虑对文本组件和图像组件使用不同的学习率

批处理策略

为保证训练稳定性,建议:

  • 每个batch中同时包含文本和图像样本
  • 或按固定比例(如1:1)交替训练不同模态
  • 避免模态不平衡导致的灾难性遗忘

8. 数学形式化与理论分析

8.1 统一损失函数的数学框架

Transfusion的训练目标可以形式化为:

表示文本数据分布, 表示图像数据分布。模型的参数 通过优化以下目标学习:

其中:

  • 是标准的语言建模损失
  • 是扩散去噪损失
  • 是平衡系数

损失梯度的期望

假设文本和图像数据按比例 采样,则参数的梯度为:

这意味着在每个训练步骤中,模型同时接收来自两个目标的梯度信号,实现了模态间的隐式对齐。

8.2 注意力机制的理论保证

因果注意力的表达能力

Transfusion中的因果注意力保证了:

  1. 无信息泄露:对于任意时间步 ,模型无法从 获取信息
  2. 自回归建模:可以正确计算

图像内部双向注意力的必要性

对于图像数据,由于不存在自然的因果顺序,使用双向注意力:

  1. 捕获空间依赖:每个patch可以 attending 到所有其他patch
  2. 建模全局特征:可以学习图像级别的全局结构
  3. 保留等变性:避免引入不自然的空间偏差

数学表达

设图像patch序列为 ,则双向注意力允许:

而因果注意力要求:


9. 实践建议与最佳实践

9.1 实现要点

数据预处理

class TransfusionDataset(torch.utils.data.Dataset):
    """Transfusion训练数据集"""
    def __init__(self, text_data, image_data, vae, tokenizer):
        self.text_data = text_data
        self.image_data = image_data
        self.vae = vae
        self.tokenizer = tokenizer
        
    def __getitem__(self, idx):
        # 随机选择文本或图像样本
        if random.random() < 0.5:
            # 文本样本
            text = self.text_data[idx]
            tokens = self.tokenizer.encode(text)
            return {
                'modality': 'text',
                'input_ids': tokens,
                'labels': tokens
            }
        else:
            # 图像样本
            image = self.image_data[idx]
            # VAE编码
            with torch.no_grad():
                latent = self.vae.encode(image)
            # 添加噪声
            t = random.randint(0, self.num_timesteps - 1)
            noise = torch.randn_like(latent)
            noisy_latent = self._add_noise(latent, noise, t)
            
            return {
                'modality': 'image',
                'noisy_latent': noisy_latent,
                'timestep': t,
                'noise': noise
            }

训练循环

def train_step(model, batch, optimizer, config):
    """单个训练步骤"""
    if batch['modality'] == 'text':
        # 语言建模损失
        logits = model(input_ids=batch['input_ids'])
        loss = F.cross_entropy(
            logits[..., :-1, :].reshape(-1, config.vocab_size),
            batch['labels'][..., 1:].reshape(-1)
        )
    else:
        # 扩散损失
        output = model(
            image_patches=batch['noisy_latent'],
            timesteps=batch['timestep']
        )
        loss = F.mse_loss(output['noise_pred'], batch['noise'])
    
    # 反向传播
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    optimizer.zero_grad()
    
    return loss.item()

9.2 常见问题与解决方案

问题可能原因解决方案
文本困惑度上升 过大减小 或增加文本数据比例
图像FID退化训练不足或 过小增加训练步数或增大
模态崩溃学习率过高使用学习率预热和梯度裁剪
生成质量不稳定推理步数不足增加采样步数,使用DDIM调度

9.3 资源需求估算

模型规模GPU显存需求推荐GPU
125M~4GBRTX 3080
1B~16GBA100 40GB
7B~64GB8×A100 80GB

9.4 高级技巧与优化策略

梯度累积与大批量训练

当GPU显存受限时,可以使用梯度累积来模拟大批量训练:

def training_loop(model, dataloader, optimizer, config):
    model.train()
    optimizer.zero_grad()
    
    accumulated_loss = 0
    for step, batch in enumerate(dataloader):
        loss = compute_loss(model, batch, config)
        loss = loss / config.gradient_accumulation_steps
        loss.backward()
        accumulated_loss += loss.item()
        
        if (step + 1) % config.gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
            
            # 日志记录
            print(f"Step {step}: Loss = {accumulated_loss:.4f}")
            accumulated_loss = 0

混合精度训练

使用FP16或BF16混合精度训练可以显著加速并减少显存占用:

from torch.cuda.amp import autocast, GradScaler
 
scaler = GradScaler()
 
for batch in dataloader:
    with autocast():
        output = model(batch)
        loss = compute_loss(output, batch)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

推理加速

扩散模型的推理通常较慢,以下是几种加速策略:

  1. DDIM调度:使用DDIM(Denoising Diffusion Implicit Models)可以大幅减少采样步数
  2. LCM(Latent Consistency Models):学习一致性模型,只需2-4步采样
  3. 蒸馏:通过知识蒸馏压缩扩散模型
def ddim_sample(model, noisy_latent, guidance_scale=7.5):
    """DDIM采样加速"""
    num_steps = 20  # 远少于标准的1000步
    eta = 0.0  # DDIM超参数
    
    for t in reversed(range(num_steps)):
        # 预测噪声
        noise_pred = model(noisy_latent, timestep=t)
        
        # 计算DDIM更新
        alpha_t = get_alpha_cosine(t)
        pred_x0 = (noisy_latent - (1-alpha_t)**0.5 * noise_pred) / alpha_t**0.5
        
        # 选择性应用DDIM
        if t > 0:
            noise = torch.randn_like(noisy_latent)
            noisy_latent = alpha_t**0.5 * pred_x0 + (1-alpha_t)**0.5 * noise
        else:
            noisy_latent = pred_x0
    
    return noisy_latent

10. 总结

Transfusion提出了一种优雅的统一多模态建模方案,其核心贡献包括:

  1. 方法创新:首次证明可以在单一Transformer中结合语言建模和扩散训练
  2. 效率提升:相比量化方法,在相同计算资源下实现更好的扩展性
  3. 性能验证:7B模型在图像和文本生成上均达到同规模最优模型水平
  4. 架构启示:为真正的统一多模态基础模型开辟了新道路

Transfusion的成功表明,离散的文本和连续的图像可以在统一的框架下被有效建模,无需牺牲任一模态的优势。这为构建更强大的多模态AI系统提供了重要的理论基础和实践指导。

核心要点回顾

  • 统一损失
  • 混合注意力:序列级因果 + 图像内部双向
  • 模态特定组件:轻量级编码/解码层(线性或U-Net)
  • 推理切换:根据BOI token在LM和扩散模式间切换
  • 扩展性:相比量化方法,在控制FLOPs时FID降低约2倍

启示与展望

Transfusion的工作为多模态AI研究指明了几个重要方向:

  1. 任务适配的损失函数:不一定要求所有模态使用相同的损失函数
  2. 连续表示的优势:在可能的情况下,保留连续表示避免量化损失
  3. 统一架构的可能:单一模型可以有效处理多种数据类型
  4. 跨模态对齐:通过联合训练实现模态间的隐式对齐

随着计算资源的增长和算法的发展,我们可以期待看到更多基于Transfusion思想的多模态模型,可能整合文本、图像、音频、视频等多种模态,实现真正通用的多模态AI系统。

与其他统一多模态方法的对比总结

方法架构表示方式损失函数代表工作
Transfusion单一Transformer连续(图像)LM + DDPM本文
Chameleon单一Transformer离散(全部量化)仅LMMeta
Show-o单一Transformer混合(量化为离散)LM + Diffusion
EmuViT + LMCLIP特征LM百度
Flamingo独立LM + 视觉编码器连续LMDeepMind

从表中可以看出,Transfusion是唯一一种在单一模型中同时保留文本离散表示和图像连续表示的工作,同时使用各自最优的损失函数进行训练。

关键创新点总结

  1. 模态原生设计:不是强制统一表示,而是保留各模态的自然形式
  2. 任务适配损失:为每个模态选择最适合的损失函数
  3. 注意力机制创新:混合因果和双向注意力以适应不同模态
  4. 推理模式切换:通过特殊token实现不同生成模式的平滑切换
  5. 规模化验证:首次在7B规模验证了统一多模态模型的可行性

参考


相关主题

Footnotes

  1. Zhou, C., Yu, L., Babu, A., Tirumala, K., Yasunaga, M., Shamis, L., Kahn, J., Ma, X., Zettlemoyer, L., & Levy, O. (2025). Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model. International Conference on Learning Representations (ICLR 2025). https://arxiv.org/abs/2408.11039 2 3