MemMamba:状态空间模型中的记忆模式重思考

概述

MemMamba是针对Mamba等选择性状态空间模型(Selective State Space Models)的记忆机制进行深入分析的论文,提出**选择性状态传递(Selective State Transmission)遗忘门控(Gating-for-Oblivion)**机制来改进长程依赖建模。1

核心论文:arXiv:2510.032791
研究机构:中国人民大学、上海财经大学、上海人工智能实验室

核心洞察

  • 分析了Mamba中”记忆丢失”的根本原因
  • 发现现有SSM在状态压缩时存在信息选择性不足
  • 提出记忆增强机制,在不增加计算复杂度的情况下改善长程依赖

1. 背景:SSM的记忆问题

1.1 状态空间模型回顾

标准SSM(Mamba)的核心方程:

其中:

  • :隐藏状态(记忆)
  • :离散化后的状态矩阵和输入矩阵
  • :输出投影

1.2 记忆容量问题

核心问题:隐藏状态 的维度 通常远小于输入序列长度

时,状态空间必须压缩输入信息,导致:

  1. 记忆冲突:不同时间步的信息被压缩到相同状态
  2. 选择性缺失:无法有效区分重要/不重要信息
  3. 灾难性遗忘:新信息覆盖旧信息
# 记忆容量问题示意
class SSMMemoryAnalysis:
    """
    分析SSM的记忆容量限制
    """
    def __init__(self, state_dim, seq_len):
        self.d_state = state_dim
        self.seq_len = seq_len
        
    def memory_requirement(self, info_bits_per_token):
        """
        计算存储所有信息所需的最小状态维度
        
        对于长度为L的序列,每个token携带I位信息:
        总信息量 = L × I bits
        
        状态容量 = d_state × log(V) bits
        其中V是状态值的取值范围
        """
        total_info = self.seq_len * info_bits_per_token
        return total_info
    
    def compression_ratio(self, info_bits=10):
        """
        计算压缩比
        
        当序列很长时:
        - 需要的记忆容量 >> 实际状态维度
        - 必须进行选择性压缩
        """
        required = self.memory_requirement(info_bits)
        capacity = self.d_state * 16  # 假设float16
        return capacity / required

1.3 现有解决方案的局限

方法思路局限
增大 增加记忆容量计算/内存增加
外部记忆KV Cache等需要额外存储
层次SSM多尺度记忆复杂、难以优化

2. MemMamba核心分析

2.1 遗忘动态分析

MemMamba首先分析了Mamba中遗忘的动态过程

定理(遗忘率)

,则状态更新为:

对于第 步的输入 ,其对当前状态 的贡献为:

遗忘因子 决定了过去信息的衰减速度。

2.2 选择性不足问题

问题:Mamba的选择性主要由输入依赖的 控制:

但这种选择性是标量级别的(每个通道共享),无法区分:

  • 长期依赖 vs 短期依赖
  • 重要信息 vs 噪声
  • 不同语义角色的token
# Mamba选择性的局限
class MambaSelectivityLimit:
    """
    演示Mamba中选择性不足的问题
    """
    def analyze_scalar_selection(self, delta_t):
        """
        Mamba的选择性来自delta_t(标量)
        
        问题:
        1. 同一通道内,所有信息以相同速度衰减
        2. 无法根据语义重要性调整
        3. 长序列时,所有token竞争相同状态空间
        """
        # delta_t 控制全局衰减速度
        # 但同一个状态维度无法区分:
        # - 长期依赖(需要保留)
        # - 短期依赖(需要快速衰减)
        pass

2.3 记忆冲突检测

MemMamba定义了记忆冲突的概念:

定义(记忆冲突)

对于两个语义不同的token ,如果它们被映射到相似的状态表示,则发生记忆冲突:

其中 为余弦相似度, 为阈值。

实验观察

  • 当序列长度超过状态维度的10倍时,冲突率显著上升
  • 冲突主要集中在语法相似但语义不同的token

3. 核心方法:选择性状态传递

3.1 记忆增强架构

MemMamba的核心改进是在SSM中引入选择性状态传递机制:

原始Mamba:
x → [Input Projection] → SSM Scan → Output

MemMamba:
x → [Input Projection] → SSM Scan → [Memory Bank] → Output
                              ↓
                        选择性状态传递

3.2 记忆银行(Memory Bank)

