Show-o2 统一多模态模型

概述

Show-o2 是由多机构联合提出的统一多模态模型,发表在 NeurIPS 2025。该模型的核心创新是:结合自回归(AR)建模流匹配(Flow Matching),实现同时支持多模态理解和生成的高效统一架构。


1. 背景与动机

1.1 现有方法的局限

纯 AR 方法的局限(如 Emu3、Chameleon):

  • 生成质量不如 Diffusion/Flow 模型
  • 长序列生成效率低
  • 对连续输出(图像)建模困难

纯 Diffusion 方法的局限(如 DALL-E 3、Stable Diffusion):

  • 理解和生成需要分离的组件
  • 推理速度慢(需要多步采样)
  • 难以处理离散 token(如文本)

1.2 Show-o2 的解决方案

核心洞察:AR 和 Flow 各有优势,应该扬长避短

任务最适合格式原因
文本生成自回归离散 token,序列依赖
图像理解自回归编码信息到离散表示
图像生成流匹配连续空间,高质量

Show-o2 策略

  • 理解:AR 编码所有模态
  • 生成文本:AR 解码
  • 生成图像:Flow 解码

2. 架构设计

2.1 整体架构

输入
  │
  ├── 文本 ──→ Tokenizer ──→ 文本 Token
  │
  ├── 图像 ──→ Visual Encoder ──→ 连续特征
  │
  └── 视频 ──→ Video Encoder ──→ 连续特征
  │
  ▼
┌──────────────────────────────────────────────────────────┐
│                    统一 Transformer                        │
│  ┌─────────────────────────────────────────────────┐    │
│  │          理解模式 (自回归编码)                    │    │
│  │  处理文本/图像/视频输入,生成统一的上下文表示       │    │
│  └─────────────────────────────────────────────────┘    │
│                          │                              │
│                          ▼                              │
│  ┌─────────────────────────────────────────────────┐    │
│  │          生成模式 (双路径解码)                    │    │
│  │  文本 → AR解码器                               │    │
│  │  图像 → Flow解码器 (Continuous)                │    │
│  └─────────────────────────────────────────────────┘    │
└──────────────────────────────────────────────────────────┘
  │
  ▼
输出
  │
  ├── 文本 Token ──→ Detokenizer ──→ 文本
  │
  └── 图像特征 ──→ VAE Decoder ──→ 图像

2.2 理解编码器

基于 Phi-3.5-mini 的语言模型骨干

配置
隐藏维度3072
层数32
注意力头32
FFN 维度-
上下文长度8192

多模态理解头

class MultimodalUnderstandingHead(nn.Module):
    def __init__(self, vision_dim, lang_dim):
        super().__init__()
        # 视觉投影
        self.vision_proj = nn.Sequential(
            nn.Linear(vision_dim, lang_dim),
            nn.GELU(),
            nn.Linear(lang_dim, lang_dim)
        )
        
        # 视频时序建模
        self.temporal_agg = TemporalAttention()
        
        # 模态融合
        self.fusion = CrossModalAttention()
    
    def forward(self, vision_features, lang_features):
        # 投影视觉特征
        vision_features = self.vision_proj(vision_features)
        
        # 时序聚合(视频)
        vision_features = self.temporal_agg(vision_features)
        
        # 跨模态融合
        fused = self.fusion(vision_features, lang_features)
        
        return fused

2.3 AR 解码器(文本生成)

文本生成使用标准 Next-Token Prediction

class ARDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transformer = TransformerDecoder(config)
        self.lm_head = nn.Linear(config.hidden, config.vocab_size)
    
    def forward(self, context, text_tokens):
        hidden = self.transformer(context, text_tokens)
        logits = self.lm_head(hidden)
        return logits
    
    @torch.no_grad()
    def generate(self, context, max_len=100):
        tokens = []
        for _ in range(max_len):
            logits = self.forward(context, tokens)
            next_token = logits[-1].argmax()
            tokens.append(next_token)
            if next_token == EOS_TOKEN:
                break
        return tokens

2.4 Flow 解码器(图像生成)

核心创新:使用 Conditional Flow Matching 生成图像。

Flow Matching 回顾

给定数据分布 和噪声分布 ,Flow Matching 定义一个插值路径:

其中 是数据点 的均值。

速度场

Show-o2 的 Flow 解码器

class FlowDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        # U-Net 风格的 denoiser
        self.denoiser = UNetDiffusion(config)
        
        # 条件注入
        self.context_proj = nn.Linear(config.hidden, config.hidden)
        
        # 时间步嵌入
        self.time_embed = SinusoidalPosEmb(config.time_dim)
    
    def forward(self, z_t, t, context):
        """训练:预测速度场"""
        # 时间嵌入
        t_embed = self.time_embed(t)
        
        # 条件投影
        context_embed = self.context_proj(context)
        
        # 预测速度
        v = self.denoiser(z_t, t_embed, context_embed)
        
        return v
    
    @torch.no_grad()
    def sample(self, context, num_steps=50):
        """推理:从噪声生成图像"""
        # 从纯噪声开始
        z = torch.randn(B, C, H, W, device=context.device)
        
        # 离散化 ODE 求解
        dt = 1.0 / num_steps
        for t in reversed(range(num_steps)):
            t_batch = torch.full((B,), t / num_steps, device=context.device)
            v = self.forward(z, t_batch, context)
            z = z - dt * v
        
        return z

