概述

长短期记忆网络(Long Short-Term Memory, LSTM)由 Hochreiter 和 Schmidhuber 于 1997 年提出,是 RNN 的一种变体,专门设计用于解决长期依赖问题1

为什么需要 LSTM?

标准 RNN 在处理长序列时面临梯度消失问题:

时,梯度随距离指数衰减,导致远处信息无法有效传递。

LSTM 的核心思想

LSTM 引入细胞状态(Cell State)和门控机制(Gating Mechanism):

  • 细胞状态:像传送带一样,让信息在整个序列中流动而不会被逐渐遗忘
  • 门控:学习哪些信息应该记住,哪些应该遗忘

LSTM 架构详解

核心组件

┌─────────────────────────────────────────────────────────────┐
│                      LSTM Cell                              │
│                                                              │
│   h_{t-1} ──┬──────────────────┬──────────────────┐        │
│             │                  │                  │        │
│             ▼                  ▼                  ▼        │
│        ┌────────┐        ┌────────┐        ┌────────┐      │
│   x_t ─┤ Forget │        │  Input │        │ Output │      │
│        │  Gate  │        │  Gate  │        │  Gate  │      │
│        └────┬───┘        └────┬───┘        └────┬───┘      │
│             │                  │                  │          │
│             ▼                  ▼                  ▼          │
│        σ ⊙ h_{t-1}       σ ⊙ ~C_t           σ ⊙ C_t        │
│             │                  │                  │          │
│             ▼                  ▼                  ▼          │
│        ┌─────────────────────────────────────────┐          │
│        │            Cell State (C_t)              │          │
│        │  C_t = f_t ⊙ C_{t-1} + i_t ⊙ ~C_t      │          │
│        └─────────────────────────────────────────┘          │
│                          │                                  │
│                          ▼                                  │
│                      tanh(C_t)                              │
│                          │                                  │
│                          ▼                                  │
│                       h_t = o_t ⊙ tanh(C_t)                 │
└─────────────────────────────────────────────────────────────┘

数学公式

设当前输入为 ,上一隐藏状态为 ,细胞状态为

1. 遗忘门(Forget Gate)

决定从细胞状态中丢弃什么信息:

的每个元素取值在 之间,表示保留该信息的比例。

2. 输入门(Input Gate)

决定将什么新信息存储到细胞状态:

3. 候选细胞状态(Candidate Cell State)

创建新的候选值:

4. 更新细胞状态

这允许网络选择性地遗忘旧信息和添加新信息。

5. 输出门(Output Gate)

决定输出什么:

完整公式汇总


PyTorch 实现

