概述

SAMBA是由微软研究院提出的混合状态空间-注意力架构,创新性地将Mamba(选择性SSM)与滑动窗口注意力结合,在保持语言模型能力的同时实现线性时间复杂度和极低内存占用1

核心设计理念:

  • Mamba层:高效压缩历史信息到递归隐藏状态
  • 滑动窗口注意力层:精确回忆近期记忆
  • 交替堆叠:在效率和精确性间取得平衡

核心设计

架构概览

输入Token序列
    ↓
Embedding
    ↓
┌─────────────────────────────────────┐
│  SAMBA Block (重复L次)              │
│  ┌───────────────────────────────┐  │
│  │  Mamba Layer                  │  │
│  │  - 选择性状态空间扫描         │  │
│  │  - 压缩历史到隐藏状态         │  │
│  │  - O(N) 时间复杂度            │  │
│  └───────────────────────────────┘  │
│              ↓                      │
│  ┌───────────────────────────────┐  │
│  │  SwiGLU MLP                  │  │
│  │  - 门控线性单元               │  │
│  └───────────────────────────────┘  │
│              ↓                      │
│  ┌───────────────────────────────┐  │
│  │  Sliding Window Attention    │  │
│  │  - 固定窗口大小W              │  │
│  │  - 精确回忆近期上下文         │  │
│  └───────────────────────────────┘  │
│              ↓                      │
│  ┌───────────────────────────────┐  │
│  │  SwiGLU MLP                  │  │
│  └───────────────────────────────┘  │
└─────────────────────────────────────┘
    ↓
输出Logits

层配置

层类型层数占比作用
Mamba~40%长期依赖压缩
Sliding Window Attention~40%精确局部回忆
MLP~20%非线性变换

消融实验:层配比

配置Mamba:SWAMMLU平均困惑度
全部Mamba12:064.26.8
全部SWA0:1266.16.4
SAMBA6:667.85.9

SSM与注意力的互补性

状态空间模型的优势与局限

方面SSM (Mamba)注意力
时间复杂度O(N)O(N²)
空间复杂度O(N) (固定隐藏状态)O(N²) (KV缓存)
记忆压缩✓ 高效压缩✗ 完整存储
精确回忆✗ 马尔可夫假设✓ 任意位置精确访问
长度外推中等
检索任务较弱

SAMBA的解决思路

历史信息流:
    
Token序列: [t₁, t₂, ..., t₁₀₀₀, ..., t₁₀₀₀₀]
              ↓            ↓           ↓
            Mamba      滑动窗口    全部历史
           压缩存储    精确回忆     (隐藏状态)

Mamba的压缩能力

  • 将整个历史压缩到固定大小的隐藏状态
  • 适合捕获长期语义依赖
  • 适合时间序列预测

滑动窗口注意力的精确回忆

  • 精确检索最近W个token
  • 捕获局部语法结构
  • 处理需要精确匹配的任务

数学框架

Mamba层

选择性状态空间模型的前向传播:

其中选择性参数由输入动态生成:

滑动窗口注意力

其中掩码矩阵 (窗口大小),否则

融合机制

SAMBA通过交替堆叠融合两种机制:


实验结果

基准性能

模型参数量MMLUHumanEvalGSM8K平均
Samba-3.8B3.8B67.876.294.1SOTA
Phi-3-mini3.8B66.473.489.8-
LLaMA-3-3B3.8B65.372.185.4-

吞吐量对比

场景基准SAMBA加速比
128K上下文推理TransformerSAMBA3.73×
64K生成长度TransformerSAMBA3.64×
1M上下文OOMSAMBA可运行

长度外推能力

上下文长度困惑度Passkey检索
训练长度 (4K)5.9100%
32K6.299.8%
128K6.898.5%
1M8.194.2%

记忆任务分析

任务类型SSM单独SWA单独SAMBA
长期依赖 (10K+)✓ 85%✗ 12%92%
精确检索 (Phonebook)✗ 23%✓ 98%99%
语法结构中等✓ 95%96%

长上下文处理

KV缓存效率

传统Transformer的KV缓存在长上下文下成为瓶颈:

模型128K上下文KV缓存内存占用
LLaMA-3-8B8×128K×d_k×layers~16GB
Samba-3.8B~4K×d_k×layers + SSM状态~0.5GB

推理效率

# SAMBA推理伪代码
class SambaLM:
    def __init__(self, config):
        self.mamba_layers = [MambaLayer() for _ in range(config.n_mamba)]
        self.swa_layers = [SWALayer() for _ in range(config.n_swa)]
    
    def forward(self, x):
        for mamba, swa in zip(self.mamba_layers, self.swa_layers):
            x = mamba(x)      # O(N) 压缩历史
            x = mlp(x)        # 非线性
            x = swa(x)        # O(W) 精确回忆
            x = mlp(x)
        return x
    
    def generate(self, prompt, max_len):
        # 预填充阶段
        cache = self.forward(prompt)
        
        # 解码阶段 - 增量计算
        for _ in range(max_len):
            # Mamba: 常数时间状态更新
            cache = cache.mamba_step()
            # SWA: 只关注窗口内
            cache = cache.swa_step()
            yield cache

与其他混合架构对比

架构混合方式SSM层注意力层特点
Jamba层间MambaFull Attn交替堆叠
Mamba-2SSM-Attention融合SSDPartial对偶性
SAMBA层间+窗口MambaSWA互补优势

设计哲学对比

架构解决什么问题方法
Jamba长上下文效率增加SSM减少Attn
Mamba-2SSM-Attn统一数学对偶性
SAMBA精确回忆SWA补充SSM

实现细节

模型配置

参数Samba-0.4BSamba-1.8BSamba-3.8B
隐藏维度102420483072
SSM状态维度161616
层数243648
注意力头162424
滑动窗口51210242048
训练数据0.5T1.2T3.2T

训练配置

# SAMBA训练超参数
config = {
    "optimizer": "AdamW",
    "learning_rate": 3e-4,
    "weight_decay": 0.1,
    "beta": (0.9, 0.95),
    "warmup_steps": 2000,
    "context_length": 4096,
    "batch_size": 4,  # per device
    "gradient_accumulation": 4,
    "max_seq_len": 4096,
    "activation": "swish",  # SiLU
    "norm": "RMSNorm",
}

总结

SAMBA的核心贡献:

  1. 创新的混合架构,结合Mamba和滑动窗口注意力的互补优势
  2. 突破性的效率提升,128K上下文下3.73×加速
  3. 优秀的长度外推,零样本扩展至1M token
  4. 精确回忆机制,通过SWA补充SSM的马尔可夫局限
  5. 统一的语言模型能力,在多项基准上达到SOTA

设计启示

SAMBA的成功表明:

  • 互补设计比单一架构更有效
  • 压缩与精确可以共存
  • 滑动窗口是处理精确回忆的轻量级方案

参考文献


相关主题

Footnotes

  1. Wang, P., et al. (2024). SAMBA: Simple Stateful State Space Model for Efficient Language Modeling. arXiv:2406.07522. https://arxiv.org/abs/2406.07522