RNN - LSTM - GRU
Posted massquantity
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了RNN - LSTM - GRU相关的知识,希望对你有一定的参考价值。
循环神经网络 (Recurrent Neural Network,RNN) 是一类具有短期记忆能力的神经网络,因而常用于序列建模。本篇先总结 RNN 的基本概念,以及其训练中时常遇到梯度爆炸和梯度消失问题,再引出 RNN 的两个主流变种 —— LSTM 和 GRU。
Vanilla RNN
Vanilla RNN 的主体结构:
上图中 (f{X, h, y}) 都是向量,公式如下:
[
% <![CDATA[
egin{align}
extbf{h}_{t} &= f_{ extbf{W}}left( extbf{h}_{t-1}, extbf{x}_{t}
ight) ag{1} \\textbf{h}_{t} &= fleft( extbf{W}_{hx} extbf{x}_{t} + extbf{W}_{hh} extbf{h}_{t-1} + extbf{b}_{h}
ight) ag{2a} \\textbf{h}_{t} &= extbf{tanh}left( extbf{W}_{hx} extbf{x}_{t} + extbf{W}_{hh} extbf{h}_{t-1} + extbf{b}_{h}
ight) ag{2b} \\hat{ extbf{y}}_{t} &= extbf{softmax}left( extbf{W}_{yh} extbf{h}_{t} + extbf{b}_{y}
ight) ag{3}
end{align} %]]>
]
其中 ( extbf{W}_{hx} in mathbb{R}^{h imes x}, ; extbf{W}_{hh} in mathbb{R}^{h imes h}, ; extbf{W}_{yh} in mathbb{R}^{y imes h}, ; extbf{b}_{h} in mathbb{R}^{h}, ; extbf{b}_{y} in mathbb{R}^{y})
((2a)) 式中的两个矩阵 (mathbf{W}) 可以合并:
[
egin{align*}
extbf{h}_{t} &= fleft( extbf{W}_{hx} extbf{x}_{t} + extbf{W}_{hh} extbf{h}_{t-1} + extbf{b}_{h}
ight) & = fleft(left( extbf{W}_{hx}, extbf{W}_{hh}
ight)
egin{pmatrix}
extbf{x}_t \\textbf{h}_{t-1}
end{pmatrix}
+ extbf{b}_{h}
ight) & = fleft( extbf{W}
egin{pmatrix}
extbf{x}_t \\textbf{h}_{t-1}
end{pmatrix}
+ extbf{b}_{h}
ight)
end{align*}
]
注意到在计算时,每一 time step 中使用的参数 ( extbf{W}, ; extbf{b}) 是一样的,也就是说每个步骤的参数都是共享的,这是RNN的重要特点。
和普通的全连接层相比,RNN 除了输入 ( extbf{x}_t) 外,还有输入隐藏层上一节点 (mathbf{h}_{t-1}) ,RNN 每一层的输出就是这两个输入用矩阵 ( extbf{W}_{hx}),( extbf{W}_{hh})和激活函数进行组合的结果。从 ((2a)) 式可以看出 ( extbf{x}_t) 和 (mathbf{h}_{t-1}) 都是与 ( extbf{h}_h) 全连接的,下图形象展示了各个时间节点 RNN 隐藏层记忆的变化。随着时间流逝,最初的蓝色结点保留地越来越少,这意味着RNN对于长时记忆的困难。
Vanishing & Exploding Gradient Problems
RNN 对于长时记忆的困难主要来源于梯度爆炸 / 消失问题,下面进行说明。RNN 中 Loss 的计算图示例:
总的 Loss 是每个 time step 的加和 : (mathcal{large{L}} (hat{ extbf{y}}, extbf{y}) = sum_{t = 1}^{T} mathcal{ large{L} }(hat{ extbf{y}_t}, extbf{y}_{t}))
由 backpropagation through time (BPTT) 算法,参数的梯度为:
[
frac{partial oldsymbol{mathcal{L}}}{partial extbf{W}} = sum_{t=1}^{T} frac{partial oldsymbol{mathcal{L}}_{t}}{partial extbf{W}} = sum_{t=1}^{T} frac{partial oldsymbol{mathcal{L}}_t}{partial extbf{y}_{t}} frac{partial extbf{y}_{t}}{partial extbf{h}_{t}} overbrace{frac{partial extbf{h}_{t}}{partial extbf{h}_{k}}}^{ igstar } frac{partial extbf{h}_{k}}{partial extbf{W}}
]
其中 (frac{partial extbf{h}_{t}}{partial extbf{h}_{k}}) 包含一系列 ( ext{Jacobian}) 矩阵,
[
frac{partial extbf{h}_{t}}{partial extbf{h}_{k}} = frac{partial extbf{h}_{t}}{partial extbf{h}_{t-1}} frac{partial extbf{h}_{t-1}}{partial extbf{h}_{t-2}} cdots frac{partial extbf{h}_{k+1}}{partial extbf{h}_{k}}
= prod_{i=k+1}^{t} frac{partial extbf{h}_{i}}{partial extbf{h}_{i-1}}
]
由于 RNN 中每个 time step 都是用相同的 ( extbf{W}) ,所以由 ((2a)) 式可得:
[
prod_{i=k+1}^{t} frac{partial extbf{h}_{i}}{partial extbf{h}_{i-1}} = prod_{i=k+1}^{t} extbf{W}^ op ext{diag} left[ f‘left( extbf{h}_{i-1}
ight)
ight]
]
由于 ( extbf{W}_{hh} in mathbb{R}^{h imes h}) 为方阵,对其进行特征值分解:
[
mathbf{W} = mathbf{V} , ext{diag}(oldsymbol{lambda}) , mathbf{V}^{-1}
]
由于上式是连乘 ( ext{t}) 次 (mathbf{W}) :
[
mathbf{W}^t = (mathbf{V} , ext{diag}(oldsymbol{lambda}) , mathbf{V}^{-1})^t = mathbf{V} , ext{diag}(oldsymbol{lambda})^t , mathbf{V}^{-1}
]
连乘的次数多了之后,则若最大的特征值 (lambda >1) ,会产生梯度爆炸; (lambda < 1) ,则会产生梯度消失 。不论哪种情况,都会导致模型难以学到有用的模式。
下左图显示一个 time step 中 tanh 函数的计算结果,右图显示整个神经网络的计算结果,可以清楚地看到哪个区域最容易产生梯度爆炸/消失问题。
梯度爆炸的解决办法:
(1) Truncated Backpropagation through time:每次只 BP 固定的 time step 数,类似于 mini-batch SGD。缺点是丧失了长距离记忆的能力。
(2) Clipping Gradients: 当梯度超过一定的 threshold 后,就进行 element-wise 的裁剪,该方法的缺点是又引入了一个新的参数 threshold。同时该方法也可视为一种基于瞬时梯度大小来自适应 learning rate 的方法:
[
ext{if} quad lVert extbf{g}
Vert ge ext{threshold} \\[1ex]
extbf{g} leftarrow frac{ ext{threshold}}{lVert extbf{g}
Vert} extbf{g}
]
梯度消失的解决办法
(1) 使用 LSTM、GRU等升级版 RNN,使用各种 gates 控制信息的流通。
(2) 在这篇论文 ( https://arxiv.org/pdf/1602.06662.pdf ) 中提出将权重矩阵 ( extbf{W}) 初始化为正交矩阵。正交矩阵有如下性质:(A^T A =A A^T = I, ; A^T = A^{-1}), 正交矩阵的特征值的绝对值为 ( ext{1}) 。证明如下, 对矩阵 (A) 有:
[
egin{align*}
& A mathbf{v} = lambda mathbf{v} \\[1ex]
||A mathbf{v}||^2& = (A mathbf{v})^ ext{T} (A mathbf{v}) &= mathbf{v}^ ext{T}A ^{ ext{T}}A mathbf{v} & = mathbf{v}^{ ext{T}}mathbf{v} \\ &
= ||mathbf{v}||^2 \\ &
= |lambda|^2 ||mathbf{v}||^2
end{align*}
]
由于 (mathbf{v}) 为特征向量,(mathbf{v}
eq 0) ,所以 (|lambda| = 1) ,这样连乘之后 (lambda^t) 不会出现越来越小的情况。
(3) 反转输入序列。像在机器翻译中使用 seq2seq 模型,若使用正常序列输入,则输入序列的第一个词和输出序列的第一个词相距较远,难以学到长期依赖。将输入序列反向后,输入序列的第一个词就会和输出序列的第一个词非常接近,二者的相互关系也就比较容易学习了。这样模型可以先学前几个词的短期依赖,再学后面词的长期依赖关系。见下图正常输入顺序是 (| ext{ABC}|),反向是 (| ext{CBA}|) ,则 ( ext{A}) 与第一个输出词 ( ext{W}) 接近:
LSTM
虽然 Vanilla RNN 理论上可以建立长时间间隔状态之间的依赖关系,但由于梯度爆炸或消失问题,实际上只能学到短期依赖关系。为了学到长期依赖关系,LSTM 中引入了门控机制来控制信息的累计速度,包括有选择地加入新的信息,并有选择地遗忘之前累计的信息,整个 LSTM 单元结构如下图所示:
[
egin{align}
ext{input gate}&: quad extbf{i}_t = sigma( extbf{W}_i extbf{x}_t + extbf{U}_i extbf{h}_{t-1} + extbf{b}_i) ag{1} \\text{forget gate}&: quad extbf{f}_t = sigma( extbf{W}_f extbf{x}_t + extbf{U}_f extbf{h}_{t-1} + extbf{b}_f) ag{2}\\text{output gate}&: quad extbf{o}_t = sigma( extbf{W}_o extbf{x}_t + extbf{U}_o extbf{h}_{t-1} + extbf{b}_o) ag{3}\\text{new memory cell}&: quad ilde{ extbf{c}}_t = ext{tanh}( extbf{W}_c extbf{x}_t + extbf{U}_c extbf{h}_{t-1} + extbf{b}_c) ag{4}\\text{final memory cell}& : quad extbf{c}_t = extbf{f}_t odot extbf{c}_{t-1} + extbf{i}_t odot ilde{ extbf{c}}_t ag{5}\\text{final hidden state} &: quad extbf{h}_t= extbf{o}_t odot ext{tanh}( extbf{c}_t) ag{6}
end{align}
]
式 $(1) sim (4) $ 的输入都一样,因而可以合并:
[
egin{pmatrix}
extbf{i}_t \\textbf{f}_{t} \\textbf{o}_t \\tilde{ extbf{c}}_t
end{pmatrix}
=
egin{pmatrix}
sigma \\sigma \\sigma \\text{tanh}
end{pmatrix}
left( extbf{W}
egin{bmatrix}
extbf{x}_t \\textbf{h}_{t-1}
end{bmatrix} + extbf{b}
ight)
]
$ ilde{ extbf{c}}_t $ 为时刻 t 的候选状态,( extbf{i}_t) 控制 ( ilde{ extbf{c}}_t) 中有多少新信息需要保存,( extbf{f}_{t}) 控制上一时刻的内部状态 ( extbf{c}_{t-1}) 需要遗忘多少信息,( extbf{o}_t) 控制当前时刻的内部状态 ( extbf{c}_t) 有多少信息需要输出给外部状态 ( extbf{h}_t) 。
下表显示 forget gate 和 input gate 的关系,可以看出 forget gate 其实更应该被称为 “remember gate”, 因为其开启时之前的记忆信息 ( extbf{c}_{t-1}) 才会被保留,关闭时则会遗忘所有:
forget gate | input gate | result |
---|---|---|
1 | 0 | 保留上一时刻的状态 ( extbf{c}_{t-1}) |
1 | 1 | 保留上一时刻 ( extbf{c}_{t-1}) 和添加新信息 ( ilde{ extbf{c}}_t) |
0 | 1 | 清空历史信息,引入新信息 ( ilde{ extbf{c}}_t) |
0 | 0 | 清空所有新旧信息 |
对比 Vanilla RNN,可以发现在时刻 t,Vanilla RNN 通过 ( extbf{h}_t) 来保存和传递信息,上文已分析了如果时间间隔较大容易产生梯度消失的问题。 LSTM 则通过记忆单元 ( extbf{c}_t) 来传递信息,通过 ( extbf{i}_t) 和 ( extbf{f}_{t}) 的调控,( extbf{c}_t) 可以在 t 时刻捕捉到某个关键信息,并有能力将此关键信息保存一定的时间间隔。
原始的 LSTM 中是没有 forget gate 的,即:
[
extbf{c}_t = extbf{c}_{t-1} + extbf{i}_t odot ilde{ extbf{c}}_t
]
这样 (frac{partial extbf{c}_t}{partial extbf{c}_{t-1}}) 恒为 ( ext{1}) 。但是这样 ( extbf{c}_t) 会不断增大,容易饱和从而降低模型性能。后来引入了 forget gate ,则梯度变为 ( extbf{f}_{t}) ,事实上连乘多个 ( extbf{f}_{t} in (0,1)) 同样会导致梯度消失,但是 LSTM 的一个初始化技巧就是将 forget gate 的 bias 置为正数(例如 1 或者 5,如 tensorflow 中的默认值就是 (1.0) ),这样一来模型刚开始训练时 forget gate 的值都接近 1,不会发生梯度消失 (反之若 forget gate 的初始值过小则意味着前一时刻的大部分信息都丢失了,这样很难捕捉到长距离依赖关系)。 随着训练过程的进行,forget gate 就不再恒为 1 了。不过,一个训好的模型里各个 gate 值往往不是在 [0, 1] 这个区间里,而是要么 0 要么 1,很少有类似 0.5 这样的中间值,其实相当于一个二元的开关。假如在某个序列里,forget gate 全是 1,那么梯度不会消失;某一个 forget gate 是 0,模型选择遗忘上一时刻的信息。
LSTM 的一种变体增加 peephole 连接,这样三个 gate 不仅依赖于 ( extbf{x}_t) 和 ( extbf{h}_{t-1}),也依赖于记忆单元 ( extbf{c}) :
[
egin{align*}
ext{input gate}&: quad extbf{i}_t = sigma( extbf{W}_i extbf{x}_t + extbf{U}_i extbf{h}_{t-1} + extbf{V}_i extbf{c}_{t-1} + extbf{b}_i) \\text{forget gate}&: quad extbf{f}_t = sigma( extbf{W}_f extbf{x}_t + extbf{U}_f extbf{h}_{t-1} + extbf{V}_f extbf{c}_{t-1} + extbf{b}_f) \\text{output gate}&: quad extbf{o}_t = sigma( extbf{W}_o extbf{x}_t + extbf{U}_o extbf{h}_{t-1} + extbf{V}_o extbf{c}_{t} + extbf{b}_o) \\end{align*}
]
注意 input gate 和 forget gate 连接的是 ( extbf{c}_{t-1}) ,而 output gate 连接的是 ( extbf{c}_t) 。下图来自 《LSTM: A Search Space Odyssey》,标注了 peephole 连接的样貌。
GRU
相比于 Vanilla RNN (每个 time step 有一个输入 ( extbf{x}_t) ),从上面的 ((1) sim (4)) 式可以看出 一个 LSTM 单元有四个输入 (如下图,不考虑 peephole) ,因而参数是 Vanilla RNN 的四倍,带来的结果是训练起来很慢,因而在2014年 Cho 等人提出了 GRU ,对 LSTM 进行了简化,在不影响效果的前提下加快了训练速度。
(largescr{LSTM:})
[
ormalsize
egin{align}
ext{input gate}&: quad extbf{i}_t = sigma( extbf{W}_i extbf{x}_t + extbf{U}_i extbf{h}_{t-1} + extbf{b}_i) ag{1} \\text{forget gate}&: quad extbf{f}_t = sigma( extbf{W}_f extbf{x}_t + extbf{U}_f extbf{h}_{t-1} + extbf{b}_f) ag{2}\\text{output gate}&: quad extbf{o}_t = sigma( extbf{W}_o extbf{x}_t + extbf{U}_o extbf{h}_{t-1} + extbf{b}_o) ag{3}\\text{new memory cell}&: quad ilde{ extbf{c}}_t = ext{tanh}( extbf{W}_c extbf{x}_t + extbf{U}_c extbf{h}_{t-1} + extbf{b}_c) ag{4}\\text{final memory cell}& : quad extbf{c}_t = extbf{f}_t odot extbf{c}_{t-1} + extbf{i}_t odot ilde{ extbf{c}}_t ag{5}\\text{final hidden state} &: quad extbf{h}_t= extbf{o}_t odot ext{tanh}( extbf{c}_t) ag{6}
end{align}
]
在式 ((5)?) 中 forget gate 和 input gate 是互补关系,因而比较冗余,GRU 将其合并为一个 update gate。同时 GRU 也不引入额外的记忆单元 (LSTM 中的 ( extbf{c}?)) ,而是直接在当前状态 ( extbf{h}_t?) 和历史状态 ( extbf{h}_{t-1}?) 之间建立线性依赖关系。
(largescr{GRU:})
[
ormalsize
egin{align}
ext{reset gate}&: quad extbf{r}_t = sigma( extbf{W}_r extbf{x}_t + extbf{U}_r extbf{h}_{t-1} + extbf{b}_r) ag{7} \\text{update gate}&: quad extbf{z}_t = sigma( extbf{W}_z extbf{x}_t + extbf{U}_z extbf{h}_{t-1} + extbf{b}_z) ag{8} \\text{new memory cell}&: quad ilde{ extbf{h}}_t = ext{tanh}( extbf{W}_h extbf{x}_t + extbf{r}_t odot ( extbf{U}_h extbf{h}_{t-1}) + extbf{b}_h) ag{9}\\text{final hidden state}&: quad extbf{h}_t = extbf{z}_t odot extbf{h}_{t-1} + (1 - extbf{z}_t) odot ilde{ extbf{h}}_t ag{10}
end{align}
]
$ ilde{ extbf{h}}_t $ 为时刻 t 的候选状态,( extbf{r}_t) 控制 $ ilde{ extbf{h}}_t $ 有多少依赖于上一时刻的状态 ( extbf{h}_{t-1}) ,如果 ( extbf{r}_t = 1) ,则式 ((9)) 与 Vanilla RNN 一致,对于短依赖的 GRU 单元,reset gate 通常会更新频繁。( extbf{z}_t) 控制当前的内部状态 ( extbf{h}_t) 中有多少来自于上一时刻的 ( extbf{h}_{t-1}) 。如果 ( extbf{z}_t = 1) ,则会每步都传递同样的信息,和当前输入 ( extbf{x}_t) 无关。
另一方面看,( extbf{r}_t) 与 LSTM 中的 ( extbf{o}_t) 角色有些类似,因为将上面的 ((6)) 式代入 ((4)) 式可以得到:
[ egin{align*} ilde{ extbf{c}}_t &= ext{tanh}( extbf{W}_c extbf{x}_t + extbf{U}_c extbf{h}_{t-1} + extbf{b}_c) \\ extbf{h}_t &= extbf{o}_t odot ext{tanh}( extbf{c}_t) end{align*} quad Longrightarrow quad ilde{ extbf{c}}_t = ext{tanh}( extbf{W}_c extbf{x}_t + extbf{U}_c left( extbf{o}_{t-1} odot ext{tanh}( extbf{c}_{t-1}) ight) + extbf{b}_c) ]
最后是 cs224n 中提出的 RNN 训练 tips:
/
以上是关于RNN - LSTM - GRU的主要内容,如果未能解决你的问题,请参考以下文章