概述

尽管 Transformer 在大语言模型领域占据主导地位,LSTM 的研究在 2024-2025 年迎来了重要突破。xLSTM 证明了 LSTM 可以扩展到数十亿参数,τ-GRU 引入时滞反馈机制,而 QL-LSTM 实现了参数量的显著压缩。1

这些进展不仅提升了 LSTM 本身的性能,更重要的是揭示了 LSTM 与现代状态空间模型(SSM)之间的深刻联系。


xLSTM:扩展型长短期记忆网络

背景与动机

标准 LSTM 存在两个主要瓶颈:

  1. 可并行性差:LSTM 的门控计算具有顺序依赖性,无法像 Transformer 那样高效并行
  2. 表达能力受限:传统门控机制在处理长序列时仍有局限

核心创新

xLSTM 由 Beck 等人于 2024 年提出,引入两项关键创新:2

1. 指数门控(Exponential Gating)

标准 LSTM 使用 sigmoid 门控,输出范围为 。xLSTM 引入指数门控

其中 是可学习的参数。

优势

  • 遗忘门可以取 的值,实现选择性增强而非仅选择性遗忘
  • 更好地平衡记忆与遗忘

2. 新的内存结构

xLSTM 引入两种新的内存单元:

sLSTM(Scalar LSTM)

  • 每个时间步维护一个标量记忆
  • 通过记忆混合(Memory Mixing) 连接相邻时间步
  • 完全可并行化的门控机制
class sLSTMCell(nn.Module):
    """Scalar LSTM Cell"""
    def __init__(self, d_model):
        super().__init__()
        # 输入门(指数)
        self.alpha_i = nn.Parameter(torch.zeros(d_model))
        self.beta_i = nn.Parameter(torch.zeros(d_model))
        
        # 遗忘门(指数)
        self.alpha_f = nn.Parameter(torch.zeros(d_model))
        self.beta_f = nn.Parameter(torch.zeros(d_model))
        
        # 输出门
        self.gamma_o = nn.Parameter(torch.ones(d_model))
        self.delta_o = nn.Parameter(torch.zeros(d_model))
        
        # 候选值
        self.W_q = nn.Linear(d_model, d_model)
        self.W_h = nn.Linear(d_model, d_model)
        
        # 记忆混合参数
        self.A = nn.Parameter(torch.eye(d_model))  # 邻居连接
    
    def forward(self, x, state):
        m_prev, h_prev = state
        
        # 指数遗忘门(选择性增强/遗忘)
        f = torch.exp(self.alpha_f) * torch.sigmoid(self.beta_f)
        
        # 指数输入门
        i = torch.exp(self.alpha_i) * torch.sigmoid(self.beta_i)
        
        # 候选记忆
        q = self.W_q(x) + self.W_h(h_prev)
        m_tilde = torch.tanh(q)
        
        # 更新记忆(带邻居混合)
        m_t = f * (self.A @ m_prev) + i * m_tilde
        
        # 归一化
        m_t = m_t / (f + i + 1e-6)
        
        # 输出门
        o = torch.sigmoid(self.gamma_o) * torch.sigmoid(self.delta_o)
        h_t = o * torch.tanh(m_t)
        
        return h_t, (m_t, h_t)

mLSTM(Matrix LSTM)

  • 维护一个矩阵记忆
  • 通过协方差更新规则实现并行化
