Emu3 原生多模态统一模型
概述
Emu3 是智谱 AI 提出的原生多模态统一模型,发表在 arXiv:2510.26583。该模型的核心思想是:使用统一的 next-token 预测框架同时处理视觉理解和生成任务,打破了传统方法中理解和生成分离的架构设计。
1. 背景与动机
1.1 传统多模态模型的局限
传统多模态大模型通常采用编码器-解码器分离架构:
| 方法 | 视觉编码器 | 语言解码器 | 问题 |
|---|---|---|---|
| GPT-4V | ViT | GPT-4 | 理解与生成分离 |
| LLaVA | CLIP ViT | LLaMA | 架构不统一 |
| IDEFICS | FLAVA | LLaMA | 预训练目标不一致 |
核心问题:
- 理解和生成使用不同的架构
- 预训练目标不一致
- 难以端到端优化
1.2 原生统一的优势
原生统一模型的优势:
- 架构统一:单一模型处理所有任务
- 目标统一:统一的 next-token 预测目标
- 优化统一:端到端梯度优化
- 部署简单:一个模型服务多种需求
2. Emu3 架构设计
2.1 整体架构
Emu3 采用 Next-Token Prediction 统一框架:
输入
│
├── 文本 ──→ Tokenizer ──→ 文本 Token
│
└── 图像 ──→ Visual Tokenizer ──→ 视觉 Token
│
▼
┌──────────────────────────────────────────────┐
│ Transformer 解码器 │
│ (统一的自回归生成,预测下一个 Token) │
└──────────────────────────────────────────────┘
│
▼
输出
│
├── 文本 Token ──→ Detokenizer ──→ 文本
│
└── 视觉 Token ──→ Visual Generator ──→ 图像
2.2 视觉 Tokenizer
核心创新:使用离散视觉 Tokenizer 将图像转换为 Token 序列。
Tokenizer 架构:
图像
│
▼
┌─────────────────────────────────┐
│ 视觉编码器 (ViT-based) │
│ 输出: 特征图 H×W×D │
└─────────────────────────────────┘
│
▼
┌─────────────────────────────────┐
│ Vector Quantization (VQ) │
│ Codebook: K 个离散 Token │
│ 输出: H×W 的 Token IDs │
└─────────────────────────────────┘
│
▼
视觉 Token 序列
关键设计:
| 组件 | 设计选择 |
|---|---|
| 编码器 | Swin Transformer (冻结) |
| 量化器 | VQ-GAN 风格 |
| Codebook 大小 | 32,768 (16-bit) |
| 空间分辨率 | 16×16 (256 tokens) |
2.3 统一 Transformer
Emu3 使用 Llama-style Transformer 作为骨干:
| 配置 | 值 |
|---|---|
| 隐藏维度 | 4096 |
| 层数 | 40 |
| 注意力头 | 32 |
| QKV 维度 | 128 |
| FFN 维度 | 13,712 |
| 上下文长度 | 8,192 |
| 词汇表大小 | 200,064 (文本+视觉) |
2.4 训练目标
统一的 Next-Token 预测损失:
其中 可以是任意模态的 Token(文本或视觉)。
关键洞察:文本 Token 和视觉 Token 在同一个词汇表中,可以无缝交替生成。
3. 训练策略
3.1 预训练数据
| 数据类型 | 规模 | 来源 |
|---|---|---|
| 图像-文本对 | 10B | 互联网爬取 |
| 视频-文本对 | 1B | 视频平台 |
| 纯文本 | 1T | 书籍、网页 |
| 交错数据 | 100M | 图文混合网页 |
3.2 多阶段训练
阶段1: 视觉Token学习
├── 目标: 学习好的视觉Tokenizer
├── 数据: 纯图像
└── 方法: VQ重构 + GAN损失
阶段2: 联合预训练
├── 目标: 统一架构学习
├── 数据: 图像-文本对、视频-文本对、纯文本
└── 方法: Next-token预测
阶段3: 指令微调
├── 目标: 提升指令遵循能力
├── 数据: SFT数据集
└── 方法: 监督微调
阶段4: RLHF对齐
├── 目标: 与人类偏好对齐
├── 数据: 人类偏好数据
└── 方法: PPO/DPO
3.3 训练细节
| 超参数 | 值 |
|---|---|
| 批量大小 | 4,096 (tokens) |
| 学习率 | 3e-4 |
| 预热步数 | 2,000 |
| 余弦衰减 | 是 |
| 权重衰减 | 0.1 |
| 梯度裁剪 | 1.0 |
| 训练步数 | 100,000+ |
4. 能力评估
4.1 视觉理解基准
| 基准 | Emu3 | GPT-4V | Gemini Pro | Qwen-VL |
|---|---|---|---|---|
| VQAv2 | 85.2% | 92.6% | 86.4% | 84.4% |
| GQA | 78.3% | 85.2% | - | 77.9% |
| SEED-Bench | 76.8% | 84.1% | 79.4% | 73.7% |
| MMBench | 79.5% | 86.1% | 78.6% | 76.8% |
| MathVista | 58.2% | 72.4% | 65.0% | 54.1% |
4.2 图像生成基准
| 基准 | Emu3 | DALL-E 3 | Stable Diffusion | SDXL |
|---|---|---|---|---|
| FID (COCO) | 8.7 | 6.3 | 12.4 | 9.5 |
| CLIP Score | 0.85 | 0.89 | 0.78 | 0.82 |
| 人力评估 | 4.2/5 | 4.5/5 | 3.8/5 | 4.0/5 |
4.3 多模态推理
| 任务 | Emu3 | GPT-4V | Gemini Pro |
|---|---|---|---|
| 图表理解 | 82.3% | 88.7% | 85.2% |
| 视频问答 | 76.8% | 82.1% | 78.9% |
| 多图像推理 | 74.5% | 81.3% | - |
| 跨模态数学 | 71.2% | 79.8% | 73.5% |
5. 技术创新
5.1 统一 Tokenizer 设计
关键创新:文本和视觉 Token 共享词汇表。
class UnifiedTokenizer:
def __init__(self):
# 文本Tokenizer
self.text_tokenizer = SentencePieceTokenizer()
self.text_vocab_size = 128000
# 视觉Tokenizer
self.visual_tokenizer = VQTokenizer()
self.visual_vocab_size = 32768
# 合并词汇表
self.total_vocab_size = self.text_vocab_size + self.visual_vocab_size
def encode(self, inputs):
if isinstance(inputs, str):
return self.text_tokenizer.encode(inputs)
elif is_image(inputs):
return self.visual_tokenizer.encode(inputs)
else:
raise ValueError("Unsupported input type")
def decode(self, token_ids):
if token_ids < self.text_vocab_size:
return self.text_tokenizer.decode([token_ids])
else:
visual_id = token_ids - self.text_vocab_size
return self.visual_tokenizer.decode([visual_id])5.2 模态感知注意力
创新:模态感知的注意力掩码
class ModalityAwareAttention(nn.Module):
def __init__(self):
super().__init__()
self.attention = MultiHeadAttention()
self.modality_embedding = nn.Embedding(3, hidden_dim) # text, image, video
def forward(self, x, modality_type):
# 添加模态嵌入
mod_emb = self.modality_embedding(modality_type)
x = x + mod_emb
# 模态感知的注意力
attention_mask = self.compute_modality_mask(modality_type)
return self.attention(x, attention_mask)
def compute_modality_mask(self, modality_type):
# 视觉 Token 之间可以完全交互
# 文本 Token 之间可以完全交互
# 视觉和文本 Token 根据任务动态交互
...5.3 渐进式训练
创新:从易到难的课程学习
Step 1: 纯文本生成 (1T tokens)
↓
Step 2: 图像描述 (1B samples)
↓
Step 3: 图像生成 (500M samples)
↓
Step 4: 图文交错生成 (100M samples)
↓
Step 5: 视频理解/生成 (1B frames)
6. 与其他工作的对比
6.1 Show-o2 对比
| 方面 | Emu3 | Show-o2 |
|---|---|---|
| 统一方式 | AR (自回归) | AR + Flow |
| 视觉表示 | VQ Token | 连续特征 |
| 生成质量 | 中等 | 较高 |
| 推理速度 | 快 | 较慢 |
| 架构复杂度 | 简单 | 复杂 |
6.2 Chameleon 对比
| 方面 | Emu3 | Chameleon |
|---|---|---|
| 视觉 Token | VQ (离散) | VQ (离散) |
| 模型规模 | 40B | 34B |
| 训练数据 | 10B 图像 | 10B 图像 |
| 开源 | 是 | 部分 |
| 性能 | 相当 | 相当 |
6.3 架构选择分析
为什么选择纯 AR 框架?
| 优点 | 缺点 |
|---|---|
| 推理简单 | 生成质量可能不如 Diffusion |
| 部署容易 | 长序列生成效率低 |
| 理论基础成熟 | - |
| 与 LLM 兼容 | - |
7. PyTorch 实现
7.1 简化实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LlamaConfig
class Emu3Config:
"""Emu3 配置"""
def __init__(self):
# Transformer 配置
self.hidden_size = 4096
self.num_hidden_layers = 40
self.num_attention_heads = 32
self.intermediate_size = 13712
# Tokenizer 配置
self.text_vocab_size = 128000
self.visual_vocab_size = 32768
self.total_vocab_size = self.text_vocab_size + self.visual_vocab_size
# 视觉配置
self.visual_token_size = 16 # 16x16 visual tokens
self.max_position_embeddings = 8192
class VisualTokenizer(nn.Module):
"""视觉 Tokenizer"""
def __init__(self, config):
super().__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Conv2d(3, 128, 7, 2, 3),
nn.ReLU(),
*self._make_encoder_blocks(128, 256),
)
# Vector Quantization
self.codebook = nn.Parameter(
torch.randn(32768, 512) # 32K tokens, 512-dim
)
# 投影层
self.projection = nn.Linear(512, config.hidden_size)
def _make_encoder_blocks(self, in_ch, out_ch):
blocks = []
for _ in range(4):
blocks.extend([
nn.Conv2d(in_ch, out_ch, 3, 2, 1),
nn.ReLU(),
])
in_ch = out_ch
return blocks
def forward(self, images):
# 编码
features = self.encoder(images) # B, C, H, W
# 量化
B, C, H, W = features.shape
features = features.permute(0, 2, 3, 1).reshape(B * H * W, C)
# 最近邻查询
distances = torch.cdist(features, self.codebook)
indices = distances.argmin(dim=-1)
# 获取量化特征
quantized = self.codebook[indices] # (B*H*W), 512
# 重塑
quantized = quantized.reshape(B, H, W, -1).permute(0, 3, 1, 2)
# 投影到隐藏维度
B, C, H, W = quantized.shape
quantized = self.projection(quantized.permute(0, 2, 3, 1)) # B, H, W, hidden
quantized = quantized.permute(0, 3, 1, 2) # B, hidden, H, W
return {
'tokens': indices.reshape(B, H, W), # 离散 token IDs
'quantized': quantized, # 连续特征
}
class Emu3Model(nn.Module):
"""Emu3 主模型"""
def __init__(self, config):
super().__init__()
self.config = config
# Token embeddings
self.token_embeddings = nn.Embedding(
config.total_vocab_size,
config.hidden_size
)
# 视觉 tokenizer
self.visual_tokenizer = VisualTokenizer(config)
# Transformer backbone
self.transformer = LlamaModel(config)
# Output projection
self.lm_head = nn.Linear(
config.hidden_size,
config.total_vocab_size,
bias=False
)
def encode_image(self, images):
"""编码图像为 tokens"""
visual_output = self.visual_tokenizer(images)
return visual_output['tokens']
def forward(self, input_ids, pixel_values=None, labels=None):
# 获取 embeddings
if pixel_values is not None:
# 替换视觉部分的 embeddings
visual_tokens = self.encode_image(pixel_values)
visual_embeds = self.token_embeddings(
visual_tokens + self.config.text_vocab_size
)
# 混合到 input_ids 中
...
hidden_states = self.token_embeddings(input_ids)
# Transformer forward
outputs = self.transformer(hidden_states=hidden_states)
# Output projection
logits = self.lm_head(outputs.last_hidden_state)
return {'logits': logits}
def generate(self, input_ids, max_length=100, pixel_values=None):
"""自回归生成"""
self.eval()
generated = input_ids.clone()
for _ in range(max_length):
outputs = self.forward(generated, pixel_values)
next_token = outputs['logits'][:, -1, :].argmax(dim=-1)
generated = torch.cat([generated, next_token.unsqueeze(1)], dim=1)
if next_token.item() == EOS_TOKEN:
break
return generated7.2 训练脚本
def train_emu3():
from torch.utils.data import DataLoader
from emu3_model import Emu3Model, Emu3Config
# 初始化模型
config = Emu3Config()
model = Emu3Model(config)
# 优化器
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4,
weight_decay=0.1
)
# 训练循环
for epoch in range(num_epochs):
for batch in dataloader:
text_ids = batch['text_ids']
pixel_values = batch['images']
# 前向传播
outputs = model(
input_ids=text_ids,
pixel_values=pixel_values
)
# 计算损失
logits = outputs['logits']
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1)
)
# 反向传播
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()8. 局限性与未来方向
8.1 当前局限性
| 局限性 | 影响 |
|---|---|
| 生成质量不如专用 Diffusion 模型 | 图像生成任务 |
| 长序列推理效率低 | 视频生成任务 |
| 视觉 Tokenizer 训练复杂 | 训练成本 |
| 词汇表管理复杂 | 工程实现 |
8.2 未来改进方向
- 混合生成:结合 AR 和 Flow/Diffusion
- 更长上下文:扩展到视频级别
- 更好的视觉 Tokenizer:更高分辨率、更精细表示
- 多模态指令遵循:更强的复杂推理能力
9. 总结
9.1 主要贡献
- 提出原生统一架构:使用 Next-token prediction 统一多模态理解和生成
- 离散视觉 Token:基于 VQ 的视觉表示,简化与文本的融合
- 端到端优化:统一的预训练和微调流程
- 开源:模型权重和代码公开
9.2 关键洞察
- 统一框架是多模态 AI 的重要方向
- 离散 Token 可以有效桥接不同模态
- 端到端 训练比分离组件更有效
9.3 展望
Emu3 代表了多模态 AI 的一个重要方向。随着训练数据和模型规模的增长,原生统一模型有望在更多任务上取得突破。