概述

门控循环单元(Gated Recurrent Unit, GRU)是Cho et al. (2014)提出的简化版LSTM变体。本文档系统整理:

  1. GRU架构详解:重置门、更新门、候选激活
  2. GRU vs LSTM:参数对比、性能对比、长时依赖能力
  3. 理论分析
    • Can et al. 2020: 门控创造慢模态与相空间复杂度控制
    • Hilgert 2025: Lyapunov稳定性分析
    • Livi 2025: 可学习窗口理论
  4. GRU的现代应用
    • LLM Unlearning(Wang et al. 2025)
    • 不规则时间序列
  5. GRU的变体:MGRU(Minimal Gated Recurrent Unit)等
  6. 与现代RNN替代品对比:Mamba, RWKV

GRU是LSTM的”轻量化”版本,在很多任务上性能相当但参数更少。1


一、GRU架构详解

1.1 基本方程

GRU通过两个门控制信息流:

重置门(Reset Gate)

更新门(Update Gate)

候选激活

最终隐藏状态

其中:

  • :sigmoid函数
  • :tanh激活
  • :逐元素乘积(Hadamard积)

1.2 门控机制的直觉

重置门

控制忽略多少过去的信息

  • :保留全部过去信息(默认)
  • :完全忽略过去,仅看当前输入

更新门

控制保留多少过去的状态

  • (跳过当前时间步)
  • (用当前输入覆盖)

1.3 与LSTM的对应

概念GRULSTM
输入门-
遗忘门-
输出门-
更新门类似 组合
重置门-
候选状态
隐藏状态

参数对比(隐藏维度 ,输入维度 ):

  • GRU
  • LSTM

GRU参数约为LSTM的 75%

1.4 PyTorch实现