class mLSTMCell(nn.Module):
    """Matrix LSTM Cell"""
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        
        # 输入投影
        self.W_z = nn.Linear(d_model, d_model)
        
        # 门控参数(指数)
        self.gate = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.Parameter(torch.zeros(4 * d_model))  # 指数偏移
        )
    
    def forward(self, x, state):
        M_prev, h_prev, c_prev = state
        
        # 门控计算(可并行)
        gate_out = self.gate(x)
        i, f, o, g = gate_out.chunk(4, dim=-1)
        
        # 指数门控
        i = torch.exp(i)
        f = torch.exp(f)
        
        # 外部输入
        z = self.W_z(x)
        
        # mLSTM 核心:协方差更新
        # v_t = f * (M_prev @ z) / sqrt(d) + i * x
        v_t = (f.unsqueeze(-1) * (M_prev @ z.unsqueeze(-1)).squeeze(-1)) / np.sqrt(self.d_model) + i * x
        
        # 更新记忆矩阵
        M_t = f.unsqueeze(-1) * M_prev + i.unsqueeze(-1) * torch.outer(z, z)
        
        # 归一化
        M_t = M_t / (f.unsqueeze(-1) + i.unsqueeze(-1) + 1e-6)
        
        # 输出
        o = torch.sigmoid(o)
        h_t = o * (M_t @ z) / np.sqrt(self.d_model)
        
        return h_t, (M_t, h_t, v_t)

架构设计

xLSTM 使用残差堆叠的 block 结构:

class xLSTMBlock(nn.Module):
    """xLSTM Block"""
    def __init__(self, d_model, num_heads=8, expand=2):
        super().__init__()
        d_inner = d_model * expand
        
        self.norm = nn.LayerNorm(d_model)
        
        # sLSTM 或 mLSTM
        self.xlstm = nn.ModuleList([
            sLSTMCell(d_model) for _ in range(num_heads)
        ])
        
        # 输出投影
        self.proj = nn.Linear(d_model * num_heads, d_model)
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_inner),
            nn.GELU(),
            nn.Linear(d_inner, d_model)
        )
    
    def forward(self, x):
        residual = x
        x = self.norm(x)
        
        # xLSTM 处理
        outs = []
        for cell in self.xlstm:
            h, _ = cell(x, self.init_state())
            outs.append(h)
        
        x = torch.cat(outs, dim=-1)
        x = self.proj(x)
        x = x + residual
        
        # FFN
        x = x + self.ffn(x)
        
        return x

实验结果

模型参数量WikiText-103 PPLMMLUHumanEval
LSTM (baseline)1B23.528.115.2
xLSTM-1B1B19.831.522.4
xLSTM-7B7B15.258.345.1
Transformer-1B1B18.530.220.8

关键发现

  • xLSTM-1B 在多项任务上超越 Transformer-1B
  • mLSTM 变体在语言建模任务上表现最佳
  • xLSTM 展现出与 Transformer 类似的缩放定律

τ-GRU:时滞门控循环单元

核心思想

传统 RNN 假设状态更新是瞬时的,但真实系统往往存在传播延迟。τ-GRU 将时滞微分方程引入门控机制。3

数学模型

标准 GRU

其中

τ-GRU:引入时滞反馈

其中 是可学习的延迟参数。

物理意义

时滞反馈可以用微分方程描述:

这将 RNN 从离散时间系统转化为连续时间系统的离散化。

PyTorch 实现

class TauGRUCell(nn.Module):
    """τ-GRU Cell"""
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 更新门
        self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
        
        # 候选隐藏状态
        self.W_h = nn.Linear(input_size + hidden_size, hidden_size)
        
        # 延迟参数(可学习)
        self.tau = nn.Parameter(torch.ones(hidden_size))
    
    def forward(self, x, h_prev, h_delayed=None):
        """
        Args:
            x: 当前输入
            h_prev: 上一个时间步的隐藏状态
            h_delayed: 延迟 τ 步的隐藏状态
        """
        if h_delayed is None:
            h_delayed = h_prev
        
        combined = torch.cat([x, h_prev], dim=-1)
        
        # 更新门
        z = torch.sigmoid(self.W_z(combined))
        
        # 候选状态
        h_tilde = torch.tanh(self.W_h(combined))
        
        # τ-GRU 更新
        tau = torch.clamp(self.tau, min=1.0, max=10.0)
        
        # 指数衰减的时滞反馈
        decay = torch.exp(-1.0 / tau)
        h_t = (1 - z) * (decay * h_delayed) + z * h_tilde
        
        return h_t
 
 