3. 统一训练目标

3.1 多任务训练框架

Show-o2 同时训练理解和生成能力:

def show_o2_loss(model, batch):
    total_loss = 0.0
    
    # 1. 理解任务:图像描述
    if 'image' in batch and 'text' in batch:
        understanding_loss = model.understanding_loss(
            batch['image'], batch['text']
        )
        total_loss += 0.5 * understanding_loss
    
    # 2. 理解任务:VQA
    if 'image' in batch and 'question' in batch:
        vqa_loss = model.understanding_loss(
            batch['image'], batch['question']
        )
        total_loss += 0.3 * vqa_loss
    
    # 3. 生成任务:文本生成
    if 'prompt' in batch and 'target_text' in batch:
        text_gen_loss = model.text_generation_loss(
            batch['prompt'], batch['target_text']
        )
        total_loss += 0.4 * text_gen_loss
    
    # 4. 生成任务:图像生成
    if 'prompt' in batch and 'target_image' in batch:
        image_gen_loss = model.image_generation_loss(
            batch['prompt'], batch['target_image']
        )
        total_loss += 0.6 * image_gen_loss
    
    return total_loss

3.2 条件图像生成

Show-o2 支持基于文本条件的图像生成

def conditional_image_generation(model, text_prompt):
    # 1. 编码文本条件
    text_tokens = model.tokenizer(text_prompt)
    text_features = model.lang_encoder(text_tokens)
    
    # 2. Flow 解码
    z_T = torch.randn(1, 256, 32, 32)  # 初始噪声
    z_0 = model.flow_decoder.sample(z_T, context=text_features)
    
    # 3. VAE 解码
    image = model.vae.decoder(z_0)
    
    return image

3.3 交错生成

Show-o2 支持文本和图像的交错生成

输入: "给我画一只猫,"
       ↓
输出: [图像 Token] + "它在草地上玩耍"
            ↓
       显示图像 + 继续文本
def interleaved_generation(model, prompt):
    tokens = model.tokenizer(prompt)
    features = model.lang_encoder(tokens)
    
    outputs = []
    mode = "image"  # 初始模式
    
    for step in range(max_steps):
        if mode == "image":
            # 生成图像
            image_tokens = model.flow_decoder.sample(context=features)
            outputs.append(('image', image_tokens))
            mode = "text"  # 切换到文本模式
            
        else:
            # 生成文本
            next_token = model.lang_decoder.generate(features)
            outputs.append(('text', next_token))
            
            if next_token == IMAGE_TOKEN:
                mode = "image"  # 遇到图像 token,切换模式
            
            if next_token == EOS_TOKEN:
                break
    
    return outputs

4. 训练策略

4.1 数据混合

数据类型比例用途
图像-文本对40%理解和生成
纯文本30%语言能力
图像描述15%图像理解
VQA 数据10%视觉问答
交错数据5%复杂推理

4.2 两阶段训练

阶段1: 预训练
├── 目标: 学习通用表示
├── 数据: 大规模图文数据
├── 训练目标: 多任务联合
└── 步数: 100K

阶段2: 微调
├── 目标: 提升特定能力
├── 数据: 精标注数据
├── 方法: SFT + DPO
└── 步数: 10K

4.3 训练超参数

参数
批量大小512 (图像) / 2048 (文本)
学习率1e-4
预热1%
权重衰减0.1
梯度裁剪1.0
图像生成步数50
Flow 学习率5e-5

5. 能力评估

5.1 理解任务

基准Show-o2Emu3ChameleonLLaVA-1.5
VQAv284.7%85.2%82.1%80.2%
GQA76.8%78.3%74.5%72.1%
VisWiz62.3%61.8%58.9%55.7%
TextVQA58.9%60.1%56.3%52.4%

5.2 生成任务

基准Show-o2SDXLDALL-E 3Emu3
FID ↓7.88.56.28.7
CLIP Score ↑0.870.820.890.85
人力评估4.1/54.0/54.5/54.2/5
生成速度2.1s3.5s5.2s1.8s

5.3 统一能力

任务Show-o2专用模型
图文推理82.3%85.1%
图像生成7.8 FID6.2 FID
交错生成
端到端优化

6. 技术细节

6.1 3D 因果 VAE

Show-o2 使用 3D 因果 VAE 进行视觉表示:

