RNN 为什么记不住:梯度消失
普通 RNN 理论上能记住很久以前的信息,实际却常常“记不住”。问题出在训练:要让早期输入影响最终结果,误差信号得沿着时间一步步往回传,每回传一步就乘上一次循环权重 w。于是传回 k 步后,信号大约变成 wᵏ——只要 w 稍小于 1,传回几十步就衰减到几乎为 0(梯度消失,早期输入学不动、被“遗忘”);稍大于 1 又会爆炸(梯度爆炸,训练发散)。只有 w≈1 的窄缝才稳定,却极难凑到。拖动 w,看误差信号沿时间回传时怎样消失或爆炸——这正是 LSTM 用“门控细胞”要解决的问题。
每根条是某个时间步的输入对最终结果的影响强度(梯度),右边是最近一步、左边是最久远一步。看 w 怎样决定久远信息是被“记住”还是“消失/爆炸”。
w<1:消失
梯度按 wᵏ 指数衰减,回传几十步就≈0,早期输入学不到——长程记忆丢失。
w>1:爆炸
梯度指数放大,数值溢出、训练发散。需要梯度裁剪等手段救急。
LSTM 的解法
用“细胞状态+门”让信息近乎原样直传(≈乘 1),绕开 wᵏ 衰减,记得更久。