概述

xLSTM(Extended Long Short-Term Memory)是一种扩展版LSTM架构,通过引入指数门控和矩阵记忆来克服传统LSTM的局限,在语言建模任务上展现出与Transformer和SSM相当甚至更好的性能。12

核心创新:

  1. 指数门控(Exponential Gating):替代sigmoid门控
  2. 矩阵记忆(Matrix Memory):sLSTM标量 → mLSTM矩阵
  3. 7B参数模型:首次将LSTM扩展到十亿参数级别
  4. 并行化训练:mLSTM支持类似FlashAttention的高效实现

1. 传统LSTM的局限性

1.1 门控机制的局限

传统LSTM使用sigmoid门控

问题:

  • 表达能力受限:sigmoid输出范围 ,只能”压缩”信息
  • 梯度问题:连乘导致梯度消失或爆炸
  • 缺乏选择性:无法精确控制信息流

1.2 记忆结构的局限

传统LSTM的记忆是标量形式

问题:

  • 容量限制:标量记忆难以存储复杂模式
  • 信息干扰:新旧信息简单叠加

2. 指数门控机制

2.1 数学形式

xLSTM使用指数门控

关键洞察:指数函数可以将任意实数映射到正数域,实现放大或衰减

  • :增强信息
  • :抑制信息
  • :保持不变

2.2 归一化与稳定性

指数门控可能导致数值爆炸,需要归一化技术

class ExponentialGating(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.gate_proj = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, x, prev_h):
        # 门控参数
        gate = self.gate_proj(x)
        
        # 指数门控(带稳定性截断)
        gate_exp = torch.exp(torch.clamp(gate, min=-10, max=10))
        
        # 归一化(防止数值爆炸)
        gate_normalized = gate_exp / (1 + gate_exp)
        
        # 门控应用
        return gate_normalized * prev_h

2.3 与其他门控的对比

门控类型范围表达能力梯度特性
Sigmoid压缩平稳
Tanh压缩/增强平稳
指数放大/衰减指数

3. xLSTM变体

3.1 sLSTM:标量记忆

sLSTM保持标量记忆,增加内存混合(Memory Mixing)

class sLSTMCell(nn.Module):
    """
    sLSTM: 标量记忆 + 内存混合 + 指数门控
    """
    def __init__(self, d_model, d_state):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # 输入门控(指数)
        self.input_gate = nn.Linear(d_model, d_state)
        
        # 遗忘门控(指数)
        self.forget_gate = nn.Linear(d_model, d_state)
        
        # 输出门控
        self.output_gate = nn.Linear(d_model, d_state)
        
        # 候选记忆
        self.candidate = nn.Linear(d_model, d_state)
        
        # 内存混合参数(对角)
        self.A = nn.Parameter(torch.randn(d_state))
        
    def forward(self, x, h_prev, c_prev):
        # 指数门控
        i = torch.exp(torch.clamp(self.input_gate(x), -5, 5))
        f = torch.exp(torch.clamp(self.forget_gate(x), -5, 5))
        o = torch.sigmoid(self.output_gate(x))
        
        # 候选记忆
        g = torch.tanh(self.candidate(x))
        
        # 内存混合(对角线性变换)
        c_mixed = self.A * c_prev  # 对角混合
        
        # 更新记忆
        c_new = f * c_mixed + i * g
        
        # 隐藏状态
        h_new = o * torch.tanh(c_new)
        
        return h_new, c_new

特点

  • 支持内存混合,适合状态追踪任务
  • 无法完全并行化
  • 适合Parity等需要精确记忆的任务

3.2 mLSTM:矩阵记忆

mLSTM使用矩阵记忆完全可并行化

