概述
SAMBA是由微软研究院提出的混合状态空间-注意力架构,创新性地将Mamba(选择性SSM)与滑动窗口注意力结合,在保持语言模型能力的同时实现线性时间复杂度和极低内存占用。1
核心设计理念:
- Mamba层:高效压缩历史信息到递归隐藏状态
- 滑动窗口注意力层:精确回忆近期记忆
- 交替堆叠:在效率和精确性间取得平衡
核心设计
架构概览
输入Token序列
↓
Embedding
↓
┌─────────────────────────────────────┐
│ SAMBA Block (重复L次) │
│ ┌───────────────────────────────┐ │
│ │ Mamba Layer │ │
│ │ - 选择性状态空间扫描 │ │
│ │ - 压缩历史到隐藏状态 │ │
│ │ - O(N) 时间复杂度 │ │
│ └───────────────────────────────┘ │
│ ↓ │
│ ┌───────────────────────────────┐ │
│ │ SwiGLU MLP │ │
│ │ - 门控线性单元 │ │
│ └───────────────────────────────┘ │
│ ↓ │
│ ┌───────────────────────────────┐ │
│ │ Sliding Window Attention │ │
│ │ - 固定窗口大小W │ │
│ │ - 精确回忆近期上下文 │ │
│ └───────────────────────────────┘ │
│ ↓ │
│ ┌───────────────────────────────┐ │
│ │ SwiGLU MLP │ │
│ └───────────────────────────────┘ │
└─────────────────────────────────────┘
↓
输出Logits
层配置
| 层类型 | 层数占比 | 作用 |
|---|---|---|
| Mamba | ~40% | 长期依赖压缩 |
| Sliding Window Attention | ~40% | 精确局部回忆 |
| MLP | ~20% | 非线性变换 |
消融实验:层配比
| 配置 | Mamba:SWA | MMLU | 平均困惑度 |
|---|---|---|---|
| 全部Mamba | 12:0 | 64.2 | 6.8 |
| 全部SWA | 0:12 | 66.1 | 6.4 |
| SAMBA | 6:6 | 67.8 | 5.9 |
SSM与注意力的互补性
状态空间模型的优势与局限
| 方面 | SSM (Mamba) | 注意力 |
|---|---|---|
| 时间复杂度 | O(N) | O(N²) |
| 空间复杂度 | O(N) (固定隐藏状态) | O(N²) (KV缓存) |
| 记忆压缩 | ✓ 高效压缩 | ✗ 完整存储 |
| 精确回忆 | ✗ 马尔可夫假设 | ✓ 任意位置精确访问 |
| 长度外推 | 中等 | 好 |
| 检索任务 | 较弱 | 强 |
SAMBA的解决思路
历史信息流:
Token序列: [t₁, t₂, ..., t₁₀₀₀, ..., t₁₀₀₀₀]
↓ ↓ ↓
Mamba 滑动窗口 全部历史
压缩存储 精确回忆 (隐藏状态)
Mamba的压缩能力:
- 将整个历史压缩到固定大小的隐藏状态
- 适合捕获长期语义依赖
- 适合时间序列预测
滑动窗口注意力的精确回忆:
- 精确检索最近W个token
- 捕获局部语法结构
- 处理需要精确匹配的任务
数学框架
Mamba层
选择性状态空间模型的前向传播:
其中选择性参数由输入动态生成:
滑动窗口注意力
其中掩码矩阵 当 (窗口大小),否则
融合机制
SAMBA通过交替堆叠融合两种机制:
实验结果
基准性能
| 模型 | 参数量 | MMLU | HumanEval | GSM8K | 平均 |
|---|---|---|---|---|---|
| Samba-3.8B | 3.8B | 67.8 | 76.2 | 94.1 | SOTA |
| Phi-3-mini | 3.8B | 66.4 | 73.4 | 89.8 | - |
| LLaMA-3-3B | 3.8B | 65.3 | 72.1 | 85.4 | - |
吞吐量对比
| 场景 | 基准 | SAMBA | 加速比 |
|---|---|---|---|
| 128K上下文推理 | Transformer | SAMBA | 3.73× |
| 64K生成长度 | Transformer | SAMBA | 3.64× |
| 1M上下文 | OOM | SAMBA | 可运行 |
长度外推能力
| 上下文长度 | 困惑度 | Passkey检索 |
|---|---|---|
| 训练长度 (4K) | 5.9 | 100% |
| 32K | 6.2 | 99.8% |
| 128K | 6.8 | 98.5% |
| 1M | 8.1 | 94.2% |
记忆任务分析
| 任务类型 | SSM单独 | SWA单独 | SAMBA |
|---|---|---|---|
| 长期依赖 (10K+) | ✓ 85% | ✗ 12% | ✓ 92% |
| 精确检索 (Phonebook) | ✗ 23% | ✓ 98% | ✓ 99% |
| 语法结构 | 中等 | ✓ 95% | ✓ 96% |
长上下文处理
KV缓存效率
传统Transformer的KV缓存在长上下文下成为瓶颈:
| 模型 | 128K上下文KV缓存 | 内存占用 |
|---|---|---|
| LLaMA-3-8B | 8×128K×d_k×layers | ~16GB |
| Samba-3.8B | ~4K×d_k×layers + SSM状态 | ~0.5GB |
推理效率
# SAMBA推理伪代码
class SambaLM:
def __init__(self, config):
self.mamba_layers = [MambaLayer() for _ in range(config.n_mamba)]
self.swa_layers = [SWALayer() for _ in range(config.n_swa)]
def forward(self, x):
for mamba, swa in zip(self.mamba_layers, self.swa_layers):
x = mamba(x) # O(N) 压缩历史
x = mlp(x) # 非线性
x = swa(x) # O(W) 精确回忆
x = mlp(x)
return x
def generate(self, prompt, max_len):
# 预填充阶段
cache = self.forward(prompt)
# 解码阶段 - 增量计算
for _ in range(max_len):
# Mamba: 常数时间状态更新
cache = cache.mamba_step()
# SWA: 只关注窗口内
cache = cache.swa_step()
yield cache与其他混合架构对比
| 架构 | 混合方式 | SSM层 | 注意力层 | 特点 |
|---|---|---|---|---|
| Jamba | 层间 | Mamba | Full Attn | 交替堆叠 |
| Mamba-2 | SSM-Attention融合 | SSD | Partial | 对偶性 |
| SAMBA | 层间+窗口 | Mamba | SWA | 互补优势 |
设计哲学对比
| 架构 | 解决什么问题 | 方法 |
|---|---|---|
| Jamba | 长上下文效率 | 增加SSM减少Attn |
| Mamba-2 | SSM-Attn统一 | 数学对偶性 |
| SAMBA | 精确回忆 | SWA补充SSM |
实现细节
模型配置
| 参数 | Samba-0.4B | Samba-1.8B | Samba-3.8B |
|---|---|---|---|
| 隐藏维度 | 1024 | 2048 | 3072 |
| SSM状态维度 | 16 | 16 | 16 |
| 层数 | 24 | 36 | 48 |
| 注意力头 | 16 | 24 | 24 |
| 滑动窗口 | 512 | 1024 | 2048 |
| 训练数据 | 0.5T | 1.2T | 3.2T |
训练配置
# SAMBA训练超参数
config = {
"optimizer": "AdamW",
"learning_rate": 3e-4,
"weight_decay": 0.1,
"beta": (0.9, 0.95),
"warmup_steps": 2000,
"context_length": 4096,
"batch_size": 4, # per device
"gradient_accumulation": 4,
"max_seq_len": 4096,
"activation": "swish", # SiLU
"norm": "RMSNorm",
}总结
SAMBA的核心贡献:
- 创新的混合架构,结合Mamba和滑动窗口注意力的互补优势
- 突破性的效率提升,128K上下文下3.73×加速
- 优秀的长度外推,零样本扩展至1M token
- 精确回忆机制,通过SWA补充SSM的马尔可夫局限
- 统一的语言模型能力,在多项基准上达到SOTA
设计启示
SAMBA的成功表明:
- 互补设计比单一架构更有效
- 压缩与精确可以共存
- 滑动窗口是处理精确回忆的轻量级方案
参考文献
相关主题
Footnotes
-
Wang, P., et al. (2024). SAMBA: Simple Stateful State Space Model for Efficient Language Modeling. arXiv:2406.07522. https://arxiv.org/abs/2406.07522 ↩