混合SSM-Transformer架构
混合SSM-Transformer架构结合了Transformer的全局注意力机制与状态空间模型(SSM)的高效线性复杂度,旨在兼顾两者的优势:Transformer的表达能力与SSM的推理效率。
1. Jamba架构
1.1 核心设计理念
Jamba是首个生产级别的混合SSM-Transformer模型,由AI21 Labs于2024年3月发布。其核心设计采用blocks-and-layers结构1:
Jamba Block 结构
┌─────────────────────────────────────┐
│ Block (重复 L 次) │
│ ┌─────────────────────────────────┐ │
│ │ Attention Layer (MHA/GQA) │ │
│ │ Mamba Layer × 7 │ │
│ │ [MoE Layer] (每2层) │ │
│ └─────────────────────────────────┘ │
└─────────────────────────────────────┘
1.2 关键架构参数
Jamba-1.5-Large的配置如下2:
| 参数 | 值 |
|---|---|
| 总参数量 | 398B |
| 激活参数量 | 94B |
| 每Block层数 | 8 |
| Attention : Mamba 比例 | 1 : 7 |
| MoE专家数 | 16 |
| Top-K路由 | 2 |
| 隐藏维度 | 8192 |
| 注意力头数 | 64 |
| KV头数 | 8 |
1.3 256K上下文窗口
Jamba的核心优势之一是支持长达256K tokens的上下文窗口,这得益于:
- Mamba层的选择性机制:SSM可以高效处理长序列
- Transformer层的补充:保留全局建模能力
- KV缓存优化:仅需4GB(256K上下文),而同等规模Transformer需56GB+2
| 模型 | KV缓存 (256K, FP16) |
|---|---|
| LLaMA-3.1 70B | 80GB |
| Mixtral 8x22B | 56GB |
| Jamba-1.5-Large | 9GB |
| Jamba-1.5-Mini | 4GB |
1.4 ExpertsInt8量化
为支持大模型高效推理,Jamba-1.5开发了ExpertsInt8量化技术2:
- 观测到85%+权重位于MoE层
- 将MoE和MLP权重量化至INT8
- 在融合核内即时反量化至BF16
- 优势:
- 无需校准(calibration-free)
- 延迟与FP8相当
- 支持A100 GPU(FP8仅H100可用)
1.5 激活值稳定性问题
训练过程中发现某些激活值会逐渐增大至 量级,可能导致FP16溢出。解决方案是引入激活损失(Activation Loss):
其中 可有效将激活值控制在2K-3K范围内。
2. Nemotron-H架构
2.1 NVIDIA的混合架构方案
Nemotron-H是NVIDIA提出的混合Mamba-Transformer模型系列,包含8B和56B两种规模3:
| 模型 | 层数 | Attention层 | Mamba-2层 | FFN层 |
|---|---|---|---|---|
| Nemotron-H-8B | 52 | 4 (8%) | 24 | 24 |
| Nemotron-H-56B | 118 | 10 (8%) | 54 | 54 |
2.2 关键设计原则
- 注意力层比例:约8%的总层数为自注意力层
- 均匀分散:注意力层在模型中均匀分布
- 首尾约束:
- 首层必须是Mamba-2层
- 最后一层必须是FFN层
- Attention层总是紧跟在FFN层之前
2.3 Mamba-2集成
Nemotron-H采用Mamba-2而非Mamba-1:
# Mamba-2配置
head_dim = 64
expansion_factor = 2
conv_kernel_size = 4
state_dim_8B = 128
state_dim_56B = 256为何Mamba-2-Attention组合:实验表明,在混合架构中Mamba-1-Attention表现优于Mamba-2-Attention,因为注意力层可以补充Mamba-2缺失的部分能力。
2.4 FP8训练配方
Nemotron-H-56B采用FP8训练3:
- per-tensor动态量化:整个张量使用单一缩放因子
- 首尾保留:首尾各4层保持BF16以保证稳定性
- 混合精度:E4M3用于权重/激活,E5M2用于梯度
- 性能:训练损失差距<0.1%
2.5 MiniPuzzle剪枝蒸馏
通过MiniPuzzle技术将56B模型压缩至47B:
- 使用63B训练tokens
- FP8训练
- 精度损失极小
- 可部署于单卡RTX 5090(32GiB)
3. 混合策略对比
3.1 层间混合 vs 层内混合
有两种主要的混合策略4:
层间混合 (Inter-layer):
┌────┐ ┌────┐ ┌────┐ ┌────┐
│Attn│ │Mamba│ │Mamba│ │Mamba│
└────┘ └────┘ └────┘ └────┘
层内混合 (Intra-layer):
┌────────────────────────────┐
│ Half Attn │ Half Mamba │
└────────────────────────────┘
| 策略 | 优势 | 劣势 |
|---|---|---|
| 层间混合 | 实现简单,灵活性高 | 需要更多Mamba层才能匹配 |
| 层内混合 | 更紧凑 | 实现复杂度高 |
3.2 注意力层分散策略
实验表明,均匀分散注意力层优于集中放置:
- 集中注意力:短上下文好,长上下文差
- 均匀分散:长短上下文均优
- 比例选择:8%是性能和效率的良好平衡点
4. 性能分析
4.1 推理效率
| 模型 | 上下文长度 | 相对速度 |
|---|---|---|
| Transformer | 128K | 1x |
| Jamba-1.5-Large | 256K | 3-10x (长上下文) |
| Nemotron-H-56B | 128K | 3x (65K输入) |
4.2 基准测试
| 模型 | MMLU-Pro | 相对速度提升 |
|---|---|---|
| LLaMA-3.1 70B | 53.0 | 1x |
| Mistral-Large-2 | 54.2 | 1x |
| Jamba-1.5-Large | 48.3 | 1.5-3x |
| Nemotron-H-56B | ~52 | 3x |
4.3 长上下文能力
在RULER基准上的256K有效长度表现2:
| 模型 | 256K得分 | 有效长度 |
|---|---|---|
| GPT-4-1106 | - | 64K |
| LLaMA-3.1 70B | - | 64K |
| Jamba-1.5-Large | 93.9 | 256K |
| Jamba-1.5-Mini | 86.1 | 256K |
5. 实践指南
5.1 何时选择混合架构
| 场景 | 推荐架构 |
|---|---|
| 长上下文应用(>32K) | Jamba/Nemotron-H |
| 资源受限部署 | 混合架构 + 量化 |
| 短上下文,高精度 | 纯Transformer |
| 需要快速推理 | 混合架构 |
5.2 训练注意事项
- 激活值监控:使用Activation Loss防止溢出
- 专家负载均衡:MoE路由需要辅助损失
- 数据混合:长文档数据对Mamba层有益
5.3 推理部署建议
# 使用vLLM部署Jamba
from vllm import LLM
model = LLM("ai21labs/Jamba-1.5-Mini")
# 自动利用ExpertsInt8量化6. 未来展望
- 更大规模:Jamba已验证398B规模的可行性
- 多模态扩展:Nemotron-H-VLM已展示视觉-语言能力
- 推理时缩放:推理时计算量的增加对混合架构更有利
- 硬件协同:定制化Kernel进一步提升效率
参考资料
Footnotes
-
Lieber O, et al. Jamba: A Hybrid Transformer-Mamba Language Model. arXiv:2403.19887, 2024. ↩
-
AI21 Labs. Jamba-1.5: Hybrid Transformer-Mamba Models at Scale. arXiv:2408.12570, 2024. ↩ ↩2 ↩3 ↩4 ↩5
-
NVIDIA. Nemotron-H: A Family of Accurate and Efficient Hybrid Mamba-Transformer Models. arXiv:2504.03624, 2025. ↩ ↩2 ↩3
-
Waleffe R, et al. A Study of Hybrid Transformer-Mamba Language Models. 2024. ↩