class mLSTMCell(nn.Module):
    """
    mLSTM: 矩阵记忆 + 外积更新 + 指数门控
    """
    def __init__(self, d_model, d_state):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # 输入投影(产生q, k, v)
        self.q_proj = nn.Linear(d_model, d_state)
        self.k_proj = nn.Linear(d_model, d_state)
        self.v_proj = nn.Linear(d_model, d_state)
        
        # 指数门控
        self.z_gate = nn.Parameter(torch.randn(d_model))
        
        # 遗忘门控(标量)
        self.f_gate = nn.Parameter(torch.randn(1))
        
    def forward(self, x, h_prev, C_prev):
        B, D = x.shape
        
        # 投影
        q = self.q_proj(x)  # [B, d_state]
        k = self.k_proj(x)  # [B, d_state]
        v = self.v_proj(x)  # [B, d_state]
        
        # 指数门控
        z = torch.exp(torch.clamp(self.z_gate, -5, 5))
        
        # 遗忘门控
        f = torch.exp(self.f_gate)
        
        # 外积更新(协方差规则)
        # ΔC = z · v · k^T
        C_new = f * C_prev + z * torch.outer(v, k)
        
        # 读取(注意力形式)
        s = C_new @ q  # [B, d_state]
        
        # 隐藏状态
        h_new = F.layer_norm(s, (self.d_state,))
        
        return h_new, C_new

特点

  • 矩阵记忆增强存储容量
  • 完全可并行化(类似FlashAttention)
  • 适合需要大规模记忆的任务

3.3 xLSTM架构组合

xLSTM Block可以包含sLSTM和mLSTM的组合:

class xLSTMBlock(nn.Module):
    def __init__(self, d_model, d_state, num_sLSTM=1, num_mLSTM=1):
        super().__init__()
        
        # Pre-LayerNorm
        self.norm = nn.LayerNorm(d_model)
        
        # sLSTM层
        self.sLSTMs = nn.ModuleList([
            sLSTMCell(d_model, d_state) 
            for _ in range(num_sLSTM)
        ])
        
        # mLSTM层
        self.mLSTMs = nn.ModuleList([
            mLSTMCell(d_model, d_state) 
            for _ in range(num_mLSTM)
        ])
        
        # 门控
        self.gate = nn.Parameter(torch.zeros(1))
        
    def forward(self, x, states=None):
        h = x
        
        # 残差连接
        h = h + self.gate * self.norm(h)
        
        # 应用sLSTM
        for slstm in self.sLSTMs:
            h = slstm(h)
        
        # 应用mLSTM
        for mlstm in self.mLSTMs:
            h = mlstm(h)
        
        return h

论文中的配置

  • xLSTM[1:0] = 纯mLSTM
  • xLSTM[0:1] = 纯sLSTM
  • xLSTM[1:1] = 1层sLSTM + 1层mLSTM

4. 7B模型实现

4.1 模型配置

xLSTM 7B的完整配置:

总参数量: 7.35B

架构配置:
├── 隐藏维度: 4096
├── 词嵌入维度: 4096
├── 中间维度: 14336 (FFN)
├── 层数: 48

每层配置:
├── xLSTM Block: 48个
│   ├── sLSTM层: 1
│   ├── mLSTM层: 1
│   └── FFN层: 1

词汇表:
├── 词表大小: 100,288
└── 位置编码: RoPE (YaRN)

4.2 CUDA核优化

xLSTM团队开发了优化的CUDA核:

# mLSTM的融合CUDA核示例
class mLSTMCUDA:
    @staticmethod
    def forward(q, k, v, z, f, C_prev):
        """
        融合的mLSTM前向传播
        
        步骤:
        1. 外积计算: ΔC = z * v @ k^T
        2. 遗忘: C = f * C_prev + ΔC
        3. 读取: s = C @ q
        """
        # 使用Triton实现的融合核
        return fused_mlstm_forward(q, k, v, z, f, C_prev)
    
    @staticmethod
    def backward(grad_h, q, k, v, z, f, C, s):
        """
        融合的mLSTM反向传播
        """
        return fused_mlstm_backward(grad_h, q, k, v, z, f, C, s)

4.3 JAX实现

xLSTM也提供了JAX/TPU优化实现:

# xLSTM-JAX中的并行化实现
def mlstm_parallel_scan(q, k, v, z, f):
    """
    mLSTM的并行扫描实现
    
    使用前缀和算法实现O(1)步时间复杂度
    """
    # 外积计算
    outer = jnp.einsum('bd,be->bde', v, k)  # [B, D, D]
    weighted_outer = z[:, None, None] * outer
    
    # 并行前缀和(类似FlashAttention)
    C = parallel_prefix_sum(f * C_prev + weighted_outer)
    
    # 读取
    s = jnp.einsum('bde,bd->be', C, q)
    
    return s

5. 任务性能分析

5.1 语言建模任务

模型Pile PPLWikiText-103Delta
Transformer8.915.1-
Mamba8.614.8-
RWKV-68.714.9-
xLSTM[1:0]8.414.5-0.5
xLSTM[1:1]8.314.4-0.7

5.2 状态追踪任务

任务TransformerxLSTM[0:1]xLSTM[1:1]
Multi-Query Associative Recall45%82%88%
Parity Task (100步)52%95%98%
Selective Copying67%91%93%

分析:sLSTM的内存混合机制显著提升了状态追踪能力。

5.3 推理效率

模型吞吐量(T=4K)吞吐量(T=16K)相对提升
Transformer1.0x1.0x-
Mamba2.3x3.8x优势明显
xLSTM[1:0]2.1x3.5x接近Mamba
xLSTM[1:1]1.8x2.9x中等提升

6. 与其他架构的对比

6.1 架构分类

类别代表模型时间复杂度状态追踪
TransformerGPT, Llama中等
SSMMamba中等
线性注意力GLA, RetNet中等
RNNxLSTM, RWKV

6.2 门控机制对比

模型门控类型记忆类型并行化
LSTMSigmoid标量困难
GRUSigmoid标量困难
xLSTM指数标量/矩阵mLSTM可并行
RWKVSigmoid标量困难
Mamba选择性向量可并行

7. 使用指南

7.1 HuggingFace Transformers使用

from transformers import xLSTMConfig, xLSTMModel
import torch
 
# 配置
config = xLSTMConfig(
    vocab_size=100288,
    hidden_size=4096,
    num_attention_heads=32,
    num_hidden_layers=48,
    intermediate_size=14336,
)
 
# 加载模型
model = xLSTMModel(config)
 
# 生成
input_ids = torch.randint(0, config.vocab_size, (1, 100))
outputs = model(input_ids)

7.2 自定义xLSTM层

class CustomxLSTMBlock(nn.Module):
    def __init__(self, d_model, d_state, ratio=[1, 1]):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        
        # 可配置的sLSTM/mLSTM比例
        self.xlstm = xLSTMLayer(
            d_model=d_model,
            d_state=d_state,
            num_sLSTM=ratio[0],
            num_mLSTM=ratio[1]
        )
        
    def forward(self, x):
        return x + self.xlstm(self.norm(x))

8. 总结

xLSTM代表了RNN架构的现代化复兴

  1. 指数门控:突破传统sigmoid的表达限制
  2. 矩阵记忆:增强存储容量,适合复杂模式
  3. 并行化训练:mLSTM达到与SSM相当的效率
  4. 7B规模:首次证明LSTM可扩展到十亿参数

xLSTM与Mamba、Transformer形成三足鼎立的局面,各有优劣:

架构语言建模状态追踪推理效率
Transformer★★★★★
Mamba/SSM★★★★★★★★
xLSTM★★★★★★★★

参考资料


相关文档:[[xlstm-extended-memory-lstm]、[lstm-to-ssm-state-space-duality]、[rwkv-model]、[state-space-model]]

Footnotes

  1. Beck, M. et al. (2024). xLSTM: Extended Long Short-Term Memory. NeurIPS 2024.

  2. Beck, M. et al. (2025). xLSTM 7B: A Recurrent LLM for Fast and Efficient Inference. ICLR 2025.