Show-o2 统一多模态模型
概述
Show-o2 是由多机构联合提出的统一多模态模型,发表在 NeurIPS 2025。该模型的核心创新是:结合自回归(AR)建模和流匹配(Flow Matching),实现同时支持多模态理解和生成的高效统一架构。
1. 背景与动机
1.1 现有方法的局限
纯 AR 方法的局限(如 Emu3、Chameleon):
- 生成质量不如 Diffusion/Flow 模型
- 长序列生成效率低
- 对连续输出(图像)建模困难
纯 Diffusion 方法的局限(如 DALL-E 3、Stable Diffusion):
- 理解和生成需要分离的组件
- 推理速度慢(需要多步采样)
- 难以处理离散 token(如文本)
1.2 Show-o2 的解决方案
核心洞察:AR 和 Flow 各有优势,应该扬长避短。
| 任务 | 最适合格式 | 原因 |
|---|---|---|
| 文本生成 | 自回归 | 离散 token,序列依赖 |
| 图像理解 | 自回归 | 编码信息到离散表示 |
| 图像生成 | 流匹配 | 连续空间,高质量 |
Show-o2 策略:
- 理解:AR 编码所有模态
- 生成文本:AR 解码
- 生成图像:Flow 解码
2. 架构设计
2.1 整体架构
输入
│
├── 文本 ──→ Tokenizer ──→ 文本 Token
│
├── 图像 ──→ Visual Encoder ──→ 连续特征
│
└── 视频 ──→ Video Encoder ──→ 连续特征
│
▼
┌──────────────────────────────────────────────────────────┐
│ 统一 Transformer │
│ ┌─────────────────────────────────────────────────┐ │
│ │ 理解模式 (自回归编码) │ │
│ │ 处理文本/图像/视频输入,生成统一的上下文表示 │ │
│ └─────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ 生成模式 (双路径解码) │ │
│ │ 文本 → AR解码器 │ │
│ │ 图像 → Flow解码器 (Continuous) │ │
│ └─────────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────────────┘
│
▼
输出
│
├── 文本 Token ──→ Detokenizer ──→ 文本
│
└── 图像特征 ──→ VAE Decoder ──→ 图像
2.2 理解编码器
基于 Phi-3.5-mini 的语言模型骨干:
| 配置 | 值 |
|---|---|
| 隐藏维度 | 3072 |
| 层数 | 32 |
| 注意力头 | 32 |
| FFN 维度 | - |
| 上下文长度 | 8192 |
多模态理解头:
class MultimodalUnderstandingHead(nn.Module):
def __init__(self, vision_dim, lang_dim):
super().__init__()
# 视觉投影
self.vision_proj = nn.Sequential(
nn.Linear(vision_dim, lang_dim),
nn.GELU(),
nn.Linear(lang_dim, lang_dim)
)
# 视频时序建模
self.temporal_agg = TemporalAttention()
# 模态融合
self.fusion = CrossModalAttention()
def forward(self, vision_features, lang_features):
# 投影视觉特征
vision_features = self.vision_proj(vision_features)
# 时序聚合(视频)
vision_features = self.temporal_agg(vision_features)
# 跨模态融合
fused = self.fusion(vision_features, lang_features)
return fused2.3 AR 解码器(文本生成)
文本生成使用标准 Next-Token Prediction:
class ARDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.transformer = TransformerDecoder(config)
self.lm_head = nn.Linear(config.hidden, config.vocab_size)
def forward(self, context, text_tokens):
hidden = self.transformer(context, text_tokens)
logits = self.lm_head(hidden)
return logits
@torch.no_grad()
def generate(self, context, max_len=100):
tokens = []
for _ in range(max_len):
logits = self.forward(context, tokens)
next_token = logits[-1].argmax()
tokens.append(next_token)
if next_token == EOS_TOKEN:
break
return tokens2.4 Flow 解码器(图像生成)
核心创新:使用 Conditional Flow Matching 生成图像。
Flow Matching 回顾:
给定数据分布 和噪声分布 ,Flow Matching 定义一个插值路径:
其中 是数据点 的均值。
速度场:
Show-o2 的 Flow 解码器:
class FlowDecoder(nn.Module):
def __init__(self, config):
super().__init__()
# U-Net 风格的 denoiser
self.denoiser = UNetDiffusion(config)
# 条件注入
self.context_proj = nn.Linear(config.hidden, config.hidden)
# 时间步嵌入
self.time_embed = SinusoidalPosEmb(config.time_dim)
def forward(self, z_t, t, context):
"""训练:预测速度场"""
# 时间嵌入
t_embed = self.time_embed(t)
# 条件投影
context_embed = self.context_proj(context)
# 预测速度
v = self.denoiser(z_t, t_embed, context_embed)
return v
@torch.no_grad()
def sample(self, context, num_steps=50):
"""推理:从噪声生成图像"""
# 从纯噪声开始
z = torch.randn(B, C, H, W, device=context.device)
# 离散化 ODE 求解
dt = 1.0 / num_steps
for t in reversed(range(num_steps)):
t_batch = torch.full((B,), t / num_steps, device=context.device)
v = self.forward(z, t_batch, context)
z = z - dt * v
return z3. 统一训练目标
3.1 多任务训练框架
Show-o2 同时训练理解和生成能力:
def show_o2_loss(model, batch):
total_loss = 0.0
# 1. 理解任务:图像描述
if 'image' in batch and 'text' in batch:
understanding_loss = model.understanding_loss(
batch['image'], batch['text']
)
total_loss += 0.5 * understanding_loss
# 2. 理解任务:VQA
if 'image' in batch and 'question' in batch:
vqa_loss = model.understanding_loss(
batch['image'], batch['question']
)
total_loss += 0.3 * vqa_loss
# 3. 生成任务:文本生成
if 'prompt' in batch and 'target_text' in batch:
text_gen_loss = model.text_generation_loss(
batch['prompt'], batch['target_text']
)
total_loss += 0.4 * text_gen_loss
# 4. 生成任务:图像生成
if 'prompt' in batch and 'target_image' in batch:
image_gen_loss = model.image_generation_loss(
batch['prompt'], batch['target_image']
)
total_loss += 0.6 * image_gen_loss
return total_loss3.2 条件图像生成
Show-o2 支持基于文本条件的图像生成:
def conditional_image_generation(model, text_prompt):
# 1. 编码文本条件
text_tokens = model.tokenizer(text_prompt)
text_features = model.lang_encoder(text_tokens)
# 2. Flow 解码
z_T = torch.randn(1, 256, 32, 32) # 初始噪声
z_0 = model.flow_decoder.sample(z_T, context=text_features)
# 3. VAE 解码
image = model.vae.decoder(z_0)
return image3.3 交错生成
Show-o2 支持文本和图像的交错生成:
输入: "给我画一只猫,"
↓
输出: [图像 Token] + "它在草地上玩耍"
↓
显示图像 + 继续文本
def interleaved_generation(model, prompt):
tokens = model.tokenizer(prompt)
features = model.lang_encoder(tokens)
outputs = []
mode = "image" # 初始模式
for step in range(max_steps):
if mode == "image":
# 生成图像
image_tokens = model.flow_decoder.sample(context=features)
outputs.append(('image', image_tokens))
mode = "text" # 切换到文本模式
else:
# 生成文本
next_token = model.lang_decoder.generate(features)
outputs.append(('text', next_token))
if next_token == IMAGE_TOKEN:
mode = "image" # 遇到图像 token,切换模式
if next_token == EOS_TOKEN:
break
return outputs4. 训练策略
4.1 数据混合
| 数据类型 | 比例 | 用途 |
|---|---|---|
| 图像-文本对 | 40% | 理解和生成 |
| 纯文本 | 30% | 语言能力 |
| 图像描述 | 15% | 图像理解 |
| VQA 数据 | 10% | 视觉问答 |
| 交错数据 | 5% | 复杂推理 |
4.2 两阶段训练
阶段1: 预训练
├── 目标: 学习通用表示
├── 数据: 大规模图文数据
├── 训练目标: 多任务联合
└── 步数: 100K
阶段2: 微调
├── 目标: 提升特定能力
├── 数据: 精标注数据
├── 方法: SFT + DPO
└── 步数: 10K
4.3 训练超参数
| 参数 | 值 |
|---|---|
| 批量大小 | 512 (图像) / 2048 (文本) |
| 学习率 | 1e-4 |
| 预热 | 1% |
| 权重衰减 | 0.1 |
| 梯度裁剪 | 1.0 |
| 图像生成步数 | 50 |
| Flow 学习率 | 5e-5 |
5. 能力评估
5.1 理解任务
| 基准 | Show-o2 | Emu3 | Chameleon | LLaVA-1.5 |
|---|---|---|---|---|
| VQAv2 | 84.7% | 85.2% | 82.1% | 80.2% |
| GQA | 76.8% | 78.3% | 74.5% | 72.1% |
| VisWiz | 62.3% | 61.8% | 58.9% | 55.7% |
| TextVQA | 58.9% | 60.1% | 56.3% | 52.4% |
5.2 生成任务
| 基准 | Show-o2 | SDXL | DALL-E 3 | Emu3 |
|---|---|---|---|---|
| FID ↓ | 7.8 | 8.5 | 6.2 | 8.7 |
| CLIP Score ↑ | 0.87 | 0.82 | 0.89 | 0.85 |
| 人力评估 | 4.1/5 | 4.0/5 | 4.5/5 | 4.2/5 |
| 生成速度 | 2.1s | 3.5s | 5.2s | 1.8s |
5.3 统一能力
| 任务 | Show-o2 | 专用模型 |
|---|---|---|
| 图文推理 | 82.3% | 85.1% |
| 图像生成 | 7.8 FID | 6.2 FID |
| 交错生成 | ✅ | ❌ |
| 端到端优化 | ✅ | ❌ |
6. 技术细节
6.1 3D 因果 VAE
Show-o2 使用 3D 因果 VAE 进行视觉表示:
class CausalVideoVAE(nn.Module):
"""3D 因果 VAE 用于视频/图像"""
def __init__(self):
super().__init__()
# 编码器(2D 或 3D)
self.encoder = Encoder3D()
# 量化器
self.quantizer = VectorQuantizer(
vocab_size=8192,
dim=256,
commitment_cost=0.25
)
# 解码器
self.decoder = Decoder3D()
def encode(self, x):
"""编码为离散表示"""
z = self.encoder(x)
z_quantized, indices = self.quantizer(z)
return z_quantized, indices
def decode(self, z):
"""从连续表示解码"""
return self.decoder(z)
def forward(self, x):
z = self.encoder(x)
z_quantized, indices = self.quantizer(z)
x_recon = self.decoder(z_quantized)
return x_recon, indices6.2 模态切换机制
创新:显式的模态切换 token
SPECIAL_TOKENS = {
'BOS_IMAGE': 200000,
'EOS_IMAGE': 200001,
'BOS_VIDEO': 200002,
'EOS_VIDEO': 200003,
'IMAGE_START': 200004,
'IMAGE_END': 200005,
}
def modality_aware_forward(model, input_ids, modal_types):
"""模态感知的自注意力"""
# 获取嵌入
embeddings = model.embed_tokens(input_ids)
# 添加模态嵌入
for i, modal in enumerate(modal_types):
if modal == 'image':
embeddings[i] += model.image_modality_embed
elif modal == 'video':
embeddings[i] += model.video_modality_embed
else:
embeddings[i] += model.text_modality_embed
# 跨模态注意力
hidden = model.transformer(embeddings)
return hidden6.3 Flow Matching 细节
Conditional Flow Matching 损失:
def flow_matching_loss(model, x_1, context):
"""Flow Matching 训练损失"""
batch_size = x_1.shape[0]
# 采样时间步
t = torch.rand(batch_size, device=x_1.device)
# 采样噪声
x_0 = torch.randn_like(x_1)
# 插值
t_expanded = t.view(batch_size, 1, 1, 1)
x_t = t_expanded * x_1 + (1 - t_expanded) * x_0
# 目标速度
v_target = x_1 - x_0
# 预测速度
v_pred = model(x_t, t, context)
# MSE 损失
loss = F.mse_loss(v_pred, v_target)
return loss7. 与其他方法的对比
7.1 架构对比
| 方法 | 理解方式 | 文本生成 | 图像生成 | 统一程度 |
|---|---|---|---|---|
| Show-o2 | AR | AR | Flow | 完全统一 |
| Emu3 | AR | AR | AR | 完全统一 |
| Chameleon | AR | AR | AR | 完全统一 |
| DALL-E 3 | CLIP | GPT-4 | Diffusion | 部分统一 |
| LLaVA | ViT | LLaMA | SD | 分离 |
7.2 生成质量对比
Show-o2 vs 其他方法(FID 分数,越低越好):
Show-o2 Emu3 Chameleon SDXL DALL-E 3
ImageNet ████ ████ ████ ███ ██
COCO ████ ████ ████ ███ ██
LAION ████ ████ ████ ███ ██
████ = 8-9 ███ = 7-8 ██ = 6-7
7.3 推理效率对比
| 方法 | 图像生成时间 | 文本生成速度 | 总效率 |
|---|---|---|---|
| Show-o2 | 2.1s | 50 tok/s | 高 |
| Emu3 | 1.8s | 48 tok/s | 高 |
| SDXL | 3.5s | - | 中 |
| DALL-E 3 | 5.2s | - | 低 |
8. 实践应用
8.1 图像生成示例
from show_o2 import ShowO2Model
model = ShowO2Model.from_pretrained("show-o2-base")
# 文本到图像
image = model.generate_image(
prompt="A cat sitting on a green grass field",
num_steps=50,
guidance_scale=7.5
)
# 多图像生成
images = model.generate_image(
prompt=["A cat", "A dog", "A bird"],
batch_size=3
)8.2 图文推理示例
# 图像问答
answer = model.answer_question(
image="path/to/image.jpg",
question="What is the cat doing?"
)
# 图文生成
result = model.generate(
prompt="A cat on a mat. Then describe what you see."
)
# 输出: (图像, "A fluffy orange cat is sitting on a red mat.")8.3 交错生成示例
# 交错生成
outputs = model.generate_interleaved(
prompt="First, draw a mountain. Then write a poem about it.",
max_image_tokens=256,
max_text_tokens=100
)
# outputs: [
# ('image', generated_image_1),
# ('text', "Mountain peaks pierce the clouds high above..."),
# ]9. 总结与展望
9.1 主要贡献
- 混合架构创新:首次结合 AR 和 Flow Matching
- 统一训练框架:多任务联合训练
- 高效推理:Flow 解码比 Diffusion 快
- 完整开源:模型和代码公开
9.2 关键洞察
- 混合生成是平衡质量和效率的好方法
- 统一架构简化了训练和部署
- 条件图像生成可以无缝集成到 LLM
9.3 局限性
| 局限性 | 影响 |
|---|---|
| Flow 解码器训练复杂 | 训练成本 |
| 对话能力不如纯 LLM | 多轮对话 |
| 视频生成未充分探索 | 视频任务 |
9.4 未来方向
- 视频生成:扩展 Flow 解码器到视频
- 更高分辨率:改进 VAE 支持 1024+
- 更长上下文:扩展到视频级别
- 更强推理:提升复杂推理能力