LSTM与GRU

长短期记忆的门控机制 —— 从梯度困境到门控革命

一、引言:循环神经网络的困境

循环神经网络(Recurrent Neural Network, RNN)是处理序列数据(如文本、语音、时间序列)的基础架构。其核心思想是通过隐藏状态在时间步之间传递信息,使网络具备"记忆"能力。然而,标准RNN在处理长序列时面临严重的长期依赖问题(Long-Term Dependency Problem)。

长期依赖问题的根源

在反向传播过程中,梯度需要沿着时间步反向传播(Backpropagation Through Time, BPTT)。随着序列长度的增加,梯度会经历多次矩阵乘法,导致:

  • 梯度消失(Vanishing Gradient): 梯度指数级衰减到零,网络无法学习远距离的依赖关系
  • 梯度爆炸(Exploding Gradient): 梯度指数级增长,导致参数更新过大,训练不稳定

为解决这一问题,长短时记忆网络(Long Short-Term Memory, LSTM) 于1997年由 Hochreiter 和 Schmidhuber 提出。LSTM引入门控机制(Gating Mechanism)和细胞状态(Cell State),使得梯度可以沿着时间步以近似恒等的方式传播,从而有效缓解梯度消失问题。

此后,门控循环单元(Gated Recurrent Unit, GRU) 由 Cho 等人在2014年提出,作为LSTM的简化变体,在保持性能的前提下大幅减少了参数量。GRU将遗忘门和输入门合并为更新门(Update Gate),并取消了独立的细胞状态,使得模型更加简洁高效。

"LSTM的设计哲学是:让网络自己学会什么时候记住、什么时候忘记,而不是让架构强加一个固定的记忆窗口。"

二、LSTM核心思想:三扇门的艺术

LSTM的核心创新在于引入了细胞状态(Cell State, 记作 Ct)和三个门控单元。如果把标准RNN比作一个只会"覆盖"旧内容的便签本,LSTM则像一个配有"编辑团队"的文档系统:有专门的编辑决定删除什么(遗忘门)、添加什么(输入门)、输出什么(输出门)。

细胞状态 Ct:记忆的传送带

细胞状态是LSTM的核心组件,它像一条贯穿整个网络的传送带。与隐藏状态 ht 不同,细胞状态在时间步之间的信息传递只经过少量的线性操作(逐元素乘法和加法),而非非线性激活函数。这意味着梯度在沿着细胞状态反向传播时几乎不会衰减,从而有效解决了长期依赖问题。

Ct = 遗忘门 × 旧状态 + 输入门 × 候选新信息