class CausalVideoVAE(nn.Module):
    """3D 因果 VAE 用于视频/图像"""
    def __init__(self):
        super().__init__()
        
        # 编码器(2D 或 3D)
        self.encoder = Encoder3D()
        
        # 量化器
        self.quantizer = VectorQuantizer(
            vocab_size=8192,
            dim=256,
            commitment_cost=0.25
        )
        
        # 解码器
        self.decoder = Decoder3D()
    
    def encode(self, x):
        """编码为离散表示"""
        z = self.encoder(x)
        z_quantized, indices = self.quantizer(z)
        return z_quantized, indices
    
    def decode(self, z):
        """从连续表示解码"""
        return self.decoder(z)
    
    def forward(self, x):
        z = self.encoder(x)
        z_quantized, indices = self.quantizer(z)
        x_recon = self.decoder(z_quantized)
        return x_recon, indices

6.2 模态切换机制

创新:显式的模态切换 token

SPECIAL_TOKENS = {
    'BOS_IMAGE': 200000,
    'EOS_IMAGE': 200001,
    'BOS_VIDEO': 200002,
    'EOS_VIDEO': 200003,
    'IMAGE_START': 200004,
    'IMAGE_END': 200005,
}
 
def modality_aware_forward(model, input_ids, modal_types):
    """模态感知的自注意力"""
    
    # 获取嵌入
    embeddings = model.embed_tokens(input_ids)
    
    # 添加模态嵌入
    for i, modal in enumerate(modal_types):
        if modal == 'image':
            embeddings[i] += model.image_modality_embed
        elif modal == 'video':
            embeddings[i] += model.video_modality_embed
        else:
            embeddings[i] += model.text_modality_embed
    
    # 跨模态注意力
    hidden = model.transformer(embeddings)
    
    return hidden

6.3 Flow Matching 细节

Conditional Flow Matching 损失

def flow_matching_loss(model, x_1, context):
    """Flow Matching 训练损失"""
    batch_size = x_1.shape[0]
    
    # 采样时间步
    t = torch.rand(batch_size, device=x_1.device)
    
    # 采样噪声
    x_0 = torch.randn_like(x_1)
    
    # 插值
    t_expanded = t.view(batch_size, 1, 1, 1)
    x_t = t_expanded * x_1 + (1 - t_expanded) * x_0
    
    # 目标速度
    v_target = x_1 - x_0
    
    # 预测速度
    v_pred = model(x_t, t, context)
    
    # MSE 损失
    loss = F.mse_loss(v_pred, v_target)
    
    return loss

7. 与其他方法的对比

7.1 架构对比

方法理解方式文本生成图像生成统一程度
Show-o2ARARFlow完全统一
Emu3ARARAR完全统一
ChameleonARARAR完全统一
DALL-E 3CLIPGPT-4Diffusion部分统一
LLaVAViTLLaMASD分离

7.2 生成质量对比

Show-o2 vs 其他方法(FID 分数,越低越好):

        Show-o2  Emu3  Chameleon  SDXL  DALL-E 3
ImageNet  ████    ████   ████     ███   ██
COCO      ████    ████   ████     ███   ██
LAION     ████    ████   ████     ███   ██

████ = 8-9   ███ = 7-8   ██ = 6-7

7.3 推理效率对比

方法图像生成时间文本生成速度总效率
Show-o22.1s50 tok/s
Emu31.8s48 tok/s
SDXL3.5s-
DALL-E 35.2s-

8. 实践应用

8.1 图像生成示例

from show_o2 import ShowO2Model
 
model = ShowO2Model.from_pretrained("show-o2-base")
 
# 文本到图像
image = model.generate_image(
    prompt="A cat sitting on a green grass field",
    num_steps=50,
    guidance_scale=7.5
)
 
# 多图像生成
images = model.generate_image(
    prompt=["A cat", "A dog", "A bird"],
    batch_size=3
)

8.2 图文推理示例

# 图像问答
answer = model.answer_question(
    image="path/to/image.jpg",
    question="What is the cat doing?"
)
 
# 图文生成
result = model.generate(
    prompt="A cat on a mat. Then describe what you see."
)
# 输出: (图像, "A fluffy orange cat is sitting on a red mat.")

8.3 交错生成示例

# 交错生成
outputs = model.generate_interleaved(
    prompt="First, draw a mountain. Then write a poem about it.",
    max_image_tokens=256,
    max_text_tokens=100
)
 
# outputs: [
#   ('image', generated_image_1),
#   ('text', "Mountain peaks pierce the clouds high above..."),
# ]

9. 总结与展望

9.1 主要贡献

  1. 混合架构创新:首次结合 AR 和 Flow Matching
  2. 统一训练框架:多任务联合训练
  3. 高效推理:Flow 解码比 Diffusion 快
  4. 完整开源:模型和代码公开

9.2 关键洞察

  • 混合生成是平衡质量和效率的好方法
  • 统一架构简化了训练和部署
  • 条件图像生成可以无缝集成到 LLM

9.3 局限性

局限性影响
Flow 解码器训练复杂训练成本
对话能力不如纯 LLM多轮对话
视频生成未充分探索视频任务

9.4 未来方向

  1. 视频生成:扩展 Flow 解码器到视频
  2. 更高分辨率:改进 VAE 支持 1024+
  3. 更长上下文:扩展到视频级别
  4. 更强推理:提升复杂推理能力

参考资料