RNN - LSTM - GRU

Posted massquantity

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了RNN - LSTM - GRU相关的知识,希望对你有一定的参考价值。


循环神经网络 (Recurrent Neural Network,RNN) 是一类具有短期记忆能力的神经网络,因而常用于序列建模。本篇先总结 RNN 的基本概念,以及其训练中时常遇到梯度爆炸和梯度消失问题,再引出 RNN 的两个主流变种 —— LSTM 和 GRU。


Vanilla RNN



Vanilla RNN 的主体结构:

技术图片


上图中 (f{X, h, y}) 都是向量,公式如下:
[ % <![CDATA[ egin{align} extbf{h}_{t} &= f_{ extbf{W}}left( extbf{h}_{t-1}, extbf{x}_{t} ight) ag{1} \\textbf{h}_{t} &= fleft( extbf{W}_{hx} extbf{x}_{t} + extbf{W}_{hh} extbf{h}_{t-1} + extbf{b}_{h} ight) ag{2a} \\textbf{h}_{t} &= extbf{tanh}left( extbf{W}_{hx} extbf{x}_{t} + extbf{W}_{hh} extbf{h}_{t-1} + extbf{b}_{h} ight) ag{2b} \\hat{ extbf{y}}_{t} &= extbf{softmax}left( extbf{W}_{yh} extbf{h}_{t} + extbf{b}_{y} ight) ag{3} end{align} %]]> ]
其中 ( extbf{W}_{hx} in mathbb{R}^{h imes x}, ; extbf{W}_{hh} in mathbb{R}^{h imes h}, ; extbf{W}_{yh} in mathbb{R}^{y imes h}, ; extbf{b}_{h} in mathbb{R}^{h}, ; extbf{b}_{y} in mathbb{R}^{y})


((2a)) 式中的两个矩阵 (mathbf{W}) 可以合并:
[ egin{align*} extbf{h}_{t} &= fleft( extbf{W}_{hx} extbf{x}_{t} + extbf{W}_{hh} extbf{h}_{t-1} + extbf{b}_{h} ight) & = fleft(left( extbf{W}_{hx}, extbf{W}_{hh} ight) egin{pmatrix} extbf{x}_t \\textbf{h}_{t-1} end{pmatrix} + extbf{b}_{h} ight) & = fleft( extbf{W} egin{pmatrix} extbf{x}_t \\textbf{h}_{t-1} end{pmatrix} + extbf{b}_{h} ight) end{align*} ]


注意到在计算时,每一 time step 中使用的参数 ( extbf{W}, ; extbf{b}) 是一样的,也就是说每个步骤的参数都是共享的,这是RNN的重要特点。

和普通的全连接层相比,RNN 除了输入 ( extbf{x}_t) 外,还有输入隐藏层上一节点 (mathbf{h}_{t-1}) ,RNN 每一层的输出就是这两个输入用矩阵 ( extbf{W}_{hx})( extbf{W}_{hh})和激活函数进行组合的结果。从 ((2a)) 式可以看出 ( extbf{x}_t)(mathbf{h}_{t-1}) 都是与 ( extbf{h}_h) 全连接的,下图形象展示了各个时间节点 RNN 隐藏层记忆的变化。随着时间流逝,最初的蓝色结点保留地越来越少,这意味着RNN对于长时记忆的困难。

技术图片




Vanishing & Exploding Gradient Problems

RNN 对于长时记忆的困难主要来源于梯度爆炸 / 消失问题,下面进行说明。RNN 中 Loss 的计算图示例:


技术图片


总的 Loss 是每个 time step 的加和 : (mathcal{large{L}} (hat{ extbf{y}}, extbf{y}) = sum_{t = 1}^{T} mathcal{ large{L} }(hat{ extbf{y}_t}, extbf{y}_{t}))