基础 LSTM Cell

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class LSTMCell(nn.Module):
    """手写 LSTM Cell"""
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # 合并输入:输入 + 隐藏状态
        selfcombined_size = input_size + hidden_size
        
        # 四个门的权重(遗忘门、输入门、候选、输出门)
        self.W = nn.Parameter(torch.randn(4 * hidden_size, combined_size))
        self.b = nn.Parameter(torch.zeros(4 * hidden_size))
        
        # 初始化权重
        self._init_weights()
    
    def _init_weights(self):
        std = math.sqrt(2.0 / (self.combined_size + self.hidden_size))
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.uniform_(param, -std, std)
            elif 'bias' in name:
                nn.init.zeros_(param)
                # 设置遗忘门偏置为 1(帮助记忆)
                n = param.size(0)
                param.data[n//4:n//2].fill_(1)
    
    def forward(self, x, state):
        """
        Args:
            x: (batch_size, input_size) 当前输入
            state: tuple of (h, C) 隐藏状态和细胞状态
        Returns:
            h: (batch_size, hidden_size) 新隐藏状态
            C: (batch_size, hidden_size) 新细胞状态
        """
        h, C = state
        
        # 拼接输入和隐藏状态
        combined = torch.cat([x, h], dim=1)  # (batch, combined_size)
        
        # 计算四个门
        gates = F.linear(combined, self.W, self.b)  # (batch, 4*hidden)
        gates = gates.chunk(4, dim=1)  # 分成4个门
        
        # 提取各个门
        f = torch.sigmoid(gates[0])  # 遗忘门
        i = torch.sigmoid(gates[1])  # 输入门
        C_tilde = torch.tanh(gates[2])  # 候选细胞状态
        o = torch.sigmoid(gates[3])  # 输出门
        
        # 更新细胞状态
        C_new = f * C + i * C_tilde
        
        # 计算新隐藏状态
        h_new = o * torch.tanh(C_new)
        
        return h_new, C_new

多层 LSTM

class MultiLayerLSTM(nn.Module):
    """多层 LSTM"""
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.0):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        
        self.layers = nn.ModuleList([
            LSTMCell(input_size if i == 0 else hidden_size, hidden_size)
            for i in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, states=None):
        """
        Args:
            x: (batch_size, seq_len, input_size)
            states: tuple of (h_0, C_0)
        Returns:
            output: (batch_size, seq_len, hidden_size)
            (h_n, C_n): 最终隐藏状态
        """
        batch_size, seq_len, _ = x.shape
        
        # 初始化隐藏状态
        if states is None:
            h = [torch.zeros(batch_size, self.hidden_size, device=x.device)
                 for _ in range(self.num_layers)]
            C = [torch.zeros(batch_size, self.hidden_size, device=x.device)
                 for _ in range(self.num_layers)]
        else:
            h, C = states
        
        outputs = []
        for t in range(seq_len):
            x_t = x[:, t, :]
            
            for layer_idx in range(self.num_layers):
                # 当前层的输入
                if layer_idx == 0:
                    cell_input = x_t
                else:
                    cell_input = h[layer_idx - 1]
                    if t == 0:
                        cell_input = torch.zeros_like(cell_input)
                
                # LSTM Cell 前向
                h_new, C_new = self.layers[layer_idx](cell_input, (h[layer_idx], C[layer_idx]))
                h[layer_idx] = h_new
                C[layer_idx] = C_new
                
                if layer_idx < self.num_layers - 1:
                    h[layer_idx] = self.dropout(h[layer_idx])
            
            outputs.append(h[-1])
        
        output = torch.stack(outputs, dim=1)  # (batch, seq_len, hidden)
        return output, (h, C)

使用 PyTorch 内置 LSTM

class PyTorchLSTM(nn.Module):
    """使用 PyTorch 内置的 LSTM"""
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.0, bidirectional=False):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True,
            bidirectional=bidirectional
        )
        
        # 双向LSTM输出维度翻倍
        self.fc = nn.Linear(hidden_size * (2 if bidirectional else 1), output_size)
    
    def forward(self, x):
        """
        Args:
            x: (batch_size, seq_len, input_size)
        Returns:
            output: (batch_size, output_size)
        """
        # LSTM 返回: (output, (h_n, C_n))
        output, (h_n, C_n) = self.lstm(x)
        
        # 取最后一个时间步的隐藏状态
        # h_n: (num_layers * num_directions, batch, hidden)
        final_hidden = h_n[-1]  # (batch, hidden)
        
        return self.fc(final_hidden)

反向传播

LSTM 的梯度流

LSTM 的关键优势是梯度流更加稳定

由于 ,且是通过学习得到的,梯度可以选择性地衰减或保持。

BPTT 实现

