Emu3 原生多模态统一模型

概述

Emu3 是智谱 AI 提出的原生多模态统一模型,发表在 arXiv:2510.26583。该模型的核心思想是:使用统一的 next-token 预测框架同时处理视觉理解和生成任务,打破了传统方法中理解和生成分离的架构设计。


1. 背景与动机

1.1 传统多模态模型的局限

传统多模态大模型通常采用编码器-解码器分离架构

方法视觉编码器语言解码器问题
GPT-4VViTGPT-4理解与生成分离
LLaVACLIP ViTLLaMA架构不统一
IDEFICSFLAVALLaMA预训练目标不一致

核心问题

  1. 理解和生成使用不同的架构
  2. 预训练目标不一致
  3. 难以端到端优化

1.2 原生统一的优势

原生统一模型的优势:

  • 架构统一:单一模型处理所有任务
  • 目标统一:统一的 next-token 预测目标
  • 优化统一:端到端梯度优化
  • 部署简单:一个模型服务多种需求

2. Emu3 架构设计

2.1 整体架构

Emu3 采用 Next-Token Prediction 统一框架:

输入
  │
  ├── 文本 ──→ Tokenizer ──→ 文本 Token
  │
  └── 图像 ──→ Visual Tokenizer ──→ 视觉 Token
  │
  ▼
┌──────────────────────────────────────────────┐
│              Transformer 解码器                 │
│  (统一的自回归生成,预测下一个 Token)           │
└──────────────────────────────────────────────┘
  │
  ▼
输出
  │
  ├── 文本 Token ──→ Detokenizer ──→ 文本
  │
  └── 视觉 Token ──→ Visual Generator ──→ 图像

2.2 视觉 Tokenizer

核心创新:使用离散视觉 Tokenizer 将图像转换为 Token 序列。

Tokenizer 架构

图像
  │
  ▼
┌─────────────────────────────────┐
│      视觉编码器 (ViT-based)      │
│   输出: 特征图 H×W×D            │
└─────────────────────────────────┘
  │
  ▼
┌─────────────────────────────────┐
│      Vector Quantization (VQ)    │
│   Codebook: K 个离散 Token       │
│   输出: H×W 的 Token IDs        │
└─────────────────────────────────┘
  │
  ▼
视觉 Token 序列

关键设计

组件设计选择
编码器Swin Transformer (冻结)
量化器VQ-GAN 风格
Codebook 大小32,768 (16-bit)
空间分辨率16×16 (256 tokens)

2.3 统一 Transformer

Emu3 使用 Llama-style Transformer 作为骨干:

配置
隐藏维度4096
层数40
注意力头32
QKV 维度128
FFN 维度13,712
上下文长度8,192
词汇表大小200,064 (文本+视觉)

2.4 训练目标

统一的 Next-Token 预测损失

其中 可以是任意模态的 Token(文本或视觉)。

关键洞察:文本 Token 和视觉 Token 在同一个词汇表中,可以无缝交替生成。


3. 训练策略

3.1 预训练数据

数据类型规模来源
图像-文本对10B互联网爬取
视频-文本对1B视频平台
纯文本1T书籍、网页
交错数据100M图文混合网页

3.2 多阶段训练

阶段1: 视觉Token学习
├── 目标: 学习好的视觉Tokenizer
├── 数据: 纯图像
└── 方法: VQ重构 + GAN损失

阶段2: 联合预训练
├── 目标: 统一架构学习
├── 数据: 图像-文本对、视频-文本对、纯文本
└── 方法: Next-token预测

阶段3: 指令微调
├── 目标: 提升指令遵循能力
├── 数据: SFT数据集
└── 方法: 监督微调

阶段4: RLHF对齐
├── 目标: 与人类偏好对齐
├── 数据: 人类偏好数据
└── 方法: PPO/DPO

3.3 训练细节

超参数
批量大小4,096 (tokens)
学习率3e-4
预热步数2,000
余弦衰减
权重衰减0.1
梯度裁剪1.0
训练步数100,000+

4. 能力评估

4.1 视觉理解基准

基准Emu3GPT-4VGemini ProQwen-VL
VQAv285.2%92.6%86.4%84.4%
GQA78.3%85.2%-77.9%
SEED-Bench76.8%84.1%79.4%73.7%
MMBench79.5%86.1%78.6%76.8%
MathVista58.2%72.4%65.0%54.1%

4.2 图像生成基准

基准Emu3DALL-E 3Stable DiffusionSDXL
FID (COCO)8.76.312.49.5
CLIP Score0.850.890.780.82
人力评估4.2/54.5/53.8/54.0/5