class TauGRU(nn.Module):
    """τ-GRU 完整实现,支持可变延迟"""
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            TauGRUCell(
                input_size if i == 0 else hidden_size,
                hidden_size
            ) for i in range(num_layers)
        ])
        
        # 延迟缓冲区
        self.delay_buffers = [
            collections.deque(maxlen=10) for _ in range(num_layers)
        ]
    
    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, input_size)
        Returns:
            output: (batch, seq_len, hidden_size)
        """
        batch_size, seq_len, _ = x.shape
        h = [torch.zeros(batch_size, self.layers[0].hidden_size, device=x.device)
             for _ in range(len(self.layers))]
        
        outputs = []
        
        for t in range(seq_len):
            x_t = x[:, t, :]
            
            for layer_idx, layer in enumerate(self.layers):
                # 获取延迟的隐藏状态
                tau = int(layer.tau.data.clamp(1, 10).item())
                
                if len(self.delay_buffers[layer_idx]) >= tau:
                    h_delayed = self.delay_buffers[layer_idx][-tau]
                else:
                    h_delayed = h[layer_idx]
                
                # 更新
                h_new = layer(x_t, h[layer_idx], h_delayed)
                h[layer_idx] = h_new
                x_t = h_new
                
                # 更新延迟缓冲区
                self.delay_buffers[layer_idx].append(h_new.detach())
            
            outputs.append(h[-1])
        
        return torch.stack(outputs, dim=1)

优势

  1. 更好的长期依赖建模:时滞反馈显式建模信息传播延迟
  2. 更快收敛:实验显示收敛速度提升 2-3 倍
  3. 更好的泛化:在时间序列预测任务上超越标准 GRU

QL-LSTM:量子跃迁 LSTM

参数共享的统一门控

QL-LSTM 引入参数共享统一门控(Parameter-Shared Unified Gating, PSUG),将 4 个独立的门控投影合并为 1 个。4

核心创新

标准 LSTM 门控

参数量

QL-LSTM PSUG

使用单一的共享权重矩阵 ,然后通过线性组合得到各个门:

层次化门控与跳跃连接

QL-LSTM 还引入层次化门控(Hierarchical Gating)与加性跳跃连接

class QLLSTMCell(nn.Module):
    """QL-LSTM with PSUG and HGR-ASC"""
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 参数共享的投影
        self.W_s = nn.Linear(input_size + hidden_size, 2 * hidden_size)
        
        # 门控权重(少量参数)
        self.gate_weights = nn.Parameter(torch.randn(4, 2, hidden_size, hidden_size))
        
        # 跳跃连接
        self.skip = nn.Linear(input_size, hidden_size, bias=False)
    
    def forward(self, x, h_prev, C_prev):
        combined = torch.cat([x, h_prev], dim=-1)
        
        # 共享投影
        s = self.W_s(combined)  # (batch, 2*hidden)
        sigma_s, tanh_s = s[:, :self.hidden_size], s[:, self.hidden_size:]
        sigma_s = torch.sigmoid(sigma_s)
        tanh_s = torch.tanh(tanh_s)
        
        # 门控组合
        gate_input = torch.stack([sigma_s, tanh_s], dim=0)  # (2, batch, hidden)
        
        i, f, o, g = torch.einsum('gbid,gd...->gb...', 
                                   self.gate_weights, 
                                   gate_input)
        
        # 层次化门控更新
        C_t = f * C_prev + i * g + self.skip(x)  # 加性跳跃连接
        
        # 输出
        h_t = torch.sigmoid(o) * torch.tanh(C_t)
        
        return h_t, C_t

参数量对比

模型参数量压缩比
标准 LSTM
QL-LSTM~48% 减少

LSTM 缩放定律

2025 年研究发现

xLSTM 团队建立了 LSTM 的计算感知缩放定律5

核心发现

给定计算预算 ,最优模型规模满足:

其中 是参数量, 是数据量, 是幂律指数。

缩放曲线

         性能
          ↑
          │    ╭───────── Transformer
          │   ╱
          │  ╱  xLSTM
          │ ╱
          │╱
          └──────────────────→ 计算量

关键洞察

  • xLSTM 的缩放曲线与 Transformer 斜率相近
  • 但在同等计算量下,xLSTM 表现略优
  • 这可能源于 LSTM 的归纳偏置(线性复杂度 vs 二次复杂度)

LSTM 在现代 AI 系统中的应用

1. 金融预测:StockBot 2.0

2026 年研究发现:Vanilla LSTM 在股票预测任务上超越 Transformer6

class StockBotLSTM(nn.Module):
    """股票预测 LSTM"""
    def __init__(self, input_dim=5, hidden_dim=128, num_layers=2):
        super().__init__()
        # 股价特征:开盘价、收盘价、最高价、最低价、成交量
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)  # 预测收益率
    
    def forward(self, x):
        # x: (batch, seq_len, 5)
        out, _ = self.lstm(x)
        return self.fc(out[:, -1, :])

原因分析

  • 金融数据噪声大,LSTM 的归纳偏置更适合
  • Transformer 需要的海量数据在金融领域往往不可得
  • LSTM 的低方差特性在少数据场景更有优势

2. 全球河流预测:AIFL

2026 年研究:使用 LSTM 进行全球 18,588 个流域的日径流预测。

class StreamflowLSTM(nn.Module):
    """流域径流预测模型"""
    def __init__(self, forcing_dim, static_dim, hidden_dim=256):
        super().__init__()
        
        # 动态输入(气象强迫)
        self.forcing_proj = nn.Linear(forcing_dim, hidden_dim)
        
        # 静态特征(地形、土壤)
        self.static_proj = nn.Linear(static_dim, hidden_dim)
        
        # LSTM
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2)
        
        # 输出:每个流域的流量
        self.fc = nn.Linear(hidden_dim, 18588)
    
    def forward(self, forcing, static, hidden=None):
        # 投影
        x = self.forcing_proj(forcing) + self.static_proj(static)
        
        # LSTM
        out, hidden = self.lstm(x, hidden)
        
        # 预测
        return self.fc(out), hidden

3. 能源预测

LSTM 在电力负荷预测、太阳能预测等能源领域表现优异:

class EnergyLSTM(nn.Module):
    """能源消耗预测"""
    def __init__(self, d_model=256, n_heads=8, n_layers=4):
        super().__init__()
        # 可学习的位置编码(针对时间)
        self.embed = nn.Linear(5, d_model)  # 小时、星期、月份、节假日、天气
        
        self.lstm = nn.LSTM(d_model, d_model, n_layers, dropout=0.1)
        self.fc = nn.Linear(d_model, 24)  # 预测未来24小时
    
    def forward(self, x):
        embedded = self.embed(x)
        out, _ = self.lstm(embedded)
        return self.fc(out[:, -1, :])

LSTM vs Transformer:何时选择

场景推荐原因
小数据集LSTM归纳偏置减少过拟合
长序列、计算受限LSTM/SSM线性复杂度
需要全局上下文Transformer完全注意力
实时推理LSTM无需缓存所有 KV
多模态融合Transformer更好的跨模态注意力
时间序列预测LSTM专为时序设计
因果推理SSM状态空间视角

参考


相关阅读

Footnotes

  1. Beck, M., et al. (2024). “xLSTM: Extended Long Short-Term Memory”. arXiv:2405.04517.

  2. Gu, A., & Dao, T. (2024). “Mamba: Linear-Time Sequence Modeling with Selective State Spaces”. arXiv:2312.00752.

  3. Erichson, N. B., et al. (2025). “τ-GRU: A Time-Delay Gating Mechanism for Recurrent Neural Networks”. AISTATS 2025.

  4. Nti, I. K. N., et al. (2025). “QL-LSTM: Quantum-Leap LSTM with Parameter-Shared Unified Gating”. arXiv:2512.06582.

  5. Beck, M., et al. (2025). “Scaling Laws for xLSTM”. arXiv:2510.02228.

  6. StockBot Team (2026). “StockBot 2.0: LSTM Superiority in Stock Prediction”. arXiv:2601.00197.