def lstm_step_backward(dh_next, dC_next, cache):
    """
    LSTM 单步反向传播
    
    Args:
        dh_next: (hidden_size,) 从下一时间步传来的隐藏状态梯度
        dC_next: (hidden_size,) 从下一时间步传来的细胞状态梯度
        cache: 前向传播缓存
    Returns:
        dx: (input_size,) 输入梯度
        dh_prev: (hidden_size,) 上一隐藏状态梯度
        dC_prev: (hidden_size,) 上一细胞状态梯度
        grads: 权重梯度字典
    """
    x, h_prev, C_prev, C, f, i, C_tilde, o = cache
    
    # tanh(C) 的梯度
    do = dh_next * torch.tanh(C)
    dC_tanh = dh_next * o * (1 - torch.tanh(C)**2)
    
    # 细胞状态梯度
    dC = dC_next + dC_tanh
    
    # 各个门的梯度
    df = dC * C_prev
    di = dC * C_tilde
    dC_tilde = dC * i
    do_local = dh_next * torch.tanh(C)
    
    # 梯度裁剪(防止梯度爆炸)
    dC = torch.clamp(dC, -5, 5)
    
    # 激活函数梯度
    df_local = df * f * (1 - f)  # sigmoid 梯度
    di_local = di * i * (1 - i)  # sigmoid 梯度
    dC_tilde_local = dC_tilde * (1 - C_tilde**2)  # tanh 梯度
    do_local = do * o * (1 - o)  # sigmoid 梯度
    
    # 拼接梯度
    d_combined = torch.cat([df_local, di_local, dC_tilde_local, do_local], dim=1)
    
    # 权重梯度
    combined = torch.cat([x, h_prev], dim=1)
    dW = torch.outer(d_combined, combined)
    db = d_combined
    
    # 输入和隐藏状态梯度
    dx = d_combined @ self.W[:, :self.input_size].T
    dh_prev = d_combined @ self.W[:, self.input_size:].T
    
    return dx, dh_prev, dC_prev, {'W': dW, 'b': db}

LSTM 变体

1. 窥视孔连接(Peephole Connections)

允许门直接看到细胞状态:

class PeepholeLSTMCell(nn.Module):
    """带窥视孔的 LSTM Cell"""
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 窥视孔权重
        self.W_f = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.W_i = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.W_o = nn.Parameter(torch.randn(hidden_size, hidden_size))
        
        # 标准的门权重
        self.W = nn.Parameter(torch.randn(3 * hidden_size, input_size + hidden_size))
        self.b = nn.Parameter(torch.zeros(3 * hidden_size))
    
    def forward(self, x, h, C):
        # 窥视孔贡献
        f = torch.sigmoid(self.W_f @ C + self.b[:self.hidden_size])
        i = torch.sigmoid(self.W_i @ C + self.b[self.hidden_size:2*self.hidden_size])
        o = torch.sigmoid(self.W_o @ C + self.b[2*self.hidden_size:])
        
        # 标准 LSTM 门
        combined = torch.cat([x, h], dim=1)
        gates = F.linear(combined, self.W, self.b)
        gates = gates.chunk(3, dim=1)
        
        # 更新细胞状态
        C_new = f * C + torch.sigmoid(gates[0]) * torch.tanh(gates[1])
        h_new = o * torch.tanh(C_new)
        
        return h_new, C_new

2. 耦合门控(Coupled Gates)

输入门和遗忘门耦合:

def lstm_coupled_forward(x, h, C, W_f, W_i, W_C, W_o, b):
    """耦合门控 LSTM"""
    combined = torch.cat([h, x], dim=1)
    
    # 遗忘门(也用作输入门)
    f = torch.sigmoid(W_f @ combined + b)
    
    # 候选细胞状态
    C_tilde = torch.tanh(W_C @ combined + b)
    
    # 耦合更新:遗忘门决定遗忘多少,输入门决定添加多少(1 - f)
    C_new = f * C + (1 - f) * C_tilde
    
    # 输出门
    o = torch.sigmoid(W_o @ combined + b)
    h_new = o * torch.tanh(C_new)
    
    return h_new, C_new

3. 门控循环单元(GRU)

GRU 是 LSTM 的简化版本,只有两个门:

GRU 优势

  • 参数量更少(比 LSTM 少约 25%)
  • 计算更高效
  • 在许多任务上与 LSTM 性能相当