4.3 多模态推理

任务Emu3GPT-4VGemini Pro
图表理解82.3%88.7%85.2%
视频问答76.8%82.1%78.9%
多图像推理74.5%81.3%-
跨模态数学71.2%79.8%73.5%

5. 技术创新

5.1 统一 Tokenizer 设计

关键创新:文本和视觉 Token 共享词汇表。

class UnifiedTokenizer:
    def __init__(self):
        # 文本Tokenizer
        self.text_tokenizer = SentencePieceTokenizer()
        self.text_vocab_size = 128000
        
        # 视觉Tokenizer
        self.visual_tokenizer = VQTokenizer()
        self.visual_vocab_size = 32768
        
        # 合并词汇表
        self.total_vocab_size = self.text_vocab_size + self.visual_vocab_size
    
    def encode(self, inputs):
        if isinstance(inputs, str):
            return self.text_tokenizer.encode(inputs)
        elif is_image(inputs):
            return self.visual_tokenizer.encode(inputs)
        else:
            raise ValueError("Unsupported input type")
    
    def decode(self, token_ids):
        if token_ids < self.text_vocab_size:
            return self.text_tokenizer.decode([token_ids])
        else:
            visual_id = token_ids - self.text_vocab_size
            return self.visual_tokenizer.decode([visual_id])

5.2 模态感知注意力

创新:模态感知的注意力掩码

class ModalityAwareAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.attention = MultiHeadAttention()
        self.modality_embedding = nn.Embedding(3, hidden_dim)  # text, image, video
    
    def forward(self, x, modality_type):
        # 添加模态嵌入
        mod_emb = self.modality_embedding(modality_type)
        x = x + mod_emb
        
        # 模态感知的注意力
        attention_mask = self.compute_modality_mask(modality_type)
        return self.attention(x, attention_mask)
    
    def compute_modality_mask(self, modality_type):
        # 视觉 Token 之间可以完全交互
        # 文本 Token 之间可以完全交互
        # 视觉和文本 Token 根据任务动态交互
        ...

5.3 渐进式训练

创新:从易到难的课程学习

Step 1: 纯文本生成 (1T tokens)
   ↓
Step 2: 图像描述 (1B samples)
   ↓
Step 3: 图像生成 (500M samples)
   ↓
Step 4: 图文交错生成 (100M samples)
   ↓
Step 5: 视频理解/生成 (1B frames)

6. 与其他工作的对比

6.1 Show-o2 对比

方面Emu3Show-o2
统一方式AR (自回归)AR + Flow
视觉表示VQ Token连续特征
生成质量中等较高
推理速度较慢
架构复杂度简单复杂

6.2 Chameleon 对比

方面Emu3Chameleon
视觉 TokenVQ (离散)VQ (离散)
模型规模40B34B
训练数据10B 图像10B 图像
开源部分
性能相当相当

6.3 架构选择分析

为什么选择纯 AR 框架?

优点缺点
推理简单生成质量可能不如 Diffusion
部署容易长序列生成效率低
理论基础成熟-
与 LLM 兼容-

7. PyTorch 实现

7.1 简化实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LlamaConfig
 
 
class Emu3Config:
    """Emu3 配置"""
    def __init__(self):
        # Transformer 配置
        self.hidden_size = 4096
        self.num_hidden_layers = 40
        self.num_attention_heads = 32
        self.intermediate_size = 13712
        
        # Tokenizer 配置
        self.text_vocab_size = 128000
        self.visual_vocab_size = 32768
        self.total_vocab_size = self.text_vocab_size + self.visual_vocab_size
        
        # 视觉配置
        self.visual_token_size = 16  # 16x16 visual tokens
        
        self.max_position_embeddings = 8192
 
 
