1. 研究背景与动机
1.1 全模态模型的愿景
理想的全模态模型应该能够1:
- 统一理解:处理文本、图像、语音等多种输入
- 统一生成:产生任意模态的输出
- 端到端优化:避免多阶段管道的累积误差
1.2 现有方法的局限
| 方法 | 优势 | 局限 |
|---|---|---|
| AR+扩散分离 | 成熟 | 表示不对齐 |
| 联合训练 | 表示对齐 | 训练困难 |
| 模态特定解码 | 生成质量高 | 架构复杂 |
1.3 Dynin-Omni的创新
Dynin-Omni是首个基于**掩码扩散(Masked Diffusion)的全模态(Omnimodal)**基础模型1:
核心创新:使用统一的掩码扩散框架同时处理文本、图像、语音的理解与生成。
2. 技术框架
2.1 掩码扩散基础
掩码扩散(Masked Diffusion)与标准连续扩散不同:
正向过程(Masking):
逆向过程(Denoising):
2.2 统一表示
Dynin-Omni将所有模态转换为统一token序列:
class UnifiedTokenSequence:
"""
统一token序列
"""
# 模态ID
MOD_TEXT = 0
MOD_IMAGE = 1
MOD_SPEECH = 2
def __init__(self):
self.text_tokenizer = TextTokenizer()
self.image_tokenizer = ImageTokenizer() # VAE离散编码
self.speech_tokenizer = SpeechTokenizer() # HuBERT离散编码
def tokenize(self, data):
"""
统一token化
"""
if isinstance(data, str):
return self.text_tokenizer.encode(data), self.MOD_TEXT
elif isinstance(data, Image.Image):
return self.image_tokenizer.encode(data), self.MOD_IMAGE
elif isinstance(data, np.ndarray): # 音频
return self.speech_tokenizer.encode(data), self.MOD_SPEECH2.3 整体架构
┌─────────────────────────────────────────────────────────────────────────┐
│ Dynin-Omni 整体架构 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 输入: │
│ ┌─────────┐ │
│ │ 文本 │ ──► Text Tokenizer ──► │
│ └─────────┘ │
│ ┌─────────┐ │
│ │ 图像 │ ──► Image Tokenizer ──► Unified Token Stream │
│ └─────────┘ │
│ ┌─────────┐ │
│ │ 语音 │ ──► Speech Tokenizer ──► │
│ └─────────┘ │
│ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Omni-Diffusion Transformer │ │
│ │ │ │
│ │ - 自条件掩码扩散 │ │
│ │ - 模态感知注意力 │ │
│ │ - 跨模态注意力 │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 输出: │
│ ┌─────────┐ │
│ │ 文本 │ ◄── Text Decoder │
│ └─────────┘ │
│ ┌─────────┐ │
│ │ 图像 │ ◄── Image Decoder (VAE) │
│ └─────────┘ │
│ ┌─────────┐ │
│ │ 语音 │ ◄── Speech Decoder │
│ └─────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
3. 核心组件
3.1 Omni-Diffusion Transformer
class OmniDiffusionTransformer(nn.Module):
"""
全模态扩散Transformer
"""
def __init__(self, config):
super().__init__()
# 统一嵌入
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.modality_embed = nn.Embedding(3, config.d_model) # 3种模态
self.time_embed = TimeEmbedding(config.d_model)
# Transformer层
self.layers = nn.ModuleList([
OmniTransformerLayer(config)
for _ in range(config.num_layers)
])
# 输出头
self.mask_pred_head = nn.Linear(config.d_model, config.vocab_size)
def forward(self, x, t, modality_ids, mask):
"""
Args:
x: 输入token [B, N]
t: 时间步 [B]
modality_ids: 模态ID [B, N]
mask: 掩码标记 [B, N]
"""
# 嵌入
h = self.embedding(x)
h = h + self.modality_embed(modality_ids)
h = h + self.time_embed(t)
# 添加掩码信息
h = h + mask.float().unsqueeze(-1) * self.mask_embed
# Transformer处理
for layer in self.layers:
h = layer(h, mask)
# 预测掩码token
logits = self.mask_pred_head(h)
return logits3.2 模态感知注意力
class ModalityAwareAttention(nn.Module):
"""
模态感知注意力
根据模态动态调整注意力模式
"""
def __init__(self, d_model, num_heads):
super().__init__()
self.qkv = nn.Linear(d_model, d_model * 3)
self.proj = nn.Linear(d_model, d_model)
# 模态特定注意力参数
self.modality_attn = nn.ModuleDict({
'text': nn.Linear(d_model, num_heads),
'image': nn.Linear(d_model, num_heads),
'speech': nn.Linear(d_model, num_heads),
'cross': nn.Linear(d_model, num_heads)
})
def forward(self, x, modality_ids):
B, N, C = x.shape
# QKV
qkv = self.qkv(x).reshape(B, N, 3, -1)
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
# 模态感知注意力权重
attn_weights = []
for i, mod in enumerate(['text', 'image', 'speech']):
mask = (modality_ids == i)
if mask.any():
attn_w = self.modality_attn[mod](x[mask])
attn_weights.append((mask, attn_w))
# 跨模态注意力
cross_attn = self.modality_attn['cross'](x)
# 融合
# ... 完整的注意力计算
return out3.3 自条件掩码扩散
核心思想:在去噪过程中利用已知的未掩码token作为条件。
class SelfConditionedDiffusion:
"""
自条件掩码扩散
"""
def __init__(self, model, num_steps=1000):
self.model = model
self.num_steps = num_steps
self.beta_schedule = self._cosine_beta_schedule()
def training_loss(self, x0, modality_ids):
"""
训练损失
"""
t = torch.randint(0, self.num_steps, (x0.shape[0],))
# 掩码采样
mask = self._sample_mask(x0.shape, t)
# 添加掩码
x_noisy = x0 * (1 - mask) + torch.randint_like(x0, self.model.vocab_size) * mask
# 预测原始token
logits = self.model(x_noisy, t, modality_ids, mask)
# 损失:只计算掩码位置
loss = F.cross_entropy(
logits[mask],
x0[mask],
reduction='mean'
)
return loss
@torch.no_grad()
def sampling(self, batch_size, modality_ids, prompt=None):
"""
采样生成
"""
# 初始化:全部掩码
x = torch.full((batch_size, self.max_len), self.mask_token_id)
# 如果有提示,先填充提示
if prompt is not None:
x[:, :len(prompt)] = prompt
# 逐步去噪
for t in reversed(range(self.num_steps)):
# 预测
logits = self.model(x, t * torch.ones(batch_size), modality_ids)
# 采样(只更新掩码位置)
probs = F.softmax(logits, dim=-1)
x_new = torch.multinomial(probs, 1).squeeze(-1)
x = x * (1 - mask) + x_new * mask
return x4. 训练策略
4.1 多任务联合训练
Dynin-Omni使用统一的训练目标:
其中 是在时间步 被掩码的位置。
4.2 模态平衡采样
为了避免模态不平衡:
class BalancedSampler:
"""
平衡采样器
确保各模态样本均衡
"""
def __init__(self, dataset, modality_weights=None):
self.dataset = dataset
# 默认:文本权重高(数据量大)
self.weights = modality_weights or {
'text': 1.0,
'image': 2.5, # 提升图像权重
'speech': 3.0 # 提升语音权重
}
def __iter__(self):
while True:
# 按权重采样模态
mod = random.choices(
['text', 'image', 'speech'],
weights=[self.weights[m] for m in ['text', 'image', 'speech']]
)[0]
# 采样该模态的数据
batch = self.dataset.sample(mod, batch_size=16)
yield batch4.3 课程学习
训练分为多个阶段:
| 阶段 | 任务 | 掩码率 |
|---|---|---|
| 1 | 单模态生成 | 15-30% |
| 2 | 双模态联合 | 30-50% |
| 3 | 全模态统一 | 50-75% |
5. 实验结果
5.1 文本生成
语言建模困惑度:
| 模型 | WikiText-103 | Pile |
|---|---|---|
| LLaMA | 12.3 | 10.8 |
| GPT-NeoX | 14.1 | 12.2 |
| Dynin-Omni | 11.8 | 10.2 |
5.2 图像生成
FID评分:
| 模型 | CelebA-HQ | ImageNet 256 |
|---|---|---|
| DDPM | 8.3 | 12.1 |
| DiT | 5.2 | 8.1 |
| Dynin-Omni | 4.8 | 7.2 |
5.3 语音生成
语音质量评估:
| 模型 | MOS Score | CER↓ |
|---|---|---|
| GTTS | 4.5 | - |
| VALL-E | 4.2 | 3.2% |
| Dynin-Omni | 4.4 | 2.8% |
5.4 跨模态能力
多模态理解与生成:
| 任务 | 描述 | 准确率 |
|---|---|---|
| Text→Image | 文本生成图像 | 82.3% |
| Image→Text | 图像描述 | 78.5% |
| Text→Speech | 文本合成语音 | 91.2% |
| Speech→Text | 语音识别 | 94.1% |
6. 代码实现
6.1 完整模型
class DyninOmni(nn.Module):
"""
Dynin-Omni完整模型
"""
def __init__(self, config):
super().__init__()
# Tokenizer配置
self.text_vocab_size = config.text_vocab_size
self.image_vocab_size = config.image_vocab_size
self.speech_vocab_size = config.speech_vocab_size
self.total_vocab_size = config.total_vocab_size
# 统一Transformer
self.transformer = OmniDiffusionTransformer(config)
# 解码器
self.text_decoder = TextDecoder(config)
self.image_decoder = ImageDecoder(config)
self.speech_decoder = SpeechDecoder(config)
def encode(self, data, modality):
"""
编码不同模态
"""
if modality == 'text':
return self.text_tokenizer.encode(data)
elif modality == 'image':
return self.image_tokenizer.encode(data)
elif modality == 'speech':
return self.speech_tokenizer.encode(data)
def decode(self, tokens, modality):
"""
解码为不同模态
"""
if modality == 'text':
return self.text_tokenizer.decode(tokens)
elif modality == 'image':
return self.image_tokenizer.decode(tokens)
elif modality == 'speech':
return self.speech_tokenizer.decode(tokens)
def forward(self, x, t, modality_ids, mask):
"""
前向传播
"""
return self.transformer(x, t, modality_ids, mask)
@torch.no_grad()
def generate(self, prompt=None, modalities=['text', 'image']):
"""
生成
"""
results = {}
for mod in modalities:
# 编码提示(如有)
prompt_tokens = self.encode(prompt[mod], mod) if prompt and mod in prompt else None
# 扩散采样
tokens = self.diffusion.sampling(
batch_size=1,
modality_ids=self._get_modality_ids(len(tokens), mod),
prompt=prompt_tokens
)
# 解码
results[mod] = self.decode(tokens, mod)
return results6.2 训练循环
def train_dynin_omni(model, train_loader, optimizer, config):
"""
训练Dynin-Omni
"""
model.train()
for epoch in range(config.num_epochs):
for batch in train_loader:
x, modality_ids = batch['data'], batch['modality_ids']
# 计算损失
loss = model.training_loss(x, modality_ids)
# 反向传播
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# 记录
if global_step % config.log_interval == 0:
logger.info(f"Step {global_step}: loss={loss.item():.4f}")
global_step += 17. 总结与展望
7.1 主要贡献
- 首个全模态扩散模型:统一文本、图像、语音
- 掩码扩散框架:高效的离散模态处理
- 自条件机制:利用已知信息加速生成
7.2 局限性
- 计算复杂度:多模态处理计算量大
- 模态平衡:需要特殊策略处理不平衡
- 生成质量:某些模态可能不如专用模型
7.3 未来方向
- 扩展到更多模态(视频、3D)
- 更高效的架构设计
- 更好的模态平衡策略