class GRUCell(nn.Module):
    """GRU Cell 实现"""
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        combined_size = input_size + hidden_size
        
        # 更新门和重置门
        self.W_z = nn.Linear(combined_size, hidden_size)
        self.W_r = nn.Linear(combined_size, hidden_size)
        
        # 候选隐藏状态
        self.W = nn.Linear(combined_size, hidden_size)
    
    def forward(self, x, h):
        combined = torch.cat([x, h], dim=1)
        
        # 更新门
        z = torch.sigmoid(self.W_z(combined))
        
        # 重置门
        r = torch.sigmoid(self.W_r(combined))
        
        # 候选隐藏状态
        h_tilde = torch.tanh(self.W(torch.cat([r * h, x], dim=1)))
        
        # 新隐藏状态
        h_new = (1 - z) * h + z * h_tilde
        
        return h_new

LSTM 的应用

领域任务说明
自然语言处理语言模型预测下一个词
机器翻译序列到序列建模
情感分析文本分类
语音处理语音识别声学建模
语音合成TTS
时间序列股票预测金融预测
天气预测气象预测
视频分析动作识别时序建模

序列到序列模型(Seq2Seq)

class Seq2SeqLSTM(nn.Module):
    """基于 LSTM 的 Seq2Seq 模型"""
    def __init__(self, src_vocab_size, tgt_vocab_size, embed_dim, hidden_size):
        super().__init__()
        
        # 编码器
        self.encoder_embedding = nn.Embedding(src_vocab_size, embed_dim)
        self.encoder = nn.LSTM(embed_dim, hidden_size, batch_first=True)
        
        # 解码器
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, embed_dim)
        self.decoder = nn.LSTM(embed_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, tgt_vocab_size)
        
        self.hidden_size = hidden_size
    
    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        """
        Args:
            src: (batch, src_len) 源序列
            tgt: (batch, tgt_len) 目标序列
        """
        # 编码
        src_embed = self.encoder_embedding(src)
        _, (h, C) = self.encoder(src_embed)
        
        # 解码
        tgt_embed = self.decoder_embedding(tgt)
        decoder_output, _ = self.decoder(tgt_embed, (h, C))
        logits = self.fc(decoder_output)
        
        return logits
    
    def generate(self, src, max_len=50, sos_token=1, eos_token=2):
        """贪心解码生成"""
        # 编码
        src_embed = self.encoder_embedding(src)
        _, (h, C) = self.encoder(src_embed)
        
        # 解码
        generated = [sos_token]
        for _ in range(max_len):
            tgt_embed = self.decoder_embedding(torch.tensor([[generated[-1]]], device=src.device))
            out, (h, C) = self.decoder(tgt_embed, (h, C))
            logits = self.fc(out)
            next_token = logits.argmax(-1).item()
            generated.append(next_token)
            
            if next_token == eos_token:
                break
        
        return generated

LSTM vs 标准 RNN

特性标准 RNNLSTM
梯度流沿时间步衰减/爆炸门控选择传递
长期依赖难以学习较好处理
参数量较少较多(约4倍)
计算复杂度中等
门控机制遗忘/输入/输出门
细胞状态

训练 LSTM 的技巧

1. 权重初始化

def init_lstm_weights(model):
    for name, param in model.named_parameters():
        if 'weight_ih' in name:  # 输入到隐藏
            nn.init.xavier_uniform_(param)
        elif 'weight_hh' in name:  # 隐藏到隐藏
            nn.init.orthogonal_(param)
        elif 'bias' in name:
            nn.init.zeros_(param)
            # 遗忘门偏置初始化为 1
            n = param.size(0)
            param.data[n//4:n//2].fill_(1)

2. 梯度裁剪

# PyTorch
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

3. 残差连接

class ResidualLSTM(nn.Module):
    """带残差连接的 LSTM"""
    def __init__(self, hidden_size):
        super().__init__()
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.norm = nn.LayerNorm(hidden_size)
    
    def forward(self, x):
        residual = x
        out, state = self.lstm(x)
        return self.norm(out + residual), state

参考


相关阅读

Footnotes

  1. Hochreiter, S., & Schmidhuber, J. (1997). “Long Short-Term Memory”. Neural Computation, 9(8), 1735-1780.