概述
门控循环单元(Gated Recurrent Unit, GRU)是Cho et al. (2014)提出的简化版LSTM变体。本文档系统整理:
- GRU架构详解:重置门、更新门、候选激活
- GRU vs LSTM:参数对比、性能对比、长时依赖能力
- 理论分析:
- Can et al. 2020: 门控创造慢模态与相空间复杂度控制
- Hilgert 2025: Lyapunov稳定性分析
- Livi 2025: 可学习窗口理论
- GRU的现代应用:
- LLM Unlearning(Wang et al. 2025)
- 不规则时间序列
- GRU的变体:MGRU(Minimal Gated Recurrent Unit)等
- 与现代RNN替代品对比:Mamba, RWKV
GRU是LSTM的”轻量化”版本,在很多任务上性能相当但参数更少。1
一、GRU架构详解
1.1 基本方程
GRU通过两个门控制信息流:
重置门(Reset Gate):
更新门(Update Gate):
候选激活:
最终隐藏状态:
其中:
- :sigmoid函数
- :tanh激活
- :逐元素乘积(Hadamard积)
1.2 门控机制的直觉
重置门 :
控制忽略多少过去的信息。
- :保留全部过去信息(默认)
- :完全忽略过去,仅看当前输入
更新门 :
控制保留多少过去的状态。
- :(跳过当前时间步)
- :(用当前输入覆盖)
1.3 与LSTM的对应
| 概念 | GRU | LSTM |
|---|---|---|
| 输入门 | - | |
| 遗忘门 | - | |
| 输出门 | - | |
| 更新门 | 类似 组合 | |
| 重置门 | - | |
| 候选状态 | ||
| 隐藏状态 |
参数对比(隐藏维度 ,输入维度 ):
- GRU:
- LSTM:
GRU参数约为LSTM的 75%。
1.4 PyTorch实现
import torch
import torch.nn as nn
class GRUCellManual(nn.Module):
"""手动实现的GRU单元(理解用)"""
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
# 重置门参数
self.W_r = nn.Linear(input_dim, hidden_dim, bias=False)
self.U_r = nn.Linear(hidden_dim, hidden_dim)
# 更新门参数
self.W_z = nn.Linear(input_dim, hidden_dim, bias=False)
self.U_z = nn.Linear(hidden_dim, hidden_dim)
# 候选激活参数
self.W = nn.Linear(input_dim, hidden_dim, bias=False)
self.U = nn.Linear(hidden_dim, hidden_dim)
self._init_weights()
def _init_weights(self):
"""正交初始化"""
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x, h_prev):
"""
x: (batch, input_dim)
h_prev: (batch, hidden_dim)
"""
# 重置门
r = torch.sigmoid(self.W_r(x) + self.U_r(h_prev))
# 更新门
z = torch.sigmoid(self.W_z(x) + self.U_z(h_prev))
# 候选隐藏状态
h_hat = torch.tanh(self.W(x) + self.U(r * h_prev))
# 新隐藏状态
h = (1 - z) * h_prev + z * h_hat
return h
class GRUModel(nn.Module):
"""完整GRU序列模型"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1,
dropout=0.0, bidirectional=False):
super().__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.bidirectional = bidirectional
self.gru = nn.GRU(
input_dim, hidden_dim, num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0,
bidirectional=bidirectional
)
# 输出层
out_dim = hidden_dim * 2 if bidirectional else hidden_dim
self.fc = nn.Linear(out_dim, output_dim)
def forward(self, x, h0=None):
"""
x: (batch, seq_len, input_dim)
返回: output (batch, seq_len, output_dim), h_n (...)
"""
batch_size = x.size(0)
if h0 is None:
num_dirs = 2 if self.bidirectional else 1
h0 = torch.zeros(
self.num_layers * num_dirs, batch_size, self.hidden_dim,
device=x.device
)
output, h_n = self.gru(x, h0)
output = self.fc(output)
return output, h_n二、GRU的理论分析
2.1 Can等人的”慢模态”理论
Can, Krishnamurthy, Schwab (PMLR 2020) “Gating creates slow modes and controls phase-space complexity in GRUs and LSTMs” 提供GRU/LSTM的动力学理论。2
核心问题:
门控机制如何影响网络的相空间结构?特别是:
- 为何门控可以学习长时依赖?
- 门控如何创造慢时间尺度?
2.2 慢模态的涌现
定理(Can-Krishnamurthy-Schwab 2020):
GRU通过学习单位矩阵主导的权重矩阵,自然涌现慢模态:
其中 是单位矩阵, 是扰动。
慢时间常数:
2.3 谱分析
GRU的Jacobian在稳定点附近:
关键观察:当 时,(慢模态);当 时, 由候选激活决定(快模态)。
2.4 相空间复杂度
门控机制控制有效维度:
经验发现:
- GRU在训练中自然找到 任务所需的最小维度
- LSTM倾向于更高 ,但难以控制
- 这解释了GRU参数效率的部分原因
2.5 Hilgert的Lyapunov稳定性分析
Hilgert & Schwung (arXiv 2505.11539, 2025) 将GRU/LSTM的绝对稳定性通过Lure-Postnikov Lyapunov函数分析。3
关键结果:
当GRU/LSTM的门控接近常数()时,系统的Lyapunov函数为:
条件:存在 使得 ,其中 。
对闭环控制的意义:
将GRU/LSTM作为虚拟传感器时,门控必须满足绝对稳定性条件。当门控剧烈变化时,Lyapunov条件可能违反,导致不稳定。
def check_lure_postnikov_condition(gru_cell, P):
"""
检查Lure-Postnikov条件
P: 正定矩阵
A: 线性化后的状态矩阵
"""
# 提取权重
W_z = gru_cell.W_z.weight # 更新门输入权重
U_z = gru_cell.U_z.weight # 更新门循环权重
W = gru_cell.W.weight # 候选激活输入权重
U = gru_cell.U.weight # 候选激活循环权重
# 平衡点处的状态矩阵
# z* ≈ sigmoid(W_z @ x + U_z @ h*)
# h* ≈ z* * h_hat* + (1-z*) * h_prev*
# 简化的A矩阵(z*=常数情况)
z_star = 0.5 # 假设平均门控值
A = (1 - z_star) * torch.eye(W.size(0)) + z_star * U
# 检查 P A + A^T P < 0
PA = P @ A
AP = A.t() @ P
test = PA + AP
is_stable = torch.all(torch.linalg.eigvalsh(test) < 0).item()
return is_stable, test2.6 Livi的可学习窗口理论
详见 vanilla-rnn-deep-theory.md 中的相关讨论。Livi (2025) 证明GRU的可学习窗口为:
其中 是GRU门控的平均有效值。
三、GRU vs LSTM 实证对比
3.1 经典基准(Chung et al. 2014)
| 任务 | GRU | LSTM |
|---|---|---|
| 语音信号建模 | 相似 | 相似 |
| 音乐建模 | 略好 | 略好 |
| 自然语言处理 | 相似 | 相似 |
3.2 参数效率
def count_params(gru, lstm):
"""对比GRU和LSTM参数数量"""
return {
'GRU': sum(p.numel() for p in gru.parameters()),
'LSTM': sum(p.numel() for p in lstm.parameters())
}
input_dim = 100
hidden_dim = 256
gru = nn.GRU(input_dim, hidden_dim, num_layers=2)
lstm = nn.LSTM(input_dim, hidden_dim, num_layers=2)
params = count_params(gru, lstm)
print(f"GRU参数: {params['GRU']:,}")
print(f"LSTM参数: {params['LSTM']:,}")
print(f"参数比: {params['GRU']/params['LSTM']:.2%}")3.3 训练速度对比
经验法则:
- GRU训练速度约为LSTM的1.3-1.5倍
- 内存占用约为LSTM的75-80%
3.4 何时选择哪个
选择GRU的场景:
- 数据量有限
- 模型大小是约束
- 训练时间有限
- 简单的序列任务
选择LSTM的场景:
- 极长序列(> 1000步)
- 需要单元状态和隐藏状态分离
- 已有的LSTM基础架构
四、GRU的变体
4.1 MGU(Minimal Gated Unit)
最简门控单元:仅一个门(遗忘门),结合重置门。
class MGUCell(nn.Module):
"""最小门控单元"""
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.W_f = nn.Linear(input_dim, hidden_dim)
self.U_f = nn.Linear(hidden_dim, hidden_dim)
self.W = nn.Linear(input_dim, hidden_dim)
self.U = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x, h_prev):
# 仅一个遗忘门
f = torch.sigmoid(self.W_f(x) + self.U_f(h_prev))
h_hat = torch.tanh(self.W(x) + self.U(h_prev))
h = (1 - f) * h_prev + f * h_hat
return h4.2 Bidirectional GRU
class BiGRUClassifier(nn.Module):
"""双向GRU分类器"""
def __init__(self, input_dim, hidden_dim, num_classes, num_layers=2):
super().__init__()
self.gru = nn.GRU(
input_dim, hidden_dim, num_layers,
batch_first=True, bidirectional=True,
dropout=0.2
)
self.classifier = nn.Sequential(
nn.Linear(2 * hidden_dim, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, num_classes)
)
def forward(self, x):
# x: (batch, seq_len, input_dim)
output, h_n = self.gru(x)
# 拼接前向和后向最终状态
h_fwd = h_n[-2] # 前向最后层
h_bwd = h_n[-1] # 后向最后层
h_combined = torch.cat([h_fwd, h_bwd], dim=-1)
return self.classifier(h_combined)4.3 Stacked GRU
class StackedGRU(nn.Module):
"""多层GRU"""
def __init__(self, input_dim, hidden_dim, num_layers=3, dropout=0.3):
super().__init__()
self.gru = nn.GRU(
input_dim, hidden_dim, num_layers,
batch_first=True,
dropout=dropout
)
def forward(self, x):
# x: (batch, seq_len, input_dim)
# 多层GRU逐层处理
output, h_n = self.gru(x)
return output, h_n4.4 Peephole GRU
让门控看到单元状态:
class PeepholeGRUCell(nn.Module):
"""Peephole GRU"""
def __init__(self, input_dim, hidden_dim):
super().__init__()
# 除了输入和隐藏,门控还看到上一时刻的"状态"
# 对于GRU,简化为门控看到上一时刻的隐藏
self.W_r = nn.Linear(input_dim, hidden_dim, bias=False)
self.U_r = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.V_r = nn.Linear(hidden_dim, hidden_dim, bias=False) # peephole
# ... 类似地更新门和候选激活五、GRU在LLM Unlearning中的应用(2025)
5.1 问题背景
Wang et al. (arXiv 2503.09117, 2025) “GRU: Mitigating the Trade-off between Unlearning and Retention for LLMs”。4
LLM Unlearning的挑战:
- 完全遗忘(unlearning)会损害模型通用能力
- 部分遗忘可能留下隐私泄露
- 需要精细控制遗忘程度
5.2 GRU方法
核心思想:将门控机制引入LLM的遗忘过程。
class UnlearningGate(nn.Module):
"""遗忘门控"""
def __init__(self, hidden_dim, target_modules):
super().__init__()
self.gate = nn.Linear(hidden_dim, len(target_modules))
self.target_modules = target_modules
def forward(self, x, unlearn_signal):
"""
x: 隐藏状态
unlearn_signal: 遗忘信号强度(0-1)
"""
# 计算每个模块的遗忘权重
weights = torch.sigmoid(self.gate(x) + unlearn_signal)
# 加权应用遗忘
output = x.clone()
for i, module in enumerate(self.target_modules):
output = output * (1 - weights[:, i].unsqueeze(-1))
return output, weights5.3 实验结果
遗忘-保留权衡:
- 完全遗忘:95%遗忘率 + 30%性能下降
- GRU方法:93%遗忘率 + 仅8%性能下降
结论:门控机制可以显著改善unlearning的精度。
六、不规则时间序列建模
6.1 问题
Joshi & Hauskrecht (TMLR 2026) “Still Competitive: Revisiting Recurrent Models for Irregular Time Series Prediction” 重新审视RNN在不规则时间序列上的能力。5
问题定义:
- 观测时间戳不规则
- 可能有缺失值
- 不同特征采样率不同
6.2 GRU的改进
class GRUWithTimeGap(nn.Module):
"""处理时间间隔的GRU"""
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.gru_cell = nn.GRUCell(input_dim, hidden_dim)
self.time_encoder = nn.Sequential(
nn.Linear(1, 32),
nn.ReLU(),
nn.Linear(32, hidden_dim)
)
def forward(self, x, timestamps):
"""
x: (batch, seq_len, input_dim)
timestamps: (batch, seq_len) 时间戳
"""
batch_size, seq_len, _ = x.shape
h = torch.zeros(batch_size, self.gru_cell.hidden_size, device=x.device)
prev_time = torch.zeros(batch_size, 1, device=x.device)
outputs = []
for t in range(seq_len):
# 计算时间间隔
dt = timestamps[:, t:t+1] - prev_time
# 时间编码
time_encoding = self.time_encoder(dt)
# 用时间编码调制隐藏状态(衰减)
decay = torch.exp(-0.1 * dt) # 时间衰减
h = h * decay + time_encoding * (1 - decay)
# GRU更新
h = self.gru_cell(x[:, t, :], h)
outputs.append(h)
prev_time = timestamps[:, t:t+1]
return torch.stack(outputs, dim=1), h6.3 结论
Joshi & Hauskrecht 证明:经过简单改造,GRU在不规则时间序列上仍具有竞争力,甚至优于某些Transformer变体。
七、GRU的现代替代品
7.1 性能对比
| 模型 | 训练速度 | 推理速度 | 长序列 | 内存效率 |
|---|---|---|---|---|
| GRU | 快 | 快 | 中 | 高 |
| LSTM | 慢 | 慢 | 中 | 低 |
| Transformer | 中 | 慢 | 强 | 低 |
| Mamba | 快 | 快 | 强 | 高 |
| RWKV | 快 | 快 | 中 | 高 |
7.2 Mamba vs GRU
Mamba(State Space Model):
- 选择性状态空间,输入依赖的参数
- 并行训练 + 快速推理
- 长序列优势明显
GRU:
- 固定门控,输入无关的递归
- 简单稳定,部署友好
- 中等长度序列足够
7.3 选择建议
def select_sequence_model(seq_len, data_size, latency_requirement):
"""根据需求选择序列模型"""
if latency_requirement == 'real-time':
if seq_len < 100:
return 'GRU' # 最快
else:
return 'Mamba'
elif data_size < 10000:
return 'GRU' # 数据量小时优势明显
elif seq_len < 500:
return 'GRU' # 中等序列
elif seq_len < 4000:
return 'Mamba' # 长序列
else:
return 'Mamba-2 or RWKV' # 极长序列八、GRU的训练技巧
8.1 学习率调度
def get_gru_training_config():
"""GRU训练推荐配置"""
return {
'optimizer': 'adam',
'learning_rate': 1e-3,
'gradient_clip': 1.0,
'batch_size': 64,
'weight_decay': 1e-5,
'dropout': 0.2,
'num_layers': 2,
'hidden_dim': 256,
'bidirectional': True, # 大多数任务有效
'num_epochs': 50,
}8.2 序列填充策略
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
class GRUPadded(nn.Module):
"""处理变长序列的GRU"""
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x, lengths):
"""
x: (batch, max_seq_len, input_dim)
lengths: (batch,) 实际长度
"""
# 打包变长序列
packed = pack_padded_sequence(
x, lengths.cpu(), batch_first=True, enforce_sorted=False
)
output, h_n = self.gru(packed)
# 解包
output, _ = pad_packed_sequence(output, batch_first=True)
# 取最后一个有效时间步
idx = (lengths - 1).view(-1, 1).expand(-1, output.size(2)).unsqueeze(1)
last_output = output.gather(1, idx).squeeze(1)
return self.fc(last_output)8.3 教师强制(Teacher Forcing)
class GRUTrainerWithTF:
"""带教师强制的GRU训练"""
def __init__(self, model, teacher_forcing_ratio=0.5):
self.model = model
self.tfr = teacher_forcing_ratio
def train_step(self, x, y, optimizer, criterion):
"""
x: (batch, seq_len, input_dim)
y: (batch, seq_len, output_dim)
"""
batch_size, seq_len, _ = x.shape
output_dim = y.size(-1)
# 初始化
h = torch.zeros(batch_size, self.model.hidden_dim, device=x.device)
outputs = torch.zeros(batch_size, seq_len, output_dim, device=x.device)
for t in range(seq_len):
h = self.model.gru_cell(x[:, t, :], h)
output = self.model.fc(h)
outputs[:, t] = output
# 教师强制
if self.training and torch.rand(1).item() < self.tfr and t < seq_len - 1:
# 使用真实标签作为下一步输入
x[:, t+1, :] = y[:, t, :]
loss = criterion(outputs, y)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
return loss.item()8.4 课程学习
class GRUCurriculumTraining:
"""GRU的课程学习:从短序列开始"""
def __init__(self, model, max_seq_len):
self.model = model
self.max_seq_len = max_seq_len
def train_epoch(self, train_data, optimizer, criterion, epoch, total_epochs):
# 当前序列长度
current_len = min(
self.max_seq_len,
int(self.max_seq_len * (epoch + 1) / total_epochs)
)
# 截断序列到current_len
truncated_data = []
for x, y in train_data:
if x.size(1) > current_len:
# 随机选择起点
start = torch.randint(0, x.size(1) - current_len + 1, (1,)).item()
x = x[:, start:start+current_len]
y = y[:, start:start+current_len]
truncated_data.append((x, y))
# 正常训练
total_loss = 0
for x, y in truncated_data:
optimizer.zero_grad()
output, _ = self.model(x)
loss = criterion(output, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
return total_loss / len(truncated_data)九、GRU的实际应用案例
9.1 文本分类
class GRUTextClassifier(nn.Module):
"""GRU文本分类器"""
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes,
num_layers=2, dropout=0.3):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.gru = nn.GRU(
embedding_dim, hidden_dim, num_layers,
batch_first=True, dropout=dropout, bidirectional=True
)
self.attention = nn.Linear(2 * hidden_dim, 1)
self.classifier = nn.Sequential(
nn.Linear(2 * hidden_dim, 128),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(128, num_classes)
)
def forward(self, x, mask=None):
# x: (batch, seq_len)
embedded = self.embedding(x) # (batch, seq_len, embed_dim)
output, _ = self.gru(embedded) # (batch, seq_len, 2*hidden)
# 注意力池化
attn_weights = self.attention(output).squeeze(-1)
if mask is not None:
attn_weights = attn_weights.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(attn_weights, dim=-1)
context = (output * attn_weights.unsqueeze(-1)).sum(dim=1)
return self.classifier(context)9.2 语音识别
class GRUSpeechRecognizer(nn.Module):
"""GRU语音识别"""
def __init__(self, input_dim, hidden_dim, num_classes, num_layers=5):
super().__init__()
# 多层GRU堆叠
self.gru = nn.GRU(
input_dim, hidden_dim, num_layers,
batch_first=True, dropout=0.2
)
# CTC损失需要的分类器
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
# x: (batch, seq_len, input_dim) - MFCC特征
output, _ = self.gru(x)
logits = self.classifier(output)
# CTC损失期望log_softmax
log_probs = F.log_softmax(logits, dim=-1)
return log_probs9.3 金融预测
class GRUForecaster(nn.Module):
"""GRU时间序列预测"""
def __init__(self, input_dim, hidden_dim, forecast_horizon):
super().__init__()
self.encoder = nn.GRU(input_dim, hidden_dim, num_layers=2, batch_first=True)
self.decoder = nn.GRU(input_dim, hidden_dim, num_layers=2, batch_first=True)
self.fc = nn.Linear(hidden_dim, forecast_horizon)
def forward(self, x, forecast_steps):
"""
x: (batch, seq_len, input_dim) - 历史序列
"""
# 编码
_, h_n = self.encoder(x)
# 自回归解码
decoder_input = x[:, -1:, :] # 最后一步作为初始输入
outputs = []
h = h_n
for _ in range(forecast_steps):
output, h = self.decoder(decoder_input, h)
outputs.append(output)
decoder_input = output # 使用上一步输出作为下一步输入
outputs = torch.cat(outputs, dim=1)
return self.fc(outputs)十、参考资料
最后更新:2026-06-21
Footnotes
-
Cho, K., et al. (2014). Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation. EMNLP 2014. ↩
-
Can, T., Krishnamurthy, K., & Schwab, D.J. (2020). Gating creates slow modes and controls phase-space complexity in GRUs and LSTMs. PMLR 107:476-511. http://proceedings.mlr.press/v107/can20a/can20a.pdf ↩
-
Hilgert, E. & Schwung, A. (2025). Lure-Postnikov Stability Analysis of Closed-Loop Control Systems with Gated Recurrent Neural Network-based Virtual Sensors. arXiv:2505.11539. https://www.arxiv.org/pdf/2505.11539 ↩
-
Wang, Y., et al. (2025). GRU: Mitigating the Trade-off between Unlearning and Retention for LLMs. arXiv:2503.09117. https://arxiv.org/html/2503.09117v3 ↩
-
Joshi, A. & Hauskrecht, M. (2026). Still Competitive: Revisiting Recurrent Models for Irregular Time Series Prediction. TMLR 01/2026. https://arxiv.org/pdf/2510.16161 ↩