Keras深度学习实战(29)——长短时记忆网络详解与实现
Posted 盼小辉丶
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Keras深度学习实战(29)——长短时记忆网络详解与实现相关的知识,希望对你有一定的参考价值。
Keras深度学习实战(29)——长短时记忆网络详解与实现
0. 前言
长短时记忆网络 (Long Short Term Memory
, LSTM
),顾名思义是具有记忆长短期信息能力的神经网络,解决了循环神经网络 (Recurrent neural networks
, RNN
) 梯度爆炸/消失的问题,是建立在循环神经网络上的一种新型深度学习的时间序列模型,它具有高度的学习能力与模拟能力,具有记忆可持续性的特点,且能预测未来的任意步长。本文首先介绍了 RNN
模型的局限性,从而引入介绍长短时记忆网络 (Long Short Term Memory
, LSTM
) 的基本原理,最后通过实现 LSTM
进行深入了解。
1. RNN 的局限性
我们首先可视化 RNN
在考虑多个时刻做出预测时的情况,如下所示,随着时间的增加,早期输入的影响会逐渐降低:
更具体的,我们也可以通过公式得到相同的结论,例如我们需要计算第 5
个时刻网络的中间状态:
h 5 = W X 5 + U h 4 = W X 5 + U W X 4 + U 2 W X 3 + U 3 W X 2 + U 4 W X 1 h_5 = WX_5 + Uh_4 = WX_5 + UWX_4 + U_2WX_3 + U_3WX_2 + U_4WX_1 h5=WX5+Uh4=WX5+UWX4+U2WX3+U3WX2+U4WX1
可以看到,随着时间的增加,如果
U
>
1
U>1
U>1,则网络中间状态的值高度依赖于
X
1
X_1
X1;而如果
U
<
1
U<1
U<1,则网络中间状态值对
X
1
X_1
X1 的依赖就少得多。对 U
矩阵的依赖性还可能在 U
值很小时导致梯度消失,而在 U
值很高时会导致梯度爆炸。
当在预测单词时存在长期依赖性时,RNN
的这种现象将导致无法学习长期依赖关系的问题。为了解决这个问题,我们将引入介绍长短期记忆 (Long Short Term Memory
, LSTM
) 体系结构。
2. LSTM 模型架构详解
在有关传统 RNN
的问题中,我们了解了 RNN
对于长期依赖问题无济于事。例如,假设输入句子如下:
I live in China. I speak ____.
可以通过关键字 China
来推测以上空中应填充的单词,但该关键字与我们要预测的单词距离 3
个时间戳。如果关键字远离要预测的单词,则需要解决消失/爆炸梯度问题。
2.1 LSTM 架构
在本节中,我们将学习 LSTM
如何帮助克服RNN体系结构的长期依赖缺点,并构建一个简单示例,以便了解 LSTM
的各个组成部分。LSTM
架构示意图如下所示:
可以看到,虽然每个时刻 (h
) 的输入 X
和输出保持不变,但是在网络中使用不同的计算方式和激活函数。
2.2 LSTM 各组成部分与计算流程
接下来,我们详细介绍在一个时间戳内的计算过程:
在上图中,
x
x
x 和
h
h
h 表示输入层和 LSTM
的输出向量,内部状态向量 Memory
存储在单元状态
c
c
c 中也就是说,相较于基础 RNN
而言,LSTM
将内部状态向量 Memory
和输出分开为两个变量,利用输入门 (Input Gate
)、遗忘门 (Forget Gate
)和输出门 (Output Gate
) 三个门控来控制内部信息的流动。门控机制是一种控制网络中数据流通量的手段,可以较好地控制数据流通的流量程度。
2.2.1 遗忘门
需要忘记的内容是通过“遗忘门
”获得的,用于控制上一个时间戳的记忆
c
t
−
1
c_t-1
ct−1 对当前时间戳的影响,遗忘门的控制变量
f
t
f_t
ft 由:
f t = σ ( W x f x ( t ) + W h f h ( t − 1 ) + b f ) f_t=\\sigma(W_xfx^(t)+W_hfh^(t-1)+b_f) ft=σ(Wxfx(t)+Whfh(t−1)+bf)
sigmoid
激活函数使网络能够选择性地识别需要忘记的内容。在确定需要忘记的内容后,更新后的单元状态如下:
c t = ( c ( t − 1 ) ⊗ f ) c_t=(c_(t-1)\\otimes f) ct=(c(t−1)⊗f)
其中,
⊗
\\otimes
⊗ 表示逐元素乘法。例如,如果句子的输入序列是 I live in China. I speak ___
,可以根据输入的单词 China
来填充空格,在之后,我们可能并不再需要有关国家名称的信息。我们根据当前时间戳需要忘记的内容来更新单元状态。
2.2.2 输入门
输入门用于控制 LSTM
对输入的接受程度,根据当前时间戳提供的输入将其他信息添加到单元状态中,通过 tanh
激活函数获得更新,因此也称为更新门。首先通过对当前时间戳的输入和上一时间戳的输出作非线性变换:
i t = σ ( W x i x ( t ) + W h i h ( t − 1 ) + b i ) i_t=\\sigma(W_xix^(t)+W_hih^(t-1)+b_i) it=σ(Wxix(t)+Whih(t−1)+bi)
输入门中,输入更新计算方法如下:
g t = t a n h ( W x g x ( t ) + W h g h ( t − 1 ) + b g ) g_t=tanh(W_xgx^(t)+W_hgh^(t-1)+b_g) gt=tanh(Wxgx(t)+Whgh(t−1)+bg)
在当前时间戳中需要忘记某些信息,并在其中添加一些其他信息,此时单元状态将按以下方式更新:
c ( t ) = ( c ( t 1 − ) ⊙ f t ) ⊕ ( i t ⊙ g t ) c^(t)=(c^(t1-)\\odot f_t)\\oplus(i_t\\odot g_t) c(t)=(c(t1−)⊙ft)⊕(it⊙gt)
得到的新的状态向量 c ( t ) c^(t) c(t) 即为当前时间戳的状态向量。
2.2.3 输入门
最后一个门称为输出门,我们需要指定输入组合和单元状态的哪一部分需要传递到下一个时刻,输入组合包括当前时间戳的输入和前一时间戳的输出值:
o t = σ ( W x o x ( t ) + W h o h ( t − 1 ) + b o ) o_t=\\sigma(W_xox^(t)+W_hoh^(t-1)+b_o) ot=σ(Wxox(t)+Whoh(t−1)+bo)
最终的网络状态值表示如下:
h ( t ) = o t ⊙ t a n h ( c ( t ) ) h^(t)=o_t\\odot tanh(c^(t)) h(t)=ot⊙tanh(c(t))
这样,我们就可以利用 LSTM
中的各个门来有选择地识别需要存储在存储器中的信息,从而克服了 RNN
的局限性。
3. 从零开始实现 LSTM
在本小节中,我们通过使用一个简单示例来了解 LSTM
的工作原理。
3.1 LSTM 模型实现
(1) 对输入数据进行预处理,该示例所用输入数据与预处理过程与在 RNN 模型中使用的完全相同:
# 定义输入与输出数据
docs = ['this is','is an']
# define class labels
labels = ['an','example']
from collections import Counter
counts = Counter()
for i,review in enumerate(docs+labels):
counts.update(review.split())
words = sorted(counts, key=counts.get, reverse=True)
vocab_size=len(words)
word_to_int = word: i for i, word in enumerate(words, 1)
encoded_docs = []
for doc in docs:
encoded_docs.append([word_to_int[word] for word in doc.split()])
encoded_labels = []
for label in labels:
encoded_labels.append([word_to_int[word] for word 以上是关于Keras深度学习实战(29)——长短时记忆网络详解与实现的主要内容,如果未能解决你的问题,请参考以下文章
Keras深度学习实战——使用长短时记忆网络构建情感分析模型
Keras深度学习实战(33)——基于LSTM的序列预测模型