import torch
import torch.nn as nn
 
 
class GRUCellManual(nn.Module):
    """手动实现的GRU单元(理解用)"""
    
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        # 重置门参数
        self.W_r = nn.Linear(input_dim, hidden_dim, bias=False)
        self.U_r = nn.Linear(hidden_dim, hidden_dim)
        
        # 更新门参数
        self.W_z = nn.Linear(input_dim, hidden_dim, bias=False)
        self.U_z = nn.Linear(hidden_dim, hidden_dim)
        
        # 候选激活参数
        self.W = nn.Linear(input_dim, hidden_dim, bias=False)
        self.U = nn.Linear(hidden_dim, hidden_dim)
        
        self._init_weights()
    
    def _init_weights(self):
        """正交初始化"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, x, h_prev):
        """
        x: (batch, input_dim)
        h_prev: (batch, hidden_dim)
        """
        # 重置门
        r = torch.sigmoid(self.W_r(x) + self.U_r(h_prev))
        
        # 更新门
        z = torch.sigmoid(self.W_z(x) + self.U_z(h_prev))
        
        # 候选隐藏状态
        h_hat = torch.tanh(self.W(x) + self.U(r * h_prev))
        
        # 新隐藏状态
        h = (1 - z) * h_prev + z * h_hat
        
        return h
 
 
class GRUModel(nn.Module):
    """完整GRU序列模型"""
    
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1, 
                 dropout=0.0, bidirectional=False):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        
        self.gru = nn.GRU(
            input_dim, hidden_dim, num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional
        )
        
        # 输出层
        out_dim = hidden_dim * 2 if bidirectional else hidden_dim
        self.fc = nn.Linear(out_dim, output_dim)
    
    def forward(self, x, h0=None):
        """
        x: (batch, seq_len, input_dim)
        返回: output (batch, seq_len, output_dim), h_n (...)
        """
        batch_size = x.size(0)
        if h0 is None:
            num_dirs = 2 if self.bidirectional else 1
            h0 = torch.zeros(
                self.num_layers * num_dirs, batch_size, self.hidden_dim,
                device=x.device
            )
        
        output, h_n = self.gru(x, h0)
        output = self.fc(output)
        return output, h_n

二、GRU的理论分析

2.1 Can等人的”慢模态”理论

Can, Krishnamurthy, Schwab (PMLR 2020) “Gating creates slow modes and controls phase-space complexity in GRUs and LSTMs” 提供GRU/LSTM的动力学理论2

核心问题

门控机制如何影响网络的相空间结构?特别是:

  • 为何门控可以学习长时依赖?
  • 门控如何创造慢时间尺度

2.2 慢模态的涌现

定理(Can-Krishnamurthy-Schwab 2020):

GRU通过学习单位矩阵主导的权重矩阵,自然涌现慢模态

其中 是单位矩阵, 是扰动。

慢时间常数

2.3 谱分析

GRU的Jacobian在稳定点附近:

关键观察:当 时,(慢模态);当 时, 由候选激活决定(快模态)。

2.4 相空间复杂度

门控机制控制有效维度

经验发现

  • GRU在训练中自然找到 任务所需的最小维度
  • LSTM倾向于更高 ,但难以控制
  • 这解释了GRU参数效率的部分原因

2.5 Hilgert的Lyapunov稳定性分析

Hilgert & Schwung (arXiv 2505.11539, 2025) 将GRU/LSTM的绝对稳定性通过Lure-Postnikov Lyapunov函数分析。3

关键结果

当GRU/LSTM的门控接近常数()时,系统的Lyapunov函数为:

条件:存在 使得 ,其中

对闭环控制的意义

将GRU/LSTM作为虚拟传感器时,门控必须满足绝对稳定性条件。当门控剧烈变化时,Lyapunov条件可能违反,导致不稳定。

def check_lure_postnikov_condition(gru_cell, P):
    """
    检查Lure-Postnikov条件
    
    P: 正定矩阵
    A: 线性化后的状态矩阵
    """
    # 提取权重
    W_z = gru_cell.W_z.weight  # 更新门输入权重
    U_z = gru_cell.U_z.weight  # 更新门循环权重
    W = gru_cell.W.weight      # 候选激活输入权重
    U = gru_cell.U.weight      # 候选激活循环权重
    
    # 平衡点处的状态矩阵
    # z* ≈ sigmoid(W_z @ x + U_z @ h*)
    # h* ≈ z* * h_hat* + (1-z*) * h_prev*
    
    # 简化的A矩阵(z*=常数情况)
    z_star = 0.5  # 假设平均门控值
    A = (1 - z_star) * torch.eye(W.size(0)) + z_star * U
    
    # 检查 P A + A^T P < 0
    PA = P @ A
    AP = A.t() @ P
    test = PA + AP
    
    is_stable = torch.all(torch.linalg.eigvalsh(test) < 0).item()
    return is_stable, test

2.6 Livi的可学习窗口理论

详见 vanilla-rnn-deep-theory.md 中的相关讨论。Livi (2025) 证明GRU的可学习窗口为:

其中 是GRU门控的平均有效值


三、GRU vs LSTM 实证对比

3.1 经典基准(Chung et al. 2014)

任务GRULSTM
语音信号建模相似相似
音乐建模略好略好
自然语言处理相似相似

3.2 参数效率

def count_params(gru, lstm):
    """对比GRU和LSTM参数数量"""
    return {
        'GRU': sum(p.numel() for p in gru.parameters()),
        'LSTM': sum(p.numel() for p in lstm.parameters())
    }
 
 
input_dim = 100
hidden_dim = 256
 
gru = nn.GRU(input_dim, hidden_dim, num_layers=2)
lstm = nn.LSTM(input_dim, hidden_dim, num_layers=2)
 
params = count_params(gru, lstm)
print(f"GRU参数: {params['GRU']:,}")
print(f"LSTM参数: {params['LSTM']:,}")
print(f"参数比: {params['GRU']/params['LSTM']:.2%}")

3.3 训练速度对比

经验法则

  • GRU训练速度约为LSTM的1.3-1.5倍
  • 内存占用约为LSTM的75-80%

3.4 何时选择哪个

选择GRU的场景

  • 数据量有限
  • 模型大小是约束
  • 训练时间有限
  • 简单的序列任务

选择LSTM的场景

  • 极长序列(> 1000步)
  • 需要单元状态和隐藏状态分离
  • 已有的LSTM基础架构

四、GRU的变体

4.1 MGU(Minimal Gated Unit)

最简门控单元:仅一个门(遗忘门),结合重置门。

class MGUCell(nn.Module):
    """最小门控单元"""
    
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.W_f = nn.Linear(input_dim, hidden_dim)
        self.U_f = nn.Linear(hidden_dim, hidden_dim)
        self.W = nn.Linear(input_dim, hidden_dim)
        self.U = nn.Linear(hidden_dim, hidden_dim)
    
    def forward(self, x, h_prev):
        # 仅一个遗忘门
        f = torch.sigmoid(self.W_f(x) + self.U_f(h_prev))
        h_hat = torch.tanh(self.W(x) + self.U(h_prev))
        h = (1 - f) * h_prev + f * h_hat
        return h

4.2 Bidirectional GRU

class BiGRUClassifier(nn.Module):
    """双向GRU分类器"""
    
    def __init__(self, input_dim, hidden_dim, num_classes, num_layers=2):
        super().__init__()
        self.gru = nn.GRU(
            input_dim, hidden_dim, num_layers,
            batch_first=True, bidirectional=True,
            dropout=0.2
        )
        self.classifier = nn.Sequential(
            nn.Linear(2 * hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        # x: (batch, seq_len, input_dim)
        output, h_n = self.gru(x)
        
        # 拼接前向和后向最终状态
        h_fwd = h_n[-2]  # 前向最后层
        h_bwd = h_n[-1]  # 后向最后层
        h_combined = torch.cat([h_fwd, h_bwd], dim=-1)
        
        return self.classifier(h_combined)

4.3 Stacked GRU

class StackedGRU(nn.Module):
    """多层GRU"""
    
    def __init__(self, input_dim, hidden_dim, num_layers=3, dropout=0.3):
        super().__init__()
        self.gru = nn.GRU(
            input_dim, hidden_dim, num_layers,
            batch_first=True,
            dropout=dropout
        )
    
    def forward(self, x):
        # x: (batch, seq_len, input_dim)
        # 多层GRU逐层处理
        output, h_n = self.gru(x)
        return output, h_n

4.4 Peephole GRU

让门控看到单元状态

class PeepholeGRUCell(nn.Module):
    """Peephole GRU"""
    
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        # 除了输入和隐藏,门控还看到上一时刻的"状态"
        # 对于GRU,简化为门控看到上一时刻的隐藏
        self.W_r = nn.Linear(input_dim, hidden_dim, bias=False)
        self.U_r = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.V_r = nn.Linear(hidden_dim, hidden_dim, bias=False)  # peephole
        
        # ... 类似地更新门和候选激活

五、GRU在LLM Unlearning中的应用(2025)

5.1 问题背景

Wang et al. (arXiv 2503.09117, 2025) “GRU: Mitigating the Trade-off between Unlearning and Retention for LLMs”。4

LLM Unlearning的挑战

  • 完全遗忘(unlearning)会损害模型通用能力
  • 部分遗忘可能留下隐私泄露
  • 需要精细控制遗忘程度

5.2 GRU方法

核心思想:将门控机制引入LLM的遗忘过程。

class UnlearningGate(nn.Module):
    """遗忘门控"""
    
    def __init__(self, hidden_dim, target_modules):
        super().__init__()
        self.gate = nn.Linear(hidden_dim, len(target_modules))
        self.target_modules = target_modules
    
    def forward(self, x, unlearn_signal):
        """
        x: 隐藏状态
        unlearn_signal: 遗忘信号强度(0-1)
        """
        # 计算每个模块的遗忘权重
        weights = torch.sigmoid(self.gate(x) + unlearn_signal)
        
        # 加权应用遗忘
        output = x.clone()
        for i, module in enumerate(self.target_modules):
            output = output * (1 - weights[:, i].unsqueeze(-1))
        
        return output, weights

5.3 实验结果

遗忘-保留权衡

  • 完全遗忘:95%遗忘率 + 30%性能下降
  • GRU方法:93%遗忘率 + 仅8%性能下降

结论:门控机制可以显著改善unlearning的精度。


六、不规则时间序列建模

6.1 问题

Joshi & Hauskrecht (TMLR 2026) “Still Competitive: Revisiting Recurrent Models for Irregular Time Series Prediction” 重新审视RNN在不规则时间序列上的能力。5

问题定义

  • 观测时间戳不规则
  • 可能有缺失值
  • 不同特征采样率不同

6.2 GRU的改进

class GRUWithTimeGap(nn.Module):
    """处理时间间隔的GRU"""
    
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.gru_cell = nn.GRUCell(input_dim, hidden_dim)
        self.time_encoder = nn.Sequential(
            nn.Linear(1, 32),
            nn.ReLU(),
            nn.Linear(32, hidden_dim)
        )
    
    def forward(self, x, timestamps):
        """
        x: (batch, seq_len, input_dim)
        timestamps: (batch, seq_len) 时间戳
        """
        batch_size, seq_len, _ = x.shape
        h = torch.zeros(batch_size, self.gru_cell.hidden_size, device=x.device)
        
        prev_time = torch.zeros(batch_size, 1, device=x.device)
        outputs = []
        
        for t in range(seq_len):
            # 计算时间间隔
            dt = timestamps[:, t:t+1] - prev_time
            
            # 时间编码
            time_encoding = self.time_encoder(dt)
            
            # 用时间编码调制隐藏状态(衰减)
            decay = torch.exp(-0.1 * dt)  # 时间衰减
            h = h * decay + time_encoding * (1 - decay)
            
            # GRU更新
            h = self.gru_cell(x[:, t, :], h)
            outputs.append(h)
            
            prev_time = timestamps[:, t:t+1]
        
        return torch.stack(outputs, dim=1), h

6.3 结论

Joshi & Hauskrecht 证明:经过简单改造,GRU在不规则时间序列上仍具有竞争力,甚至优于某些Transformer变体。


七、GRU的现代替代品

7.1 性能对比

模型训练速度推理速度长序列内存效率
GRU
LSTM
Transformer
Mamba
RWKV

7.2 Mamba vs GRU

Mamba(State Space Model):

  • 选择性状态空间,输入依赖的参数
  • 并行训练 + 快速推理
  • 长序列优势明显

GRU

  • 固定门控,输入无关的递归
  • 简单稳定,部署友好
  • 中等长度序列足够

7.3 选择建议

def select_sequence_model(seq_len, data_size, latency_requirement):
    """根据需求选择序列模型"""
    
    if latency_requirement == 'real-time':
        if seq_len < 100:
            return 'GRU'  # 最快
        else:
            return 'Mamba'
    elif data_size < 10000:
        return 'GRU'  # 数据量小时优势明显
    elif seq_len < 500:
        return 'GRU'  # 中等序列
    elif seq_len < 4000:
        return 'Mamba'  # 长序列
    else:
        return 'Mamba-2 or RWKV'  # 极长序列

八、GRU的训练技巧

8.1 学习率调度

def get_gru_training_config():
    """GRU训练推荐配置"""
    return {
        'optimizer': 'adam',
        'learning_rate': 1e-3,
        'gradient_clip': 1.0,
        'batch_size': 64,
        'weight_decay': 1e-5,
        'dropout': 0.2,
        'num_layers': 2,
        'hidden_dim': 256,
        'bidirectional': True,  # 大多数任务有效
        'num_epochs': 50,
    }

8.2 序列填充策略

from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
 
 
class GRUPadded(nn.Module):
    """处理变长序列的GRU"""
    
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x, lengths):
        """
        x: (batch, max_seq_len, input_dim)
        lengths: (batch,) 实际长度
        """
        # 打包变长序列
        packed = pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        output, h_n = self.gru(packed)
        
        # 解包
        output, _ = pad_packed_sequence(output, batch_first=True)
        
        # 取最后一个有效时间步
        idx = (lengths - 1).view(-1, 1).expand(-1, output.size(2)).unsqueeze(1)
        last_output = output.gather(1, idx).squeeze(1)
        
        return self.fc(last_output)

8.3 教师强制(Teacher Forcing)

class GRUTrainerWithTF:
    """带教师强制的GRU训练"""
    
    def __init__(self, model, teacher_forcing_ratio=0.5):
        self.model = model
        self.tfr = teacher_forcing_ratio
    
    def train_step(self, x, y, optimizer, criterion):
        """
        x: (batch, seq_len, input_dim)
        y: (batch, seq_len, output_dim)
        """
        batch_size, seq_len, _ = x.shape
        output_dim = y.size(-1)
        
        # 初始化
        h = torch.zeros(batch_size, self.model.hidden_dim, device=x.device)
        outputs = torch.zeros(batch_size, seq_len, output_dim, device=x.device)
        
        for t in range(seq_len):
            h = self.model.gru_cell(x[:, t, :], h)
            output = self.model.fc(h)
            outputs[:, t] = output
            
            # 教师强制
            if self.training and torch.rand(1).item() < self.tfr and t < seq_len - 1:
                # 使用真实标签作为下一步输入
                x[:, t+1, :] = y[:, t, :]
        
        loss = criterion(outputs, y)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        optimizer.step()
        
        return loss.item()

8.4 课程学习

class GRUCurriculumTraining:
    """GRU的课程学习:从短序列开始"""
    
    def __init__(self, model, max_seq_len):
        self.model = model
        self.max_seq_len = max_seq_len
    
    def train_epoch(self, train_data, optimizer, criterion, epoch, total_epochs):
        # 当前序列长度
        current_len = min(
            self.max_seq_len,
            int(self.max_seq_len * (epoch + 1) / total_epochs)
        )
        
        # 截断序列到current_len
        truncated_data = []
        for x, y in train_data:
            if x.size(1) > current_len:
                # 随机选择起点
                start = torch.randint(0, x.size(1) - current_len + 1, (1,)).item()
                x = x[:, start:start+current_len]
                y = y[:, start:start+current_len]
            truncated_data.append((x, y))
        
        # 正常训练
        total_loss = 0
        for x, y in truncated_data:
            optimizer.zero_grad()
            output, _ = self.model(x)
            loss = criterion(output, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item()
        
        return total_loss / len(truncated_data)

九、GRU的实际应用案例

9.1 文本分类

class GRUTextClassifier(nn.Module):
    """GRU文本分类器"""
    
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes, 
                 num_layers=2, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.gru = nn.GRU(
            embedding_dim, hidden_dim, num_layers,
            batch_first=True, dropout=dropout, bidirectional=True
        )
        self.attention = nn.Linear(2 * hidden_dim, 1)
        self.classifier = nn.Sequential(
            nn.Linear(2 * hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x, mask=None):
        # x: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed_dim)
        output, _ = self.gru(embedded)  # (batch, seq_len, 2*hidden)
        
        # 注意力池化
        attn_weights = self.attention(output).squeeze(-1)
        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, -1e9)
        attn_weights = F.softmax(attn_weights, dim=-1)
        
        context = (output * attn_weights.unsqueeze(-1)).sum(dim=1)
        return self.classifier(context)

9.2 语音识别

class GRUSpeechRecognizer(nn.Module):
    """GRU语音识别"""
    
    def __init__(self, input_dim, hidden_dim, num_classes, num_layers=5):
        super().__init__()
        # 多层GRU堆叠
        self.gru = nn.GRU(
            input_dim, hidden_dim, num_layers,
            batch_first=True, dropout=0.2
        )
        # CTC损失需要的分类器
        self.classifier = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x):
        # x: (batch, seq_len, input_dim) - MFCC特征
        output, _ = self.gru(x)
        logits = self.classifier(output)
        # CTC损失期望log_softmax
        log_probs = F.log_softmax(logits, dim=-1)
        return log_probs

9.3 金融预测

class GRUForecaster(nn.Module):
    """GRU时间序列预测"""
    
    def __init__(self, input_dim, hidden_dim, forecast_horizon):
        super().__init__()
        self.encoder = nn.GRU(input_dim, hidden_dim, num_layers=2, batch_first=True)
        self.decoder = nn.GRU(input_dim, hidden_dim, num_layers=2, batch_first=True)
        self.fc = nn.Linear(hidden_dim, forecast_horizon)
    
    def forward(self, x, forecast_steps):
        """
        x: (batch, seq_len, input_dim) - 历史序列
        """
        # 编码
        _, h_n = self.encoder(x)
        
        # 自回归解码
        decoder_input = x[:, -1:, :]  # 最后一步作为初始输入
        outputs = []
        h = h_n
        
        for _ in range(forecast_steps):
            output, h = self.decoder(decoder_input, h)
            outputs.append(output)
            decoder_input = output  # 使用上一步输出作为下一步输入
        
        outputs = torch.cat(outputs, dim=1)
        return self.fc(outputs)

十、参考资料


最后更新:2026-06-21

Footnotes

  1. Cho, K., et al. (2014). Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation. EMNLP 2014.

  2. Can, T., Krishnamurthy, K., & Schwab, D.J. (2020). Gating creates slow modes and controls phase-space complexity in GRUs and LSTMs. PMLR 107:476-511. http://proceedings.mlr.press/v107/can20a/can20a.pdf

  3. Hilgert, E. & Schwung, A. (2025). Lure-Postnikov Stability Analysis of Closed-Loop Control Systems with Gated Recurrent Neural Network-based Virtual Sensors. arXiv:2505.11539. https://www.arxiv.org/pdf/2505.11539

  4. Wang, Y., et al. (2025). GRU: Mitigating the Trade-off between Unlearning and Retention for LLMs. arXiv:2503.09117. https://arxiv.org/html/2503.09117v3

  5. Joshi, A. & Hauskrecht, M. (2026). Still Competitive: Revisiting Recurrent Models for Irregular Time Series Prediction. TMLR 01/2026. https://arxiv.org/pdf/2510.16161