设计:维护一个固定大小的记忆银行

其中 是记忆槽位数。

选择机制

对于每个时间步 ,计算与记忆银行的关联度

更新规则

其中 是更新率, 是选择阈值。

3.3 遗忘门控

动机:解决”重要信息被遗忘”的问题。

遗忘门控机制

其中:

  • :遗忘门控向量(逐元素)
  • :最相关的记忆槽
  • :逐元素乘法

门控学习

class MemMambaBlock(nn.Module):
    """MemMamba块实现"""
    def __init__(self, d_model, d_state=16, n_memory=8):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.n_memory = n_memory
        
        # 标准Mamba组件
        self.x_proj = nn.Linear(d_model, d_state * 2 + 1, bias=False)
        self.A_log = nn.Parameter(torch.randn(d_model, d_state))
        self.D = nn.Parameter(torch.ones(d_model))
        
        # 记忆银行
        self.memory_bank = nn.Parameter(torch.randn(n_memory, d_model))
        
        # 遗忘门控
        self.gate_proj = nn.Linear(d_model * 2, d_model)
        
        # 记忆选择器
        self.memory_selector = nn.Linear(d_model, n_memory)
        
    def forward(self, x):
        B, L, D = x.shape
        
        # 标准Mamba前向
        h = self._ssm_forward(x)
        
        # 记忆选择
        memory_scores = self.memory_selector(h)  # (B, L, n_memory)
        memory_weights = F.softmax(memory_scores, dim=-1)
        
        # 获取最相关记忆
        top_memory_idx = memory_weights.argmax(dim=-1)  # (B, L)
        top_memory = self.memory_bank[top_memory_idx]  # (B, L, D)
        
        # 遗忘门控
        gate_input = torch.cat([h, top_memory], dim=-1)
        gate = torch.sigmoid(self.gate_proj(gate_input))
        
        # 选择性状态传递
        h_enhanced = (1 - gate) * h + gate * top_memory
        
        # 记忆更新(简化版)
        # 实际实现中需要更复杂的选择和更新逻辑
        
        return h_enhanced

3.4 复杂度分析

组件时间复杂度空间复杂度
原始Mamba
MemMamba

其中 是记忆银行大小(通常 )。


4. 遗忘门控的数学分析

4.1 信息保留保证

定理(遗忘门控的信息保留)

表示随机变量 之间的互信息。则:

其中 是平均遗忘率。

意义:遗忘门控不会完全丢失过去的信息,而是通过记忆银行进行补充。

4.2 长期依赖改善

命题

对于长期依赖 ,其有效路径长度为:

这意味着通过选择性状态传递,可以有效延长依赖追踪的”有效长度”。


5. 实验结果

5.1 长序列基准测试

Passkey Retrieval任务(需要长程依赖):

模型长度=1K长度=4K长度=16K
Mamba98.2%87.3%54.1%
Mamba + KV Cache99.1%96.8%93.2%
MemMamba98.9%95.4%89.7%

5.2 记忆冲突分析

模型冲突率@4K冲突率@16K
Mamba23.4%61.2%
MemMamba12.1%38.7%

6. 完整PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class MemoryBank:
    """记忆银行实现"""
    def __init__(self, n_slots, d_model, memory_init='random'):
        self.n_slots = n_slots
        self.d_model = d_model
        
        if memory_init == 'zeros':
            self.memory = nn.Parameter(torch.zeros(n_slots, d_model))
        else:
            self.memory = nn.Parameter(torch.randn(n_slots, d_model) * 0.02)
        
        # 记忆重要性分数
        self.importance = nn.Parameter(torch.zeros(n_slots))
        
    def select(self, query, temperature=1.0):
        """根据查询选择最相关的记忆"""
        # query: (B, L, D)
        # 返回: (B, L, D) - 最相关的记忆
        scores = torch.matmul(query, self.memory.T) / math.sqrt(self.d_model)
        weights = F.softmax(scores / temperature, dim=-1)  # (B, L, n_slots)
        
        # 加权求和获取最相关记忆
        selected_memory = torch.matmul(weights, self.memory)  # (B, L, D)
        
        return selected_memory, weights
    
    def update(self, hidden_states, weights, lr=0.01):
        """更新记忆银行"""
        # 选择性更新:只更新高权重的记忆槽
        with torch.no_grad():
            # 计算每个记忆槽的平均激活权重
            slot_importance = weights.mean(dim=(0, 1))  # (n_slots,)
            
            # 只更新重要性超过阈值的槽
            update_mask = (slot_importance > 0.1).float().unsqueeze(0).unsqueeze(0)
            
            # 计算更新量
            update = lr * torch.matmul(weights.mean(dim=1), hidden_states)
            
            # 应用更新
            self.memory.data = (1 - update_mask) * self.memory.data + update_mask * update
 
 
