混合SSM-Transformer架构

混合SSM-Transformer架构结合了Transformer的全局注意力机制与状态空间模型(SSM)的高效线性复杂度,旨在兼顾两者的优势:Transformer的表达能力与SSM的推理效率。

1. Jamba架构

1.1 核心设计理念

Jamba是首个生产级别的混合SSM-Transformer模型,由AI21 Labs于2024年3月发布。其核心设计采用blocks-and-layers结构1

Jamba Block 结构
┌─────────────────────────────────────┐
│  Block (重复 L 次)                   │
│  ┌─────────────────────────────────┐ │
│  │ Attention Layer (MHA/GQA)       │ │
│  │ Mamba Layer × 7                 │ │
│  │ [MoE Layer] (每2层)             │ │
│  └─────────────────────────────────┘ │
└─────────────────────────────────────┘

1.2 关键架构参数

Jamba-1.5-Large的配置如下2

参数
总参数量398B
激活参数量94B
每Block层数8
Attention : Mamba 比例1 : 7
MoE专家数16
Top-K路由2
隐藏维度8192
注意力头数64
KV头数8

1.3 256K上下文窗口

Jamba的核心优势之一是支持长达256K tokens的上下文窗口,这得益于:

  1. Mamba层的选择性机制:SSM可以高效处理长序列
  2. Transformer层的补充:保留全局建模能力
  3. KV缓存优化:仅需4GB(256K上下文),而同等规模Transformer需56GB+2
模型KV缓存 (256K, FP16)
LLaMA-3.1 70B80GB
Mixtral 8x22B56GB
Jamba-1.5-Large9GB
Jamba-1.5-Mini4GB

1.4 ExpertsInt8量化

为支持大模型高效推理,Jamba-1.5开发了ExpertsInt8量化技术2

  • 观测到85%+权重位于MoE层
  • 将MoE和MLP权重量化至INT8
  • 在融合核内即时反量化至BF16
  • 优势
    • 无需校准(calibration-free)
    • 延迟与FP8相当
    • 支持A100 GPU(FP8仅H100可用)

1.5 激活值稳定性问题

训练过程中发现某些激活值会逐渐增大至 量级,可能导致FP16溢出。解决方案是引入激活损失(Activation Loss)

其中 可有效将激活值控制在2K-3K范围内。

2. Nemotron-H架构

2.1 NVIDIA的混合架构方案

Nemotron-H是NVIDIA提出的混合Mamba-Transformer模型系列,包含8B和56B两种规模3

模型层数Attention层Mamba-2层FFN层
Nemotron-H-8B524 (8%)2424
Nemotron-H-56B11810 (8%)5454

2.2 关键设计原则

  1. 注意力层比例:约8%的总层数为自注意力层
  2. 均匀分散:注意力层在模型中均匀分布
  3. 首尾约束
    • 首层必须是Mamba-2层
    • 最后一层必须是FFN层
    • Attention层总是紧跟在FFN层之前

2.3 Mamba-2集成

Nemotron-H采用Mamba-2而非Mamba-1:

# Mamba-2配置
head_dim = 64
expansion_factor = 2
conv_kernel_size = 4
state_dim_8B = 128
state_dim_56B = 256

为何Mamba-2-Attention组合:实验表明,在混合架构中Mamba-1-Attention表现优于Mamba-2-Attention,因为注意力层可以补充Mamba-2缺失的部分能力。

2.4 FP8训练配方

Nemotron-H-56B采用FP8训练3

  • per-tensor动态量化:整个张量使用单一缩放因子
  • 首尾保留:首尾各4层保持BF16以保证稳定性
  • 混合精度:E4M3用于权重/激活,E5M2用于梯度
  • 性能:训练损失差距<0.1%

2.5 MiniPuzzle剪枝蒸馏

通过MiniPuzzle技术将56B模型压缩至47B:

  • 使用63B训练tokens
  • FP8训练
  • 精度损失极小
  • 可部署于单卡RTX 5090(32GiB)

3. 混合策略对比

3.1 层间混合 vs 层内混合

有两种主要的混合策略4

层间混合 (Inter-layer):
┌────┐ ┌────┐ ┌────┐ ┌────┐
│Attn│ │Mamba│ │Mamba│ │Mamba│
└────┘ └────┘ └────┘ └────┘

层内混合 (Intra-layer):
┌────────────────────────────┐
│   Half Attn  │  Half Mamba  │
└────────────────────────────┘
策略优势劣势
层间混合实现简单,灵活性高需要更多Mamba层才能匹配
层内混合更紧凑实现复杂度高

3.2 注意力层分散策略

实验表明,均匀分散注意力层优于集中放置:

  • 集中注意力:短上下文好,长上下文差
  • 均匀分散:长短上下文均优
  • 比例选择:8%是性能和效率的良好平衡点

4. 性能分析

4.1 推理效率

模型上下文长度相对速度
Transformer128K1x
Jamba-1.5-Large256K3-10x (长上下文)
Nemotron-H-56B128K3x (65K输入)

4.2 基准测试

在MMLU-Pro等学术基准上23

模型MMLU-Pro相对速度提升
LLaMA-3.1 70B53.01x
Mistral-Large-254.21x
Jamba-1.5-Large48.31.5-3x
Nemotron-H-56B~523x

4.3 长上下文能力

在RULER基准上的256K有效长度表现2

模型256K得分有效长度
GPT-4-1106-64K
LLaMA-3.1 70B-64K
Jamba-1.5-Large93.9256K
Jamba-1.5-Mini86.1256K

5. 实践指南

5.1 何时选择混合架构

场景推荐架构
长上下文应用(>32K)Jamba/Nemotron-H
资源受限部署混合架构 + 量化
短上下文,高精度纯Transformer
需要快速推理混合架构

5.2 训练注意事项

  1. 激活值监控:使用Activation Loss防止溢出
  2. 专家负载均衡:MoE路由需要辅助损失
  3. 数据混合:长文档数据对Mamba层有益

5.3 推理部署建议

# 使用vLLM部署Jamba
from vllm import LLM
 
model = LLM("ai21labs/Jamba-1.5-Mini")
# 自动利用ExpertsInt8量化

6. 未来展望

  1. 更大规模:Jamba已验证398B规模的可行性
  2. 多模态扩展:Nemotron-H-VLM已展示视觉-语言能力
  3. 推理时缩放:推理时计算量的增加对混合架构更有利
  4. 硬件协同:定制化Kernel进一步提升效率

参考资料

Footnotes

  1. Lieber O, et al. Jamba: A Hybrid Transformer-Mamba Language Model. arXiv:2403.19887, 2024.

  2. AI21 Labs. Jamba-1.5: Hybrid Transformer-Mamba Models at Scale. arXiv:2408.12570, 2024. 2 3 4 5

  3. NVIDIA. Nemotron-H: A Family of Accurate and Efficient Hybrid Mamba-Transformer Models. arXiv:2504.03624, 2025. 2 3

  4. Waleffe R, et al. A Study of Hybrid Transformer-Mamba Language Models. 2024.