class VisualTokenizer(nn.Module):
    """视觉 Tokenizer"""
    def __init__(self, config):
        super().__init__()
        # 编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 128, 7, 2, 3),
            nn.ReLU(),
            *self._make_encoder_blocks(128, 256),
        )
        
        # Vector Quantization
        self.codebook = nn.Parameter(
            torch.randn(32768, 512)  # 32K tokens, 512-dim
        )
        
        # 投影层
        self.projection = nn.Linear(512, config.hidden_size)
    
    def _make_encoder_blocks(self, in_ch, out_ch):
        blocks = []
        for _ in range(4):
            blocks.extend([
                nn.Conv2d(in_ch, out_ch, 3, 2, 1),
                nn.ReLU(),
            ])
            in_ch = out_ch
        return blocks
    
    def forward(self, images):
        # 编码
        features = self.encoder(images)  # B, C, H, W
        
        # 量化
        B, C, H, W = features.shape
        features = features.permute(0, 2, 3, 1).reshape(B * H * W, C)
        
        # 最近邻查询
        distances = torch.cdist(features, self.codebook)
        indices = distances.argmin(dim=-1)
        
        # 获取量化特征
        quantized = self.codebook[indices]  # (B*H*W), 512
        
        # 重塑
        quantized = quantized.reshape(B, H, W, -1).permute(0, 3, 1, 2)
        
        # 投影到隐藏维度
        B, C, H, W = quantized.shape
        quantized = self.projection(quantized.permute(0, 2, 3, 1))  # B, H, W, hidden
        quantized = quantized.permute(0, 3, 1, 2)  # B, hidden, H, W
        
        return {
            'tokens': indices.reshape(B, H, W),  # 离散 token IDs
            'quantized': quantized,  # 连续特征
        }
 
 
class Emu3Model(nn.Module):
    """Emu3 主模型"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Token embeddings
        self.token_embeddings = nn.Embedding(
            config.total_vocab_size, 
            config.hidden_size
        )
        
        # 视觉 tokenizer
        self.visual_tokenizer = VisualTokenizer(config)
        
        # Transformer backbone
        self.transformer = LlamaModel(config)
        
        # Output projection
        self.lm_head = nn.Linear(
            config.hidden_size, 
            config.total_vocab_size, 
            bias=False
        )
    
    def encode_image(self, images):
        """编码图像为 tokens"""
        visual_output = self.visual_tokenizer(images)
        return visual_output['tokens']
    
    def forward(self, input_ids, pixel_values=None, labels=None):
        # 获取 embeddings
        if pixel_values is not None:
            # 替换视觉部分的 embeddings
            visual_tokens = self.encode_image(pixel_values)
            visual_embeds = self.token_embeddings(
                visual_tokens + self.config.text_vocab_size
            )
            # 混合到 input_ids 中
            ...
        
        hidden_states = self.token_embeddings(input_ids)
        
        # Transformer forward
        outputs = self.transformer(hidden_states=hidden_states)
        
        # Output projection
        logits = self.lm_head(outputs.last_hidden_state)
        
        return {'logits': logits}
    
    def generate(self, input_ids, max_length=100, pixel_values=None):
        """自回归生成"""
        self.eval()
        generated = input_ids.clone()
        
        for _ in range(max_length):
            outputs = self.forward(generated, pixel_values)
            next_token = outputs['logits'][:, -1, :].argmax(dim=-1)
            generated = torch.cat([generated, next_token.unsqueeze(1)], dim=1)
            
            if next_token.item() == EOS_TOKEN:
                break
        
        return generated

7.2 训练脚本

def train_emu3():
    from torch.utils.data import DataLoader
    from emu3_model import Emu3Model, Emu3Config
    
    # 初始化模型
    config = Emu3Config()
    model = Emu3Model(config)
    
    # 优化器
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=3e-4, 
        weight_decay=0.1
    )
    
    # 训练循环
    for epoch in range(num_epochs):
        for batch in dataloader:
            text_ids = batch['text_ids']
            pixel_values = batch['images']
            
            # 前向传播
            outputs = model(
                input_ids=text_ids, 
                pixel_values=pixel_values
            )
            
            # 计算损失
            logits = outputs['logits']
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1)
            )
            
            # 反向传播
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

8. 局限性与未来方向

8.1 当前局限性

局限性影响
生成质量不如专用 Diffusion 模型图像生成任务
长序列推理效率低视频生成任务
视觉 Tokenizer 训练复杂训练成本
词汇表管理复杂工程实现

8.2 未来改进方向

  1. 混合生成:结合 AR 和 Flow/Diffusion
  2. 更长上下文:扩展到视频级别
  3. 更好的视觉 Tokenizer:更高分辨率、更精细表示
  4. 多模态指令遵循:更强的复杂推理能力

9. 总结

9.1 主要贡献

  1. 提出原生统一架构:使用 Next-token prediction 统一多模态理解和生成
  2. 离散视觉 Token:基于 VQ 的视觉表示,简化与文本的融合
  3. 端到端优化:统一的预训练和微调流程
  4. 开源:模型权重和代码公开

9.2 关键洞察

  • 统一框架是多模态 AI 的重要方向
  • 离散 Token 可以有效桥接不同模态
  • 端到端 训练比分离组件更有效

9.3 展望

Emu3 代表了多模态 AI 的一个重要方向。随着训练数据和模型规模的增长,原生统一模型有望在更多任务上取得突破。


参考资料