class ForgetGate(nn.Module):
    """遗忘门控实现"""
    def __init__(self, d_model):
        super().__init__()
        self.gate_net = nn.Sequential(
            nn.Linear(d_model * 2, d_model),
            nn.Sigmoid()
        )
        
    def forward(self, h, memory):
        """
        计算遗忘门控
        
        h: 当前隐藏状态 (B, L, D)
        memory: 选中的记忆 (B, L, D)
        """
        combined = torch.cat([h, memory], dim=-1)
        gate = self.gate_net(combined)  # (B, L, D)
        return gate
 
 
class MemMambaLayer(nn.Module):
    """MemMamba层完整实现"""
    def __init__(self, d_model, d_state=16, n_heads=4, n_memory=8, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.n_heads = n_heads
        
        # 输入投影
        self.x_proj = nn.Linear(d_model, d_state * 2 + d_state, bias=False)
        
        # SSM参数
        self.A_log = nn.Parameter(torch.randn(d_model, d_state))
        self.D = nn.Parameter(torch.ones(d_model))
        
        # 记忆组件
        self.memory_bank = MemoryBank(n_memory, d_model)
        self.forget_gate = ForgetGate(d_model)
        
        # 输出投影
        self.out_proj = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def ssm_forward(self, x):
        """标准Mamba SSM前向"""
        B, L, D = x.shape
        
        # 输入投影
        x_proj = self.x_proj(x)
        dt, B_state, C_state = x_proj.split([D, self.d_state, self.d_state], dim=-1)
        
        # dt投影
        dt = F.softplus(dt)
        
        # 离散化
        A = -torch.exp(self.A_log)
        dA = torch.exp(dt.unsqueeze(-1) * A)  # (B, L, D, d_state)
        dB = dt.unsqueeze(-1) * B_state.unsqueeze(-1)  # (B, L, d_state, d_state)
        
        # 简化的并行扫描
        h = torch.zeros(B, D, self.d_state, device=x.device)
        outputs = []
        
        for i in range(L):
            h = dA[:, i] * h + dB[:, i] * x[:, i:i+1].unsqueeze(-1)
            y = torch.einsum('bdn,bn->bd', h, C_state[:, i])
            outputs.append(y)
        
        return torch.stack(outputs, dim=1)
    
    def forward(self, x):
        # SSM前向
        h = self.ssm_forward(x)
        
        # 记忆选择
        memory, weights = self.memory_bank.select(h)
        
        # 遗忘门控
        gate = self.forget_gate(h, memory)
        
        # 选择性状态传递
        h_enhanced = (1 - gate) * h + gate * memory
        
        # 输出投影
        out = self.out_proj(h_enhanced)
        out = self.norm(out + x)  # 残差连接
        
        return self.dropout(out)
 
 
class MemMambaModel(nn.Module):
    """完整的MemMamba模型"""
    def __init__(self, vocab_size, d_model, n_layers, d_state=16, n_memory=8):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            MemMambaLayer(d_model, d_state, n_memory=n_memory)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        
    def forward(self, input_ids):
        x = self.embedding(input_ids)
        
        for layer in self.layers:
            x = layer(x)
        
        x = self.norm(x)
        return self.lm_head(x)

7. 总结

核心贡献

  1. 系统分析:深入分析了SSM中记忆冲突和选择性不足的问题
  2. 记忆银行:引入外部记忆机制缓解状态容量限制
  3. 遗忘门控:选择性决定保留当前状态还是从记忆中补充

关键洞察

状态空间模型的”选择性”不应仅来自输入依赖的衰减率,而应包括显式的记忆选择和补充机制。

局限性

  1. 记忆银行需要额外的内存
  2. 选择性机制引入的计算开销
  3. 记忆更新的策略需要进一步优化

参考资料

Footnotes

  1. Wang et al. (2025). MemMamba: Rethinking Memory Patterns in State Space Model. arXiv:2510.03279. 2