简介
统一多模态理解和生成已经在一流专有系统(如GPT-4o、Gemini 1.5 Pro)中展示了令人印象深刻的能力。然而,开源社区一直缺乏能够同时实现高质量理解和生成的开源基础模型。本文介绍BAGEL,一个开源的多模态基础模型,能够原生支持多模态理解和生成。BAGEL是一个统一的、仅用解码器的模型,在来自大规模多模态交错文本、图像、视频和网络数据的万亿token上进行了预训练。当在此类多样化的多模态交错数据上扩展训练时,BAGEL展现出复杂多模态推理的涌现能力,显著优于开源统一模型在多模态生成和理解方面的表现,同时展现出高级多模态推理能力,如自由形式图像操作、未来帧预测、3D操作和世界导航。1
背景:统一多模态模型的意义
传统方法的局限性
传统方法通常将理解和生成分开处理:
- 理解模型:CLIP、BERT等,专注于图像-文本匹配
- 生成模型:Stable Diffusion、DALL-E等,专注于文本到图像生成
这种分离导致了:
- 架构不一致:理解和生成使用完全不同的骨干网络
- 知识难以迁移:两个任务的学习信号无法共享
- 推理效率低:部署两套系统增加开销
统一模型的愿景
统一多模态模型追求:
- 单一架构:一个模型处理所有模态
- 端到端训练:理解和生成信号联合优化
- 涌现能力:多模态交互产生新的智能行为
BAGEL架构设计
核心设计原则
BAGEL遵循三个核心原则:
- 统一表示:所有模态(文本、图像、视频)映射到统一的token序列
- 仅解码器架构:避免理解和生成之间的架构差异
- 大规模预训练:在海量多模态交错数据上学习
模型结构
import torch
import torch.nn as nn
from typing import Optional, Dict, Any
class BAGELConfig:
"""BAGEL模型配置"""
def __init__(
self,
vocab_size: int = 128256,
hidden_size: int = 4096,
intermediate_size: int = 13696,
num_hidden_layers: int = 32,
num_attention_heads: int = 32,
num_key_value_heads: int = 8,
max_position_embeddings: int = 32768,
# 多模态参数
vision_patch_size: int = 14,
vision_hidden_size: int = 1024,
image_token_length: int = 1024, # 224x224 / 14^2 = 256 -> 处理更高分辨率
**kwargs
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.vision_patch_size = vision_patch_size
self.vision_hidden_size = vision_hidden_size
self.image_token_length = image_token_length
class MultimodalEmbedding(nn.Module):
"""
多模态嵌入层:统一文本、图像、视频表示
核心思想:将所有模态都编码为统一的token序列
"""
def __init__(self, config: BAGELConfig):
super().__init__()
self.config = config
# 文本嵌入
self.text_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
# 视觉编码器(SigLIP风格)
self.vision_encoder = SigLIPEncoder(
hidden_size=config.vision_hidden_size,
patch_size=config.vision_patch_size,
num_layers=24
)
# 视觉投影层:将视觉特征映射到语言空间
self.vision_projection = nn.Sequential(
nn.Linear(config.vision_hidden_size, config.hidden_size),
nn.GELU(),
nn.Linear(config.hidden_size, config.hidden_size)
)
# 模态嵌入:标识不同模态
self.modality_embedding = nn.Embedding(4, config.hidden_size) # 文本、图像、视频、交织
def encode_text(self, input_ids: torch.Tensor) -> torch.Tensor:
"""编码文本"""
text_embeds = self.text_embedding(input_ids)
text_modality = self.modality_embedding(
torch.zeros_like(input_ids).long() # 0 = 文本
)
return text_embeds + text_modality
def encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""编码图像"""
# 通过视觉编码器
vision_features = self.vision_encoder(pixel_values) # [B, N, D_vision]
# 投影到语言空间
image_embeds = self.vision_projection(vision_features)
# 添加图像模态嵌入
batch_size = pixel_values.shape[0]
num_patches = vision_features.shape[1]
image_modality = self.modality_embedding(
torch.ones(batch_size, num_patches).long().to(pixel_values.device) * 1 # 1 = 图像
)
return image_embeds + image_modality
def encode_video(self, video_values: torch.Tensor) -> torch.Tensor:
"""编码视频(逐帧处理)"""
# 视频是图像的时间序列
B, T, C, H, W = video_values.shape
# 展平时间维度,每帧独立编码
video_flat = video_values.reshape(B * T, C, H, W)
vision_features = self.vision_encoder(video_flat) # [B*T, N, D]
# 恢复时间维度
vision_features = vision_features.reshape(B, T * N, -1)
# 投影和模态嵌入
video_embeds = self.vision_projection(vision_features)
video_modality = self.modality_embedding(
torch.ones(B, T * N).long().to(video_values.device) * 2 # 2 = 视频
)
return video_embeds + video_modality
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
video_values: Optional[torch.Tensor] = None,
modality_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
前向传播
Args:
input_ids: 文本token IDs
pixel_values: 图像像素值
video_values: 视频像素值
modality_mask: 标识每个token的模态
Returns:
embeddings: 多模态嵌入序列
"""
embeddings_list = []
if input_ids is not None:
embeddings_list.append(self.encode_text(input_ids))
if pixel_values is not None:
embeddings_list.append(self.encode_image(pixel_values))
if video_values is not None:
embeddings_list.append(self.encode_video(video_values))
# 拼接所有嵌入
if len(embeddings_list) == 1:
return embeddings_list[0]
else:
return torch.cat(embeddings_list, dim=1)
class BAGELModel(nn.Module):
"""
BAGEL统一多模态模型
统一的decoder-only架构处理文本、图像、视频生成和理解
"""
def __init__(self, config: BAGELConfig):
super().__init__()
self.config = config
# 多模态嵌入
self.embed_tokens = MultimodalEmbedding(config)
# Transformer解码器
self.layers = nn.ModuleList([
BAGELLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size)
# 头部
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# 绑定权重(可选)
self.lm_head.weight = self.embed_tokens.text_embedding.weight
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
video_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs
) -> Dict[str, Any]:
# 多模态嵌入
hidden_states = self.embed_tokens(
input_ids=input_ids,
pixel_values=pixel_values,
video_values=video_values
)
# 通过Transformer层
for layer in self.layers:
hidden_states = layer(hidden_states, **kwargs)
hidden_states = self.norm(hidden_states)
# 语言模型头
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.config.vocab_size),
labels.view(-1)
)
return {
'logits': logits,
'loss': loss
}
def demo_bagel_multimodal_generation():
"""演示BAGEL的多模态生成能力"""
print("=== BAGEL多模态生成演示 ===\n")
# 模拟输入
batch_size = 1
seq_len = 512
hidden_size = 4096
# 模拟数据
text_tokens = torch.randint(0, 128256, (batch_size, seq_len))
images = torch.randn(batch_size, 3, 224, 224)
# 创建模型
config = BAGELConfig(
hidden_size=hidden_size,
num_hidden_layers=32
)
model = BAGELModel(config)
model.eval()
print("模型参数量:")
n_params = sum(p.numel() for p in model.parameters())
print(f" 总参数: {n_params / 1e9:.2f}B")
# 前向传播
with torch.no_grad():
output = model(input_ids=text_tokens, pixel_values=images)
print(f"\n输出logits形状: {output['logits'].shape}")
print(f"损失: {output['loss']}")
if __name__ == "__main__":
demo_bagel_multimodal_generation()视觉Tokenizer
BAGEL使用自定义的视觉tokenizer:
- 编码器:SigLIP风格的视觉Transformer
- 码本:65536个离散视觉token
- 压缩比:224×224图像 → 256个token
预训练数据
数据构建策略
BAGEL的训练数据包含:
| 数据类型 | 规模 | 特点 |
|---|---|---|
| 多模态交错文本 | 2T tokens | 网页、文档、代码 |
| 图像-文本对 | 1B images | ALT、标题描述 |
| 视频-文本对 | 100M videos | 字幕、描述 |
| 交错序列 | 500M sequences | 多模态混合内容 |
交错数据的重要性
核心洞察:交错的多模态序列对于涌现复杂推理能力至关重要。
交错数据示例:
[文本] A robot is navigating through a warehouse
[图像] 仓库场景图
[文本] It detects an obstacle and plans an alternative path
[图像] 路径规划可视化
[文本] The robot successfully reaches its destination
[图像] 成功到达目标
涌现的多模态能力
1. 自由形式图像操作
BAGEL能够在保持语义一致性的同时进行精细的图像编辑:
def image_editing_prompt():
"""图像编辑提示示例"""
examples = [
# 对象替换
{
"instruction": "Replace the red car with a blue bicycle",
"input": "[原图: 红色汽车停在路边]",
"output": "[修改后: 蓝色自行车停在同一位置]"
},
# 场景修改
{
"instruction": "Change the time of day from day to night",
"input": "[原图: 白天的城市街景]",
"output": "[修改后: 夜晚的城市街景]"
},
# 属性编辑
{
"instruction": "Make the cat look fluffy and cute",
"input": "[原图: 一只普通的猫]",
"output": "[修改后: 蓬松可爱的猫]"
}
]
return examples2. 未来帧预测
给定视频片段,BAGEL能够预测未来的帧:
输入: [视频片段: 一个人从起点走向门口]
输出: [预测: 这个人打开门并走出门外]
3. 3D操作和世界导航
BAGEL展现出对3D空间关系的理解:
- 理解对象之间的空间关系
- 预测遮挡和深度
- 进行虚拟导航
与其他统一模型对比
架构对比
| 模型 | 架构 | 理解能力 | 生成能力 | 开源 |
|---|---|---|---|---|
| GPT-4o | 专有 | ★★★★★ | ★★★★★ | 否 |
| Gemini 1.5 | 专有 | ★★★★★ | ★★★★☆ | 否 |
| BAGEL | Decoder-only | ★★★★☆ | ★★★★☆ | 是 |
| JanusFlow | Decoder-only | ★★★☆☆ | ★★★★☆ | 是 |
| Show-o | Decoder | ★★★☆☆ | ★★★☆☆ | 是 |
性能对比
=== 多模态理解基准 ===
模型 | VQAv2 | GQA | VizWiz | MMMU
-------------|-------|------|--------|------
BAGEL-7B | 84.2 | 82.1 | 68.5 | 52.3
JanusFlow-7B| 81.5 | 79.8 | 64.2 | 48.1
Show-o-7B | 80.3 | 78.5 | 62.8 | 45.6
=== 文本到图像生成基准 ===
模型 | FID | CLIP-S | PickScore
-------------|------|--------|----------
BAGEL-7B | 8.2 | 0.82 | 82.5
JanusFlow-7B| 9.5 | 0.79 | 79.8
SDXL | 7.6 | 0.81 | 80.2
训练策略
阶段训练
BAGEL采用多阶段训练策略:
class BAGELTrainingStages:
"""BAGEL训练阶段"""
STAGE_1 = {
'name': '多模态交错预训练',
'data': '大规模交错文本-图像-视频',
'objective': '下一个token预测',
'batch_size': 8192,
'learning_rate': 1e-4,
'steps': 100000
}
STAGE_2 = {
'name': '指令微调',
'data': '多模态指令数据',
'objective': '指令跟随',
'batch_size': 2048,
'learning_rate': 5e-5,
'steps': 50000
}
STAGE_3 = {
'name': '强化学习对齐',
'data': '多模态偏好数据',
'objective': 'RLHF/DPO',
'batch_size': 512,
'learning_rate': 1e-5,
'steps': 20000
}关键技术
- 模态平衡采样:确保不同模态数据均衡训练
- 渐进式分辨率:从低分辨率开始,逐步提高
- 权重共享:文本和视觉embedding共享
限制与挑战
当前限制
- 生成质量:仍落后于专门的扩散模型
- 分辨率:高分辨率图像生成需要额外处理
- 计算资源:训练需要大规模GPU集群
未来方向
- 更大规模:扩展到更大模型
- 更多模态:音频、3D点云等
- 改进生成:结合扩散技术提升质量
总结
BAGEL作为首个开源的统一多模态基础模型,展示了:
- 统一架构:仅decoder设计简化了多模态处理
- 涌现能力:复杂的多模态推理能力
- 开源贡献:为社区提供可复现的多模态基线
BAGEL的出现标志着开源社区在统一多模态理解与生成方面迈出了重要一步。