backpropagation through time (BPTT) 算法,参数的梯度为:
[ frac{partial oldsymbol{mathcal{L}}}{partial extbf{W}} = sum_{t=1}^{T} frac{partial oldsymbol{mathcal{L}}_{t}}{partial extbf{W}} = sum_{t=1}^{T} frac{partial oldsymbol{mathcal{L}}_t}{partial extbf{y}_{t}} frac{partial extbf{y}_{t}}{partial extbf{h}_{t}} overbrace{frac{partial extbf{h}_{t}}{partial extbf{h}_{k}}}^{ igstar } frac{partial extbf{h}_{k}}{partial extbf{W}} ]
其中 (frac{partial extbf{h}_{t}}{partial extbf{h}_{k}}) 包含一系列 ( ext{Jacobian}) 矩阵,
[ frac{partial extbf{h}_{t}}{partial extbf{h}_{k}} = frac{partial extbf{h}_{t}}{partial extbf{h}_{t-1}} frac{partial extbf{h}_{t-1}}{partial extbf{h}_{t-2}} cdots frac{partial extbf{h}_{k+1}}{partial extbf{h}_{k}} = prod_{i=k+1}^{t} frac{partial extbf{h}_{i}}{partial extbf{h}_{i-1}} ]
由于 RNN 中每个 time step 都是用相同的 ( extbf{W}) ,所以由 ((2a)) 式可得:
[ prod_{i=k+1}^{t} frac{partial extbf{h}_{i}}{partial extbf{h}_{i-1}} = prod_{i=k+1}^{t} extbf{W}^ op ext{diag} left[ f‘left( extbf{h}_{i-1} ight) ight] ]


由于 ( extbf{W}_{hh} in mathbb{R}^{h imes h}) 为方阵,对其进行特征值分解:
[ mathbf{W} = mathbf{V} , ext{diag}(oldsymbol{lambda}) , mathbf{V}^{-1} ]
由于上式是连乘 ( ext{t})(mathbf{W}) :
[ mathbf{W}^t = (mathbf{V} , ext{diag}(oldsymbol{lambda}) , mathbf{V}^{-1})^t = mathbf{V} , ext{diag}(oldsymbol{lambda})^t , mathbf{V}^{-1} ]
连乘的次数多了之后,则若最大的特征值 (lambda >1) ,会产生梯度爆炸; (lambda < 1) ,则会产生梯度消失 。不论哪种情况,都会导致模型难以学到有用的模式。


下左图显示一个 time step 中 tanh 函数的计算结果,右图显示整个神经网络的计算结果,可以清楚地看到哪个区域最容易产生梯度爆炸/消失问题。

技术图片



梯度爆炸的解决办法:

(1) Truncated Backpropagation through time:每次只 BP 固定的 time step 数,类似于 mini-batch SGD。缺点是丧失了长距离记忆的能力。

技术图片


(2) Clipping Gradients: 当梯度超过一定的 threshold 后,就进行 element-wise 的裁剪,该方法的缺点是又引入了一个新的参数 threshold。同时该方法也可视为一种基于瞬时梯度大小来自适应 learning rate 的方法:
[ ext{if} quad lVert extbf{g} Vert ge ext{threshold} \\[1ex] extbf{g} leftarrow frac{ ext{threshold}}{lVert extbf{g} Vert} extbf{g} ]

技术图片



梯度消失的解决办法

(1) 使用 LSTM、GRU等升级版 RNN,使用各种 gates 控制信息的流通。