2.1 遗忘门(Forget Gate, ft

遗忘门决定从细胞状态中丢弃哪些信息。它通过查看当前的输入 xt 和上一时刻的隐藏状态 ht-1,为细胞状态 Ct-1 中的每个元素输出一个介于0和1之间的值。

ft = σ( Wf · [ht-1, xt] + bf )

2.2 输入门(Input Gate, it

输入门决定哪些新信息需要存入细胞状态。它由两个部分组成:

it = σ( Wi · [ht-1, xt] + bi )
t = tanh( WC · [ht-1, xt] + bC )

2.3 细胞状态更新

在获得遗忘门和输入门的输出后,将旧细胞状态更新为新细胞状态:

Ct = ftCt-1 + itt

这个公式是LSTM的精髓所在

2.4 输出门(Output Gate, ot

输出门决定当前时刻的隐藏状态 ht 应该输出什么。它将细胞状态经过 tanh 压缩到 (-1, 1) 后,再通过输出门进行过滤:

ot = σ( Wo · [ht-1, xt] + bo )
ht = ot ⊗ tanh( Ct )
Ct-1 ——————————— Ct
▸ ⊗ ◂ ←→ ▸ ⊕ ◂ ←→ ▸ ⊗ ◂
ft      itt     ot
                            ht
[遗忘门]   [输入门 + 候选]   [输出门]

三、LSTM完整前向传播与PyTorch实现

综合上述三个门控机制,LSTM在时间步 t 的完整前向传播过程如下:

遗忘门:   ft = σ( Wf · [ht-1, xt] + bf )
输入门:   it = σ( Wi · [ht-1, xt] + bi )
候选状态:   C̃t = tanh( WC · [ht-1, xt] + bC )
细胞状态更新:   Ct = ft ⊗ Ct-1 + it ⊗ C̃t
输出门:   ot = σ( Wo · [ht-1, xt] + bo )
隐藏状态:   ht = ot ⊗ tanh( Ct )

下面是用 PyTorch 从零实现 LSTM 前向传播的代码:

# LSTM 前向传播的 NumPy 实现(不依赖 PyTorch 自动微分) import numpy as np def lstm_forward(x, h_prev, C_prev, params): """ LSTM 单步前向传播 Args: x: 当前时间步输入, shape (input_size,) h_prev: 上一时间步隐藏状态, shape (hidden_size,) C_prev: 上一时间步细胞状态, shape (hidden_size,) params: 权重字典 {Wf, Wi, Wc, Wo, bf, bi, bc, bo} Returns: h_next: 当前时间步隐藏状态 C_next: 当前时间步细胞状态 cache: 缓存用于反向传播 """ Wf, Wi, Wc, Wo = params['Wf'], params['Wi'], params['Wc'], params['Wo'] bf, bi, bc, bo = params['bf'], params['bi'], params['bc'], params['bo'] # 拼接输入: [h_prev, x] concat = np.concatenate([h_prev, x]) # 遗忘门 f_t = 1.0 / (1.0 + np.exp(-(Wf @ concat + bf))) # 输入门 i_t = 1.0 / (1.0 + np.exp(-(Wi @ concat + bi))) # 候选细胞状态 C_tilde = np.tanh(Wc @ concat + bc) # 细胞状态更新 C_next = f_t * C_prev + i_t * C_tilde # 输出门 o_t = 1.0 / (1.0 + np.exp(-(Wo @ concat + bo))) # 隐藏状态 h_next = o_t * np.tanh(C_next) return h_next, C_next, (f_t, i_t, C_tilde, C_next, o_t, h_next)

使用 PyTorch 的 nn.LSTM 模块则更加简洁:

import torch import torch.nn as nn # 定义 LSTM 层 input_size = 10 # 输入特征维度 hidden_size = 20 # 隐藏状态维度 num_layers = 2 # 堆叠层数 lstm = nn.LSTM( input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, batch_first = True, # 输入形状为 (batch, seq, feature) dropout = 0.2, # 多层时层间 dropout bidirectional = False ) # 前向传播 batch_size, seq_len = 32, 15 x = torch.randn(batch_size, seq_len, input_size) # 初始化隐藏状态和细胞状态(默认全零) h0 = torch.zeros(num_layers, batch_size, hidden_size) C0 = torch.zeros(num_layers, batch_size, hidden_size) output, (hn, Cn) = lstm(x, (h0, C0)) print(f"输出形状: {output.shape}") # (32, 15, 20) print(f"最终隐藏状态: {hn.shape}") # (2, 32, 20) print(f"最终细胞状态: {Cn.shape}") # (2, 32, 20)

batch_first 参数说明

PyTorch 的 LSTM 默认输入形状为 (seq_len, batch, feature)。设置 batch_first=True 后,输入形状变为 (batch, seq_len, feature),更符合直觉。此时输出的形状也相应变为 (batch, seq_len, hidden_size)

此外需要注意:

  • h0C0 的形状为 (num_layers * num_directions, batch, hidden_size)
  • 双向时 num_directions = 2,此时 hidden_size 通常取原始的一半以保证输出维度一致

四、LSTM变体与演进

自LSTM提出以来,研究者们提出了多种变体,旨在改进性能、减少参数或适应特定任务。以下是几个代表性变体:

4.1 Peephole Connections(窥视孔连接)

由 Gers 和 Schmidhuber(2000)提出,允许门控单元直接"窥视"细胞状态。在原始LSTM中,门控只依赖于 ht-1 和 xt,但 Peephole 连接将 Ct-1(或 Ct)也引入门控计算:

ft = σ( Wf · [Ct-1, ht-1, xt] + bf )
it = σ( Wi · [Ct-1, ht-1, xt] + bi )
ot = σ( Wo · [Ct, ht-1, xt] + bo )

这种设计让门控可以直接感知细胞状态,理论上可以更精确地控制信息流。但在实际应用中,Peephole 连接的收益并不显著,因此现代框架默认不启用此功能。

4.2 Coupled Forget and Input Gate(耦合遗忘与输入门)

将遗忘门和输入门耦合,使得 it = 1 - ft。这样细胞状态的更新公式变为:

Ct = ft ⊗ Ct-1 + (1 - ft) ⊗ C̃t

这种耦合意味着:忘记多少旧信息,就添加多少新信息。参数减少的同时,强制网络在遗忘和记忆之间做出权衡,在某些任务上可以防止过拟合。

4.3 GRU —— 门控循环单元

GRU(Gated Recurrent Unit)由 Cho 等人在2014年提出,是LSTM最成功的简化变体。它将遗忘门和输入门合并为更新门(Update Gate),并取消了独立的细胞状态,仅保留隐藏状态。下一节将详细讲解GRU的架构。

"GRU 的设计理念是:在保持 LSTM 核心能力的前提下,尽可能简化结构。实践证明,在许多任务上 GRU 与 LSTM 性能相当,但训练速度更快。"

五、GRU门控循环单元详解

GRU将LSTM的三个门简化为两个门重置门(Reset Gate, rt)和更新门(Update Gate, zt),并去除了独立的细胞状态。GRU的整体参数量约为LSTM的75%,但在许多序列建模任务上表现几乎相同甚至更好。

GRU的核心变革

  • 合并门控: 遗忘门 + 输入门 → 更新门 zt
  • 取消细胞状态: 只有隐藏状态 ht,不再有 Ct
  • 引入重置门: 控制历史信息对候选状态的影响程度
  • 线性自更新: 更新门同时控制遗忘和记忆,ht = zt ⊗ ht-1 + (1-zt) ⊗ h̃t

5.1 重置门(Reset Gate, rt

重置门决定忽略过去多少信息。它作用于候选隐藏状态的计算:

rt = σ( Wr · [ht-1, xt] + br )
t = tanh( Wh · [rt ⊗ ht-1, xt] + bh )

5.2 更新门(Update Gate, zt

更新门决定保留多少旧状态、加入多少新信息,同时承担了LSTM中遗忘门和输入门的功能:

zt = σ( Wz · [ht-1, xt] + bz )
ht = zt ⊗ ht-1 + (1 - zt) ⊗ h̃t

5.3 GRU 的 PyTorch 实现

# GRU 前向传播的 NumPy 实现 def gru_forward(x, h_prev, params): """ GRU 单步前向传播 Args: x: 当前输入, shape (input_size,) h_prev: 上一隐藏状态, shape (hidden_size,) params: {Wr, Wz, Wh, br, bz, bh} Returns: h_next: 当前隐藏状态 """ Wr, Wz, Wh = params['Wr'], params['Wz'], params['Wh'] br, bz, bh = params['br'], params['bz'], params['bh'] concat = np.concatenate([h_prev, x]) # 重置门 r_t = 1.0 / (1.0 + np.exp(-(Wr @ concat + br))) # 更新门 z_t = 1.0 / (1.0 + np.exp(-(Wz @ concat + bz))) # 候选隐藏状态(重置门控制历史信息流入) h_tilde = np.tanh(Wh @ np.concatenate([r_t * h_prev, x]) + bh) # 最终隐藏状态(更新门控制新旧比例) h_next = z_t * h_prev + (1 - z_t) * h_tilde return h_next, (r_t, z_t, h_tilde, h_next) # 使用 PyTorch nn.GRU gru = nn.GRU( input_size = 10, hidden_size = 20, num_layers = 2, batch_first = True, dropout = 0.2 ) x = torch.randn(32, 15, 10) # (batch, seq, feature) h0 = torch.zeros(2, 32, 20) output, hn = gru(x, h0) print(f"GRU 输出形状: {output.shape}") # (32, 15, 20) print(f"GRU 最终隐藏状态: {hn.shape}") # (2, 32, 20)

GRU与LSTM的关键差异

  • 参数量: GRU 约 3 × (hidden_size × (hidden_size + input_size)),LSTM 约 4 × 该值
  • 状态数量: LSTM 维护两个状态(ht, Ct),GRU 只维护一个状态(ht
  • 门控独立度: LSTM 的遗忘和记忆由独立门控制,GRU 的更新门强制权衡
  • 梯度传播: GRU 通过线性插值更新隐藏状态,同样有助于缓解梯度消失

六、LSTM vs GRU vs RNN 综合对比

为了更直观地理解三种架构的差异,我们从多个维度进行系统比较:

6.1 架构特性对比

对比维度 标准RNN LSTM GRU
提出时间 1986 1997 2014
门控数量 0 3(f, i, o) 2(r, z)
内部状态 ht 唯一状态 ht + Ct 双状态 ht 唯一状态
参数量 最少 最多 中等
缓解梯度消失
长期记忆能力 优秀 优秀
训练速度 最快 最慢 中等
小数据量表现 一般 较好 较好
大数据量扩展性 优秀 优秀

6.2 梯度流路径分析

三种架构在梯度传播上的根本差异在于:

"LSTM 和 GRU 之所以成功,不是因为它们更'聪明',而是因为它们为梯度提供了更'顺畅'的传播路径。加法比乘法更友善,这是深度学习的朴素真理。"

6.3 何时选择哪种架构?

  • 优先选 GRU:
    • 数据量有限(< 10万样本)
    • 计算资源受限(移动端、嵌入式)
    • 需要快速迭代和实验
    • 任务本身对长期依赖要求不高(如情感分类、短文本)
  • 优先选 LSTM:
    • 数据量充足,计算资源充裕
    • 任务需要精确建模长期依赖(如机器翻译、语音识别)
    • 需要更精细地控制信息流
    • 处理多模态序列数据(如视频描述生成)
  • 选 BiLSTM / BiGRU:
    • 任务可以访问整个序列(如文本分类、命名实体识别、阅读理解)
    • 双向上下文对理解当前时间步至关重要

七、双向LSTM(BiLSTM)

在许多序列标注任务中,当前时间步的输出不仅依赖过去的信息,也依赖未来的信息。例如,在命名实体识别(NER)中,"苹果"这个词在"苹果很好吃"和"苹果发布了新手机"中的实体类型不同,需要同时看到上下文才能准确判断。

双向LSTM(BiLSTM)通过两个独立的LSTM层分别从前向和后向处理序列,然后将两个方向的隐藏状态拼接起来:

ht = [ h→t ; h←t ]
# PyTorch BiLSTM 实现 bilstm = nn.LSTM( input_size = 10, hidden_size = 128, # 注意:双向时每方向维度为 hidden_size // 2 num_layers = 2, batch_first = True, bidirectional = True # 启用双向 ) x = torch.randn(32, 50, 10) output, (hn, Cn) = bilstm(x) print(f"BiLSTM 输出形状: {output.shape}") # (32, 50, 256) print(f"hn 形状: {hn.shape}") # (4, 32, 128) # 2层 x 2方向 # 分离前向和后向的最终隐藏状态 h_forward = hn[-2, :, :] # 最后一层前向 h_backward = hn[-1, :, :] # 最后一层后向 h_combined = torch.cat([h_forward, h_backward], dim=-1) print(f"拼接后的最终表示: {h_combined.shape}") # (32, 256)

BiLSTM 实践要点

  • hidden_size 的含义: 在双向LSTM中,hidden_size 指的是每个方向的输出维度。因此最终输出维度是 2 × hidden_size。若需要输出维度为 D,应设置 hidden_size = D // 2
  • 反向传播的不同: 前向LSTM从左到右,后向LSTM从右到左,两者参数独立训练
  • 适用场景: NER、POS Tagging、文本分类、机器翻译(Encoder端)等可以访问完整序列的任务
  • 不适用场景: 实时流式处理、在线预测等无法获取未来信息的任务

八、PyTorch 高级实践

8.1 多层堆叠 LSTM

通过堆叠多个LSTM层,可以构建更深的序列模型,捕获更抽象的时间特征。每层的隐藏状态 ht 作为下一层的输入。

class StackedLSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers, num_classes): super().__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.lstm = nn.LSTM( input_size, hidden_size, num_layers, batch_first=True, dropout=0.3 ) self.fc = nn.Linear(hidden_size, num_classes) def forward(self, x): # x: (batch, seq_len, input_size) h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) C0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) out, _ = self.lstm(x, (h0, C0)) # out: (batch, seq_len, hidden_size) # 取最后一个时间步的输出 last_out = out[:, -1, :] # (batch, hidden_size) logits = self.fc(last_out) return logits

8.2 LSTM 文本分类完整示例

import torch import torch.nn as nn import torch.optim as optim class LSTMClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, hidden_size, num_layers, num_classes, dropout=0.5): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) self.lstm = nn.LSTM( embed_dim, hidden_size, num_layers, batch_first=True, bidirectional=False, dropout=dropout if num_layers > 1 else 0 ) self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(hidden_size, num_classes) def forward(self, x): # x: (batch, seq_len) token ids emb = self.embedding(x) # (batch, seq_len, embed_dim) out, (hn, Cn) = self.lstm(emb) # out: (batch, seq_len, hidden_size) # 方法1: 取最后时间步输出 last_out = out[:, -1, :] # (batch, hidden_size) # 方法2: 所有时间步平均池化(更鲁棒) # mask = (x != 0).unsqueeze(-1).float() # avg_out = (out * mask).sum(dim=1) / mask.sum(dim=1) logits = self.fc(self.dropout(last_out)) return logits # 训练代码 model = LSTMClassifier( vocab_size=30000, embed_dim=100, hidden_size=128, num_layers=2, num_classes=5 # 5分类 ) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) # 模拟训练循环 for epoch in range(10): for x_batch, y_batch in dataloader: # 假设有 dataloader logits = model(x_batch) loss = criterion(logits, y_batch) optimizer.zero_grad() loss.backward() # 梯度裁剪——防止梯度爆炸 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) optimizer.step() print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

梯度裁剪(Gradient Clipping)

即使使用LSTM,梯度爆炸仍可能发生在序列的开始部分或训练初期。clip_grad_norm_ 将梯度的 L2 范数裁剪到指定阈值以内,是训练 RNN 系列模型的标准操作

推荐阈值范围通常在 1.0 ~ 5.0 之间。设置过小会导致训练缓慢,设置过大则失去保护作用。

8.3 nn.LSTM 与 nn.GRU 的 PackedSequence 处理

在处理变长序列时(如NLP中的不同长度句子),使用 pack_padded_sequencepad_packed_sequence 可以避免填充 token 参与计算,提升效率和准确性:

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence # 变长序列处理 # sequences: 已填充的批量序列 (batch, max_seq_len) # lengths: 每个序列的实际长度 (batch,) packed = pack_padded_sequence( sequences, lengths, batch_first=True, enforce_sorted=False # 无需预先按长度排序 ) # 通过 LSTM packed_out, (hn, Cn) = lstm(packed) # 解包回标准 tensor output, output_lengths = pad_packed_sequence( packed_out, batch_first=True ) # output: (batch, max_seq_len, hidden_size) — 非填充位置包含有效输出

九、核心要点总结

LSTM 与 GRU 知识体系总结

  • 核心问题: 标准RNN因梯度消失/爆炸无法处理长序列依赖,LSTM和GRU通过门控机制和线性状态更新路径解决这一问题
  • LSTM三要素: 细胞状态 Ct(记忆传送带)+ 三扇门 ft/it/ot(信息控制器),分别控制遗忘、输入和输出
  • LSTM前向传播: Ct = ft ⊗ Ct-1 + it ⊗ C̃t,加法使梯度线性传播,有效缓解消失
  • GRU简化: 三个门合并为两个(重置门 rt、更新门 zt),取消细胞状态,参数量减少约25%
  • GRU核心公式: ht = zt ⊗ ht-1 + (1-zt) ⊗ h̃t,更新门同时控制遗忘和记忆
  • BiLSTM: 双向处理序列,拼接前向和后向隐藏状态,适用于标注类任务
  • 实践建议: 数据少/资源受限选GRU,数据足/长依赖任务选LSTM,序列标注任务选BiLSTM/BiGRU
  • 训练技巧: 梯度裁剪(clip_grad_norm_)、PackedSequence处理变长序列、合理的dropout设置
  • 现代替代: 在大规模NLP任务中,Transformer已逐步取代LSTM成为主流,但在小规模数据、时间序列、流式任务中LSTM/GRU仍具优势

十、进一步思考

LSTM 和 GRU 是序列建模领域的重要里程碑,但深度学习技术在快速演进。以下几个方向值得持续关注:

延伸学习方向

  • Transformer 与 Attention: 自注意力机制彻底改变了序列建模范式,但在处理超长序列时存在 O(n²) 的计算复杂度问题
  • Efficient Transformers: Longformer、BigBird、Reformer 等架构试图在 Transformer 效率与 RNN 类线性复杂度之间取得平衡
  • State Space Models: Mamba(S6)等新架构结合了 RNN 的线性复杂度与 Transformer 的表达能力,是当前的研究热点
  • 时间序列 Foundation Model: TimesFM、Lag-Llama 等预训练时间序列模型正在改变传统 LSTM 在时序预测领域的统治地位
  • 多模态序列建模: 视频理解、语音-文本联合建模等场景仍然是 LSTM 及其变体的活跃应用领域

实践建议:从理论到代码

  • 用 PyTorch 从零实现 LSTM 和 GRU 的前向传播(不借助 nn.LSTM),加深对门控机制的理解
  • 在真实数据集(如 IMDB 情感分类、Penn TreeBank 语言模型)上对比 LSTM vs GRU 的性能和训练速度
  • 尝试不同的隐藏层维度、层数、dropout 率,观察对收敛速度和最终性能的影响
  • 将学习笔记与 PyTorch 官方文档(nn.LSTMnn.GRU)配合阅读,掌握参数细节和最佳实践