状态空间模型与Mamba

状态空间模型(State Space Model, SSM)是一类用于建模时序数据的经典方法。2023年,Gu和Dao提出了Mamba,通过选择性机制将SSM扩展为可以与Transformer竞争的高效序列建模架构。

SSM基础理论

连续时间状态空间模型

SSM将输入信号 映射到隐状态 和输出

其中:

  • :状态转移矩阵
  • :输入矩阵
  • :输出矩阵
  • :直接馈通(通常设为0)

离散化

计算机处理离散数据,需要将连续模型离散化。

Zero-Order Hold (ZOH)

离散化参数:

其中 是采样间隔。

S4模型的贡献

S4(Structured State Space Sequence Model)解决了离散SSM的两个关键问题:

  1. 高效计算:通过矩阵分解实现 线性复杂度
  2. 长程依赖:使用HiPPO矩阵初始化保证信息保留

Mamba:选择性状态空间模型

核心洞察

传统SSM(如S4)的参数是输入无关的(input-independent):

Mamba的关键创新是让参数变得输入依赖(input-dependent):

其中 是根据输入 计算得到的。

选择性扫描算法

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
 
class MambaBlock(nn.Module):
    """
    Mamba块的核心实现
    """
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.d_inner = int(expand * d_model)
        
        # 投影输入
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
        
        # 卷积层
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=d_conv,
            padding=d_conv - 1,
            groups=self.d_inner,
        )
        
        # SSM参数投影
        # x -> (B, dt, C, D)
        self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
        
        # dt投影
        self.dt_proj = nn.Linear(1, self.d_inner)
        
        # A矩阵(初始化为HiPPO风格)
        self.A_log = nn.Parameter(torch.arange(1, d_state + 1, dtype=torch.float).unsqueeze(0).repeat(self.d_inner, 1))
        self.D = nn.Parameter(torch.ones(self.d_inner))
        
        # 输出投影
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
    
    def selective_scan(self, x, dt, A, B, C, D):
        """
        选择性扫描算法
        x: (batch, seq_len, d_inner)
        dt: (batch, seq_len, d_inner)
        A: (d_inner, d_state)
        B: (batch, seq_len, d_state)
        C: (batch, seq_len, d_state)
        """
        batch, seq_len, d_inner = x.shape
        d_state = A.shape[1]
        
        # 离散化A, B
        # dt = softplus(dt) 确保正数
        dt = F.softplus(dt)
        dA = torch.exp(dt.unsqueeze(-1) * A)  # (batch, seq, d_inner, d_state)
        dB = dt.unsqueeze(-1) * B.unsqueeze(2)  # (batch, seq, d_inner, d_state)
        
        # 扫描
        h = torch.zeros(batch, d_inner, d_state, device=x.device, dtype=x.dtype)
        ys = []
        
        for i in range(seq_len):
            h = dA[:, i] * h + dB[:, i] * x[:, i].unsqueeze(-1)
            y = torch.einsum('bdn,bn->bd', h, C[:, i])
            ys.append(y)
        
        y = torch.stack(ys, dim=1)  # (batch, seq, d_inner)
        y = y + D * x
        
        return y
    
    def forward(self, x):
        # x: (batch, seq_len, d_model)
        batch, seq_len, d_model = x.shape
        
        # 输入投影并分割
        xz = self.in_proj(x)  # (batch, seq, d_inner * 2)
        x_inner, z = xz.chunk(2, dim=-1)  # 各 (batch, seq, d_inner)
        
        # 卷积
        x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :seq_len].transpose(1, 2)
        x_conv = F.silu(x_conv)
        
        # 计算SSM参数
        x_proj_out = self.x_proj(x_conv)  # (batch, seq, d_state * 2 + 1)
        dt, B, C = x_proj_out.split([1, self.d_state, self.d_state], dim=-1)
        
        # dt归一化
        dt = F.softplus(self.dt_proj(dt))  # (batch, seq, d_inner)
        
        # 选择性扫描
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
        y = self.selective_scan(x_conv, dt, A, B, C, self.D)
        
        # 门控
        y = y * F.silu(z)
        
        # 输出投影
        return self.out_proj(y)

选择性机制详解

为什么选择性重要?

考虑序列选择任务:

  • 信息压缩:选择性地将信息从输入压缩到状态
  • RNN行为:隐状态可以无限期地记住重要信息
  • 线性复杂度:选择性的同时保持 计算
# 选择性机制的效果示意
def selective_vs_fixed():
    """
    固定SSM vs 选择性SSM
    """
    # 固定SSM:所有输入同等对待
    # 状态更新:x_{k+1} = Ax_k + Bu_k
    # 参数A, B, C对所有输入相同
    
    # 选择性SSM:根据输入决定如何更新状态
    # x_{k+1} = A_k x_k + B_k u_k
    # 参数A_k, B_k, C_k依赖于输入
    
    """
    例如:处理 "The movie was [good/bad] and I [loved/hated] it"
    
    固定SSM:必须均匀地传播所有信息
    选择性SSM:可以专注于关键词,忽略无关词
    """