(2) 在这篇论文 ( https://arxiv.org/pdf/1602.06662.pdf ) 中提出将权重矩阵 ( extbf{W}) 初始化为正交矩阵。正交矩阵有如下性质:(A^T A =A A^T = I, ; A^T = A^{-1}), 正交矩阵的特征值的绝对值为 ( ext{1}) 。证明如下, 对矩阵 (A) 有:
[ egin{align*} & A mathbf{v} = lambda mathbf{v} \\[1ex] ||A mathbf{v}||^2& = (A mathbf{v})^ ext{T} (A mathbf{v}) &= mathbf{v}^ ext{T}A ^{ ext{T}}A mathbf{v} & = mathbf{v}^{ ext{T}}mathbf{v} \\ & = ||mathbf{v}||^2 \\ & = |lambda|^2 ||mathbf{v}||^2 end{align*} ]
由于 (mathbf{v}) 为特征向量,(mathbf{v} eq 0) ,所以 (|lambda| = 1) ,这样连乘之后 (lambda^t) 不会出现越来越小的情况。

(3) 反转输入序列。像在机器翻译中使用 seq2seq 模型,若使用正常序列输入,则输入序列的第一个词和输出序列的第一个词相距较远,难以学到长期依赖。将输入序列反向后,输入序列的第一个词就会和输出序列的第一个词非常接近,二者的相互关系也就比较容易学习了。这样模型可以先学前几个词的短期依赖,再学后面词的长期依赖关系。见下图正常输入顺序是 (| ext{ABC}|),反向是 (| ext{CBA}|) ,则 ( ext{A}) 与第一个输出词 ( ext{W}) 接近:

技术图片





LSTM



虽然 Vanilla RNN 理论上可以建立长时间间隔状态之间的依赖关系,但由于梯度爆炸或消失问题,实际上只能学到短期依赖关系。为了学到长期依赖关系,LSTM 中引入了门控机制来控制信息的累计速度,包括有选择地加入新的信息,并有选择地遗忘之前累计的信息,整个 LSTM 单元结构如下图所示:

技术图片

[ egin{align} ext{input gate}&: quad extbf{i}_t = sigma( extbf{W}_i extbf{x}_t + extbf{U}_i extbf{h}_{t-1} + extbf{b}_i) ag{1} \\text{forget gate}&: quad extbf{f}_t = sigma( extbf{W}_f extbf{x}_t + extbf{U}_f extbf{h}_{t-1} + extbf{b}_f) ag{2}\\text{output gate}&: quad extbf{o}_t = sigma( extbf{W}_o extbf{x}_t + extbf{U}_o extbf{h}_{t-1} + extbf{b}_o) ag{3}\\text{new memory cell}&: quad ilde{ extbf{c}}_t = ext{tanh}( extbf{W}_c extbf{x}_t + extbf{U}_c extbf{h}_{t-1} + extbf{b}_c) ag{4}\\text{final memory cell}& : quad extbf{c}_t = extbf{f}_t odot extbf{c}_{t-1} + extbf{i}_t odot ilde{ extbf{c}}_t ag{5}\\text{final hidden state} &: quad extbf{h}_t= extbf{o}_t odot ext{tanh}( extbf{c}_t) ag{6} end{align} ]
式 $(1) sim (4) $ 的输入都一样,因而可以合并:
[ egin{pmatrix} extbf{i}_t \\textbf{f}_{t} \\textbf{o}_t \\tilde{ extbf{c}}_t end{pmatrix} = egin{pmatrix} sigma \\sigma \\sigma \\text{tanh} end{pmatrix} left( extbf{W} egin{bmatrix} extbf{x}_t \\textbf{h}_{t-1} end{bmatrix} + extbf{b} ight) ]

$ ilde{ extbf{c}}_t $ 为时刻 t 的候选状态,( extbf{i}_t) 控制 ( ilde{ extbf{c}}_t) 中有多少新信息需要保存,( extbf{f}_{t}) 控制上一时刻的内部状态 ( extbf{c}_{t-1}) 需要遗忘多少信息,( extbf{o}_t) 控制当前时刻的内部状态 ( extbf{c}_t) 有多少信息需要输出给外部状态 ( extbf{h}_t)

下表显示 forget gate 和 input gate 的关系,可以看出 forget gate 其实更应该被称为 “remember gate”, 因为其开启时之前的记忆信息 ( extbf{c}_{t-1}) 才会被保留,关闭时则会遗忘所有:

forget gate input gate result
1 0 保留上一时刻的状态 ( extbf{c}_{t-1})
1 1 保留上一时刻 ( extbf{c}_{t-1}) 和添加新信息 ( ilde{ extbf{c}}_t)
0 1 清空历史信息,引入新信息 ( ilde{ extbf{c}}_t)
0 0 清空所有新旧信息


对比 Vanilla RNN,可以发现在时刻 t,Vanilla RNN 通过 ( extbf{h}_t) 来保存和传递信息,上文已分析了如果时间间隔较大容易产生梯度消失的问题。 LSTM 则通过记忆单元 ( extbf{c}_t) 来传递信息,通过 ( extbf{i}_t)( extbf{f}_{t}) 的调控,( extbf{c}_t) 可以在 t 时刻捕捉到某个关键信息,并有能力将此关键信息保存一定的时间间隔。


原始的 LSTM 中是没有 forget gate 的,即:
[ extbf{c}_t = extbf{c}_{t-1} + extbf{i}_t odot ilde{ extbf{c}}_t ]
这样 (frac{partial extbf{c}_t}{partial extbf{c}_{t-1}}) 恒为 ( ext{1}) 。但是这样 ( extbf{c}_t) 会不断增大,容易饱和从而降低模型性能。后来引入了 forget gate ,则梯度变为 ( extbf{f}_{t}) ,事实上连乘多个 ( extbf{f}_{t} in (0,1)) 同样会导致梯度消失,但是 LSTM 的一个初始化技巧就是将 forget gate 的 bias 置为正数(例如 1 或者 5,如 tensorflow 中的默认值就是 (1.0) ),这样一来模型刚开始训练时 forget gate 的值都接近 1,不会发生梯度消失 (反之若 forget gate 的初始值过小则意味着前一时刻的大部分信息都丢失了,这样很难捕捉到长距离依赖关系)。 随着训练过程的进行,forget gate 就不再恒为 1 了。不过,一个训好的模型里各个 gate 值往往不是在 [0, 1] 这个区间里,而是要么 0 要么 1,很少有类似 0.5 这样的中间值,其实相当于一个二元的开关。假如在某个序列里,forget gate 全是 1,那么梯度不会消失;某一个 forget gate 是 0,模型选择遗忘上一时刻的信息。


LSTM 的一种变体增加 peephole 连接,这样三个 gate 不仅依赖于 ( extbf{x}_t)( extbf{h}_{t-1}),也依赖于记忆单元 ( extbf{c})
[ egin{align*} ext{input gate}&: quad extbf{i}_t = sigma( extbf{W}_i extbf{x}_t + extbf{U}_i extbf{h}_{t-1} + extbf{V}_i extbf{c}_{t-1} + extbf{b}_i) \\text{forget gate}&: quad extbf{f}_t = sigma( extbf{W}_f extbf{x}_t + extbf{U}_f extbf{h}_{t-1} + extbf{V}_f extbf{c}_{t-1} + extbf{b}_f) \\text{output gate}&: quad extbf{o}_t = sigma( extbf{W}_o extbf{x}_t + extbf{U}_o extbf{h}_{t-1} + extbf{V}_o extbf{c}_{t} + extbf{b}_o) \\end{align*} ]

注意 input gate 和 forget gate 连接的是 ( extbf{c}_{t-1}) ,而 output gate 连接的是 ( extbf{c}_t) 。下图来自 《LSTM: A Search Space Odyssey》,标注了 peephole 连接的样貌。

技术图片





GRU



相比于 Vanilla RNN (每个 time step 有一个输入 ( extbf{x}_t) ),从上面的 ((1) sim (4)) 式可以看出 一个 LSTM 单元有四个输入 (如下图,不考虑 peephole) ,因而参数是 Vanilla RNN 的四倍,带来的结果是训练起来很慢,因而在2014年 Cho 等人提出了 GRU ,对 LSTM 进行了简化,在不影响效果的前提下加快了训练速度。


技术图片


(largescr{LSTM:})
[ ormalsize egin{align} ext{input gate}&: quad extbf{i}_t = sigma( extbf{W}_i extbf{x}_t + extbf{U}_i extbf{h}_{t-1} + extbf{b}_i) ag{1} \\text{forget gate}&: quad extbf{f}_t = sigma( extbf{W}_f extbf{x}_t + extbf{U}_f extbf{h}_{t-1} + extbf{b}_f) ag{2}\\text{output gate}&: quad extbf{o}_t = sigma( extbf{W}_o extbf{x}_t + extbf{U}_o extbf{h}_{t-1} + extbf{b}_o) ag{3}\\text{new memory cell}&: quad ilde{ extbf{c}}_t = ext{tanh}( extbf{W}_c extbf{x}_t + extbf{U}_c extbf{h}_{t-1} + extbf{b}_c) ag{4}\\text{final memory cell}& : quad extbf{c}_t = extbf{f}_t odot extbf{c}_{t-1} + extbf{i}_t odot ilde{ extbf{c}}_t ag{5}\\text{final hidden state} &: quad extbf{h}_t= extbf{o}_t odot ext{tanh}( extbf{c}_t) ag{6} end{align} ]
在式 ((5)?) 中 forget gate 和 input gate 是互补关系,因而比较冗余,GRU 将其合并为一个 update gate。同时 GRU 也不引入额外的记忆单元 (LSTM 中的 ( extbf{c}?)) ,而是直接在当前状态 ( extbf{h}_t?) 和历史状态 ( extbf{h}_{t-1}?) 之间建立线性依赖关系。

技术图片


(largescr{GRU:})
[ ormalsize egin{align} ext{reset gate}&: quad extbf{r}_t = sigma( extbf{W}_r extbf{x}_t + extbf{U}_r extbf{h}_{t-1} + extbf{b}_r) ag{7} \\text{update gate}&: quad extbf{z}_t = sigma( extbf{W}_z extbf{x}_t + extbf{U}_z extbf{h}_{t-1} + extbf{b}_z) ag{8} \\text{new memory cell}&: quad ilde{ extbf{h}}_t = ext{tanh}( extbf{W}_h extbf{x}_t + extbf{r}_t odot ( extbf{U}_h extbf{h}_{t-1}) + extbf{b}_h) ag{9}\\text{final hidden state}&: quad extbf{h}_t = extbf{z}_t odot extbf{h}_{t-1} + (1 - extbf{z}_t) odot ilde{ extbf{h}}_t ag{10} end{align} ]
$ ilde{ extbf{h}}_t $ 为时刻 t 的候选状态,( extbf{r}_t) 控制 $ ilde{ extbf{h}}_t $ 有多少依赖于上一时刻的状态 ( extbf{h}_{t-1}) ,如果 ( extbf{r}_t = 1) ,则式 ((9)) 与 Vanilla RNN 一致,对于短依赖的 GRU 单元,reset gate 通常会更新频繁。( extbf{z}_t) 控制当前的内部状态 ( extbf{h}_t) 中有多少来自于上一时刻的 ( extbf{h}_{t-1}) 。如果 ( extbf{z}_t = 1) ,则会每步都传递同样的信息,和当前输入 ( extbf{x}_t) 无关。


另一方面看,( extbf{r}_t) 与 LSTM 中的 ( extbf{o}_t) 角色有些类似,因为将上面的 ((6)) 式代入 ((4)) 式可以得到:

[ egin{align*} ilde{ extbf{c}}_t &= ext{tanh}( extbf{W}_c extbf{x}_t + extbf{U}_c extbf{h}_{t-1} + extbf{b}_c) \\ extbf{h}_t &= extbf{o}_t odot ext{tanh}( extbf{c}_t) end{align*} quad Longrightarrow quad ilde{ extbf{c}}_t = ext{tanh}( extbf{W}_c extbf{x}_t + extbf{U}_c left( extbf{o}_{t-1} odot ext{tanh}( extbf{c}_{t-1}) ight) + extbf{b}_c) ]


最后是 cs224n 中提出的 RNN 训练 tips:

技术图片





/


































































以上是关于RNN - LSTM - GRU的主要内容,如果未能解决你的问题,请参考以下文章

RNN架构解析LSTM 模型

RNN和LSTM

RNN 与 LSTM 的应用

『cs231n』RNN之理解LSTM网络

自用RNN+LSTM笔记

普通RNN,LSTM长短期记忆