概述
长短期记忆网络(Long Short-Term Memory, LSTM)由 Hochreiter 和 Schmidhuber 于 1997 年提出,是 RNN 的一种变体,专门设计用于解决长期依赖问题。1
为什么需要 LSTM?
标准 RNN 在处理长序列时面临梯度消失问题:
当 时,梯度随距离指数衰减,导致远处信息无法有效传递。
LSTM 的核心思想
LSTM 引入细胞状态(Cell State)和门控机制(Gating Mechanism):
- 细胞状态:像传送带一样,让信息在整个序列中流动而不会被逐渐遗忘
- 门控:学习哪些信息应该记住,哪些应该遗忘
LSTM 架构详解
核心组件
┌─────────────────────────────────────────────────────────────┐
│ LSTM Cell │
│ │
│ h_{t-1} ──┬──────────────────┬──────────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌────────┐ ┌────────┐ ┌────────┐ │
│ x_t ─┤ Forget │ │ Input │ │ Output │ │
│ │ Gate │ │ Gate │ │ Gate │ │
│ └────┬───┘ └────┬───┘ └────┬───┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ σ ⊙ h_{t-1} σ ⊙ ~C_t σ ⊙ C_t │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────┐ │
│ │ Cell State (C_t) │ │
│ │ C_t = f_t ⊙ C_{t-1} + i_t ⊙ ~C_t │ │
│ └─────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ tanh(C_t) │
│ │ │
│ ▼ │
│ h_t = o_t ⊙ tanh(C_t) │
└─────────────────────────────────────────────────────────────┘
数学公式
设当前输入为 ,上一隐藏状态为 ,细胞状态为 。
1. 遗忘门(Forget Gate)
决定从细胞状态中丢弃什么信息:
的每个元素取值在 之间,表示保留该信息的比例。
2. 输入门(Input Gate)
决定将什么新信息存储到细胞状态:
3. 候选细胞状态(Candidate Cell State)
创建新的候选值:
4. 更新细胞状态
这允许网络选择性地遗忘旧信息和添加新信息。
5. 输出门(Output Gate)
决定输出什么:
完整公式汇总
PyTorch 实现
基础 LSTM Cell
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class LSTMCell(nn.Module):
"""手写 LSTM Cell"""
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# 合并输入:输入 + 隐藏状态
selfcombined_size = input_size + hidden_size
# 四个门的权重(遗忘门、输入门、候选、输出门)
self.W = nn.Parameter(torch.randn(4 * hidden_size, combined_size))
self.b = nn.Parameter(torch.zeros(4 * hidden_size))
# 初始化权重
self._init_weights()
def _init_weights(self):
std = math.sqrt(2.0 / (self.combined_size + self.hidden_size))
for name, param in self.named_parameters():
if 'weight' in name:
nn.init.uniform_(param, -std, std)
elif 'bias' in name:
nn.init.zeros_(param)
# 设置遗忘门偏置为 1(帮助记忆)
n = param.size(0)
param.data[n//4:n//2].fill_(1)
def forward(self, x, state):
"""
Args:
x: (batch_size, input_size) 当前输入
state: tuple of (h, C) 隐藏状态和细胞状态
Returns:
h: (batch_size, hidden_size) 新隐藏状态
C: (batch_size, hidden_size) 新细胞状态
"""
h, C = state
# 拼接输入和隐藏状态
combined = torch.cat([x, h], dim=1) # (batch, combined_size)
# 计算四个门
gates = F.linear(combined, self.W, self.b) # (batch, 4*hidden)
gates = gates.chunk(4, dim=1) # 分成4个门
# 提取各个门
f = torch.sigmoid(gates[0]) # 遗忘门
i = torch.sigmoid(gates[1]) # 输入门
C_tilde = torch.tanh(gates[2]) # 候选细胞状态
o = torch.sigmoid(gates[3]) # 输出门
# 更新细胞状态
C_new = f * C + i * C_tilde
# 计算新隐藏状态
h_new = o * torch.tanh(C_new)
return h_new, C_new多层 LSTM
class MultiLayerLSTM(nn.Module):
"""多层 LSTM"""
def __init__(self, input_size, hidden_size, num_layers, dropout=0.0):
super().__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.layers = nn.ModuleList([
LSTMCell(input_size if i == 0 else hidden_size, hidden_size)
for i in range(num_layers)
])
self.dropout = nn.Dropout(dropout)
def forward(self, x, states=None):
"""
Args:
x: (batch_size, seq_len, input_size)
states: tuple of (h_0, C_0)
Returns:
output: (batch_size, seq_len, hidden_size)
(h_n, C_n): 最终隐藏状态
"""
batch_size, seq_len, _ = x.shape
# 初始化隐藏状态
if states is None:
h = [torch.zeros(batch_size, self.hidden_size, device=x.device)
for _ in range(self.num_layers)]
C = [torch.zeros(batch_size, self.hidden_size, device=x.device)
for _ in range(self.num_layers)]
else:
h, C = states
outputs = []
for t in range(seq_len):
x_t = x[:, t, :]
for layer_idx in range(self.num_layers):
# 当前层的输入
if layer_idx == 0:
cell_input = x_t
else:
cell_input = h[layer_idx - 1]
if t == 0:
cell_input = torch.zeros_like(cell_input)
# LSTM Cell 前向
h_new, C_new = self.layers[layer_idx](cell_input, (h[layer_idx], C[layer_idx]))
h[layer_idx] = h_new
C[layer_idx] = C_new
if layer_idx < self.num_layers - 1:
h[layer_idx] = self.dropout(h[layer_idx])
outputs.append(h[-1])
output = torch.stack(outputs, dim=1) # (batch, seq_len, hidden)
return output, (h, C)使用 PyTorch 内置 LSTM
class PyTorchLSTM(nn.Module):
"""使用 PyTorch 内置的 LSTM"""
def __init__(self, input_size, hidden_size, num_layers, dropout=0.0, bidirectional=False):
super().__init__()
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0,
batch_first=True,
bidirectional=bidirectional
)
# 双向LSTM输出维度翻倍
self.fc = nn.Linear(hidden_size * (2 if bidirectional else 1), output_size)
def forward(self, x):
"""
Args:
x: (batch_size, seq_len, input_size)
Returns:
output: (batch_size, output_size)
"""
# LSTM 返回: (output, (h_n, C_n))
output, (h_n, C_n) = self.lstm(x)
# 取最后一个时间步的隐藏状态
# h_n: (num_layers * num_directions, batch, hidden)
final_hidden = h_n[-1] # (batch, hidden)
return self.fc(final_hidden)反向传播
LSTM 的梯度流
LSTM 的关键优势是梯度流更加稳定:
由于 ,且是通过学习得到的,梯度可以选择性地衰减或保持。
BPTT 实现
def lstm_step_backward(dh_next, dC_next, cache):
"""
LSTM 单步反向传播
Args:
dh_next: (hidden_size,) 从下一时间步传来的隐藏状态梯度
dC_next: (hidden_size,) 从下一时间步传来的细胞状态梯度
cache: 前向传播缓存
Returns:
dx: (input_size,) 输入梯度
dh_prev: (hidden_size,) 上一隐藏状态梯度
dC_prev: (hidden_size,) 上一细胞状态梯度
grads: 权重梯度字典
"""
x, h_prev, C_prev, C, f, i, C_tilde, o = cache
# tanh(C) 的梯度
do = dh_next * torch.tanh(C)
dC_tanh = dh_next * o * (1 - torch.tanh(C)**2)
# 细胞状态梯度
dC = dC_next + dC_tanh
# 各个门的梯度
df = dC * C_prev
di = dC * C_tilde
dC_tilde = dC * i
do_local = dh_next * torch.tanh(C)
# 梯度裁剪(防止梯度爆炸)
dC = torch.clamp(dC, -5, 5)
# 激活函数梯度
df_local = df * f * (1 - f) # sigmoid 梯度
di_local = di * i * (1 - i) # sigmoid 梯度
dC_tilde_local = dC_tilde * (1 - C_tilde**2) # tanh 梯度
do_local = do * o * (1 - o) # sigmoid 梯度
# 拼接梯度
d_combined = torch.cat([df_local, di_local, dC_tilde_local, do_local], dim=1)
# 权重梯度
combined = torch.cat([x, h_prev], dim=1)
dW = torch.outer(d_combined, combined)
db = d_combined
# 输入和隐藏状态梯度
dx = d_combined @ self.W[:, :self.input_size].T
dh_prev = d_combined @ self.W[:, self.input_size:].T
return dx, dh_prev, dC_prev, {'W': dW, 'b': db}LSTM 变体
1. 窥视孔连接(Peephole Connections)
允许门直接看到细胞状态:
class PeepholeLSTMCell(nn.Module):
"""带窥视孔的 LSTM Cell"""
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# 窥视孔权重
self.W_f = nn.Parameter(torch.randn(hidden_size, hidden_size))
self.W_i = nn.Parameter(torch.randn(hidden_size, hidden_size))
self.W_o = nn.Parameter(torch.randn(hidden_size, hidden_size))
# 标准的门权重
self.W = nn.Parameter(torch.randn(3 * hidden_size, input_size + hidden_size))
self.b = nn.Parameter(torch.zeros(3 * hidden_size))
def forward(self, x, h, C):
# 窥视孔贡献
f = torch.sigmoid(self.W_f @ C + self.b[:self.hidden_size])
i = torch.sigmoid(self.W_i @ C + self.b[self.hidden_size:2*self.hidden_size])
o = torch.sigmoid(self.W_o @ C + self.b[2*self.hidden_size:])
# 标准 LSTM 门
combined = torch.cat([x, h], dim=1)
gates = F.linear(combined, self.W, self.b)
gates = gates.chunk(3, dim=1)
# 更新细胞状态
C_new = f * C + torch.sigmoid(gates[0]) * torch.tanh(gates[1])
h_new = o * torch.tanh(C_new)
return h_new, C_new2. 耦合门控(Coupled Gates)
输入门和遗忘门耦合:
def lstm_coupled_forward(x, h, C, W_f, W_i, W_C, W_o, b):
"""耦合门控 LSTM"""
combined = torch.cat([h, x], dim=1)
# 遗忘门(也用作输入门)
f = torch.sigmoid(W_f @ combined + b)
# 候选细胞状态
C_tilde = torch.tanh(W_C @ combined + b)
# 耦合更新:遗忘门决定遗忘多少,输入门决定添加多少(1 - f)
C_new = f * C + (1 - f) * C_tilde
# 输出门
o = torch.sigmoid(W_o @ combined + b)
h_new = o * torch.tanh(C_new)
return h_new, C_new3. 门控循环单元(GRU)
GRU 是 LSTM 的简化版本,只有两个门:
GRU 优势:
- 参数量更少(比 LSTM 少约 25%)
- 计算更高效
- 在许多任务上与 LSTM 性能相当
class GRUCell(nn.Module):
"""GRU Cell 实现"""
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
combined_size = input_size + hidden_size
# 更新门和重置门
self.W_z = nn.Linear(combined_size, hidden_size)
self.W_r = nn.Linear(combined_size, hidden_size)
# 候选隐藏状态
self.W = nn.Linear(combined_size, hidden_size)
def forward(self, x, h):
combined = torch.cat([x, h], dim=1)
# 更新门
z = torch.sigmoid(self.W_z(combined))
# 重置门
r = torch.sigmoid(self.W_r(combined))
# 候选隐藏状态
h_tilde = torch.tanh(self.W(torch.cat([r * h, x], dim=1)))
# 新隐藏状态
h_new = (1 - z) * h + z * h_tilde
return h_newLSTM 的应用
| 领域 | 任务 | 说明 |
|---|---|---|
| 自然语言处理 | 语言模型 | 预测下一个词 |
| 机器翻译 | 序列到序列建模 | |
| 情感分析 | 文本分类 | |
| 语音处理 | 语音识别 | 声学建模 |
| 语音合成 | TTS | |
| 时间序列 | 股票预测 | 金融预测 |
| 天气预测 | 气象预测 | |
| 视频分析 | 动作识别 | 时序建模 |
序列到序列模型(Seq2Seq)
class Seq2SeqLSTM(nn.Module):
"""基于 LSTM 的 Seq2Seq 模型"""
def __init__(self, src_vocab_size, tgt_vocab_size, embed_dim, hidden_size):
super().__init__()
# 编码器
self.encoder_embedding = nn.Embedding(src_vocab_size, embed_dim)
self.encoder = nn.LSTM(embed_dim, hidden_size, batch_first=True)
# 解码器
self.decoder_embedding = nn.Embedding(tgt_vocab_size, embed_dim)
self.decoder = nn.LSTM(embed_dim, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, tgt_vocab_size)
self.hidden_size = hidden_size
def forward(self, src, tgt, teacher_forcing_ratio=0.5):
"""
Args:
src: (batch, src_len) 源序列
tgt: (batch, tgt_len) 目标序列
"""
# 编码
src_embed = self.encoder_embedding(src)
_, (h, C) = self.encoder(src_embed)
# 解码
tgt_embed = self.decoder_embedding(tgt)
decoder_output, _ = self.decoder(tgt_embed, (h, C))
logits = self.fc(decoder_output)
return logits
def generate(self, src, max_len=50, sos_token=1, eos_token=2):
"""贪心解码生成"""
# 编码
src_embed = self.encoder_embedding(src)
_, (h, C) = self.encoder(src_embed)
# 解码
generated = [sos_token]
for _ in range(max_len):
tgt_embed = self.decoder_embedding(torch.tensor([[generated[-1]]], device=src.device))
out, (h, C) = self.decoder(tgt_embed, (h, C))
logits = self.fc(out)
next_token = logits.argmax(-1).item()
generated.append(next_token)
if next_token == eos_token:
break
return generatedLSTM vs 标准 RNN
| 特性 | 标准 RNN | LSTM |
|---|---|---|
| 梯度流 | 沿时间步衰减/爆炸 | 门控选择传递 |
| 长期依赖 | 难以学习 | 较好处理 |
| 参数量 | 较少 | 较多(约4倍) |
| 计算复杂度 | 低 | 中等 |
| 门控机制 | 无 | 遗忘/输入/输出门 |
| 细胞状态 | 无 | 有 |
训练 LSTM 的技巧
1. 权重初始化
def init_lstm_weights(model):
for name, param in model.named_parameters():
if 'weight_ih' in name: # 输入到隐藏
nn.init.xavier_uniform_(param)
elif 'weight_hh' in name: # 隐藏到隐藏
nn.init.orthogonal_(param)
elif 'bias' in name:
nn.init.zeros_(param)
# 遗忘门偏置初始化为 1
n = param.size(0)
param.data[n//4:n//2].fill_(1)2. 梯度裁剪
# PyTorch
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)3. 残差连接
class ResidualLSTM(nn.Module):
"""带残差连接的 LSTM"""
def __init__(self, hidden_size):
super().__init__()
self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
self.norm = nn.LayerNorm(hidden_size)
def forward(self, x):
residual = x
out, state = self.lstm(x)
return self.norm(out + residual), state参考
相关阅读
- RNN 基础 — 循环神经网络入门
- 现代 LSTM 变体 — xLSTM、τ-GRU 等最新进展
- LSTM 与状态空间对偶性 — SSM 如何统一 RNN 与 Transformer
- Transformer 与注意力机制 — 完全基于注意力的序列建模
Footnotes
-
Hochreiter, S., & Schmidhuber, J. (1997). “Long Short-Term Memory”. Neural Computation, 9(8), 1735-1780. ↩