Mamba vs Transformer

复杂度对比

特性TransformerMamba
注意力复杂度
推理速度慢(KV cache大)快(5倍+)
序列长度受限于显存可处理更长序列
并行训练高效高效

选择性机制的优势

  1. 内容感知:根据输入内容选择性地处理信息
  2. 遗忘机制:不需要显式的遗忘门
  3. 推理效率:RNN式生成,无KV cache
# Transformer解码(自回归)
class TransformerDecoder:
    def forward(self, prefix_tokens):
        # 需要保存所有KV缓存
        for i in range(generate_len):
            # 每步需要注意力计算
            output = self.attention(q=current_token, 
                                   k=torch.cat([kv_cache]), 
                                   v=torch.cat([kv_cache]))
            kv_cache.append(output)
        # 显存使用: O(seq_len)
 
# Mamba解码
class MambaDecoder:
    def forward(self, prefix_tokens):
        # 隐状态压缩信息
        h = torch.zeros(batch, d_state)
        outputs = []
        for token in prefix_tokens:
            output, h = self.ssm(token, h)
            outputs.append(output)
        # 显存使用: O(d_state),恒定

基准性能

任务TransformerMamba
LRA0.670.70
ImageNet91.0%91.5%
WikiText-10320.5 PPL19.3 PPL
Pile (1B)15.1 PPL14.8 PPL

Mamba-2:SSM与Transformer的统一

状态空间对偶性(SSD)

Mamba-2提出了SSM与注意力机制的理论联系:

核心发现:

  • 注意力可以看作SSM的特殊形式
  • 可以共享参数和计算

Mamba-2架构

class Mamba2Block(nn.Module):
    """
    Mamba-2块
    - 简化的选择机制
    - 更好的并行化
    """
    def __init__(self, d_model, d_state=128, expand=2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_inner = expand * d_model
        
        # 输入投影
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
        
        # SSM参数
        self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
        
        # 简化的A矩阵(对角化)
        A = torch.arange(1, d_state + 1, dtype=torch.float)
        self.A = nn.Parameter(A.unsqueeze(0).repeat(self.d_inner, 1))
        
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
    
    def forward(self, x):
        batch, seq_len, _ = x.shape
        
        xz = self.in_proj(x)
        x_inner, z = xz.chunk(2, dim=-1)
        
        # 参数投影
        x_proj_out = self.x_proj(F.silu(x_inner))
        dt, B, C = x_proj_out.split([1, self.d_state, self.d_state], dim=-1)
        dt = F.softplus(dt)
        
        # SSD扫描(并行化版本)
        y = self.ssd_scan(x_inner, dt, self.A, B, C, self.D)
        
        y = y * F.silu(z)
        return self.out_proj(y)

使用Mamba

安装

pip install mamba-ssm

预训练模型使用

from mamba_ssm import MambaLMHeadModel
from transformers import AutoTokenizer
 
# 加载模型
model = MambaLMHeadModel.from_pretrained(
    'state-spaces/mamba-1.4b',
    device='cuda',
    dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
 
# 生成
input_ids = tokenizer("The future of AI is", return_tensors="pt")['input_ids'].to('cuda')
 
generated_ids = model.generate(
    input_ids,
    max_length=100,
    temperature=0.9,
    top_p=0.9,
    eos_token_id=tokenizer.eos_token_id
)
 
print(tokenizer.decode(generated_ids[0]))

自定义Mamba层

class MambaLayer(nn.Module):
    def __init__(self, d_model, d_state=16, expand=2):
        super().__init__()
        self.mamba = MambaBlock(d_model, d_state, expand=expand)
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        # Pre-norm架构
        return x + self.mamba(self.norm(x))
 
class MambaClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, num_classes, n_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            MambaLayer(d_model) for _ in range(n_layers)
        ])
        self.classifier = nn.Linear(d_model, num_classes)
    
    def forward(self, x):
        embedded = self.embedding(x)
        for layer in self.layers:
            embedded = layer(embedded)
        
        # 使用最后位置的表示
        return self.classifier(embedded[:, -1])

应用场景

长文本处理

Mamba的 复杂度使其适合处理超长文档:

  • 书籍摘要
  • 视频帧序列
  • 蛋白质序列

时间序列预测

class TimeSeriesMamba(nn.Module):
    def __init__(self, n_features, d_model, n_layers, d_state=16):
        super().__init__()
        self.proj = nn.Linear(n_features, d_model)
        self.layers = nn.ModuleList([
            MambaBlock(d_model, d_state) for _ in range(n_layers)
        ])
        self.fc = nn.Linear(d_model, n_features)
    
    def forward(self, x):
        # x: (batch, seq_len, n_features)
        x = self.proj(x)
        for layer in self.layers:
            x = layer(x)
        return self.fc(x)

音频处理

Mamba在音频建模中也表现出色:

  • 语音识别
  • 音乐生成
  • 音频压缩

参考