直观理解LSTM(长短时记忆网络)
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了直观理解LSTM(长短时记忆网络)相关的知识,希望对你有一定的参考价值。
参考技术A 长短时神经网络是一种特殊的递归神经网络,所谓递归神经网络就是网络能够解决时间序列问题的预测。所谓递归神经网络就是网络中具有循环结构。递归神经网路从某种程度来说和传统的神经网络并非完全不同。可以将递归神经网络想象成有多层相同网络结构的神经网络,每一层将信息传递给下一层(以下借鉴一些十分易懂的图片):上述是为了便于理解网络送展示的示意图,实际上网络结构只是上图左边的一小块。
普通的RNN没有办法解决需要长时记忆的功能。比如试图预测“I grew up in France… I speak fluent French.”中最后一个词。最近信息显示下一个词可能是一门语言的名字,但是如果我们想要缩小选择范围,我们需要包含“法国”的那段上下文,从前面的信息推断后面的单词。相关信息与预测位置的间隔很大是完全有可能的。然而RNNs并没有办法解决这种问题。
LSTM作为效果比较好的递归神经网络,拥有者对长时时间序列问题很好的解决能力。
LSTM也有这样的链式结构,但其重复模块内部的机构不同。具体如下:
之下说明一下内部四个网络的具体意义。
主要分为: 单元状态 + 门限。
单元状态:让信息以不变的方式向下流动,相当于一个传送带,但传送带上的东西会随着他通过每一个重复模块基于当时的输入有所增减。
门限:有能力向单元状态增加或者剔除信息的管理机构,相当于传送带上放东西或者拿走东西的那个人。在LSTM中由sigmoid函数和乘法加法来控制这个过程。
上图通过当前时间的输入和前一个时间的输出来通过sigmoid函数来使得单元状态乘以这个sigmoid函数的输出。若sigmoid函数输出0则该部分信息需要被遗忘,反之该部分信息继续在单元状态中继续传下去。
该门限功能是更新旧的单元状态。之前的遗忘门限层决定了遗忘或者添加哪些信息,由该门限层来执行实现。
最后,我们需要决定需要输出什么。这个输出将会建立在单元状态的基础上,但是个过滤版本。首先,我们运行一个sigmoid层来决定单元状态中哪些部分需要输出。然后我们将单元状态输入到tanh函数(将值转换成-1到1之间)中,然后乘以输出的sigmoid门限值,所以我们只输出了我们想要输出的那部分。
上面提到的是非常常规的LSTM网络,LSTM有许多不同的变种,下面来介绍几种。
就是使用耦合遗忘和输入门限。我们不单独决定遗忘哪些、添加哪些新信息,而是一起做出决定。在输入的时候才进行遗忘。在遗忘某些旧信息时才将新值添加到状态中。
它将遗忘和输入门限结合输入到单个“更新门限”中。同样还将单元状态和隐藏状态合并,并做出一些其他变化。所得模型比标准LSTM模型要简单,这种做法越来越流行。
Keras深度学习实战(29)——长短时记忆网络详解与实现
Keras深度学习实战(29)——长短时记忆网络详解与实现
0. 前言
长短时记忆网络 (Long Short Term Memory
, LSTM
),顾名思义是具有记忆长短期信息能力的神经网络,解决了循环神经网络 (Recurrent neural networks
, RNN
) 梯度爆炸/消失的问题,是建立在循环神经网络上的一种新型深度学习的时间序列模型,它具有高度的学习能力与模拟能力,具有记忆可持续性的特点,且能预测未来的任意步长。本文首先介绍了 RNN
模型的局限性,从而引入介绍长短时记忆网络 (Long Short Term Memory
, LSTM
) 的基本原理,最后通过实现 LSTM
进行深入了解。
1. RNN 的局限性
我们首先可视化 RNN
在考虑多个时刻做出预测时的情况,如下所示,随着时间的增加,早期输入的影响会逐渐降低:
更具体的,我们也可以通过公式得到相同的结论,例如我们需要计算第 5
个时刻网络的中间状态:
h 5 = W X 5 + U h 4 = W X 5 + U W X 4 + U 2 W X 3 + U 3 W X 2 + U 4 W X 1 h_5 = WX_5 + Uh_4 = WX_5 + UWX_4 + U_2WX_3 + U_3WX_2 + U_4WX_1 h5=WX5+Uh4=WX5+UWX4+U2WX3+U3WX2+U4WX1
可以看到,随着时间的增加,如果
U
>
1
U>1
U>1,则网络中间状态的值高度依赖于
X
1
X_1
X1;而如果
U
<
1
U<1
U<1,则网络中间状态值对
X
1
X_1
X1 的依赖就少得多。对 U
矩阵的依赖性还可能在 U
值很小时导致梯度消失,而在 U
值很高时会导致梯度爆炸。
当在预测单词时存在长期依赖性时,RNN
的这种现象将导致无法学习长期依赖关系的问题。为了解决这个问题,我们将引入介绍长短期记忆 (Long Short Term Memory
, LSTM
) 体系结构。
2. LSTM 模型架构详解
在有关传统 RNN
的问题中,我们了解了 RNN
对于长期依赖问题无济于事。例如,假设输入句子如下:
I live in China. I speak ____.
可以通过关键字 China
来推测以上空中应填充的单词,但该关键字与我们要预测的单词距离 3
个时间戳。如果关键字远离要预测的单词,则需要解决消失/爆炸梯度问题。
2.1 LSTM 架构
在本节中,我们将学习 LSTM
如何帮助克服RNN体系结构的长期依赖缺点,并构建一个简单示例,以便了解 LSTM
的各个组成部分。LSTM
架构示意图如下所示:
可以看到,虽然每个时刻 (h
) 的输入 X
和输出保持不变,但是在网络中使用不同的计算方式和激活函数。
2.2 LSTM 各组成部分与计算流程
接下来,我们详细介绍在一个时间戳内的计算过程:
在上图中,
x
x
x 和
h
h
h 表示输入层和 LSTM
的输出向量,内部状态向量 Memory
存储在单元状态
c
c
c 中也就是说,相较于基础 RNN
而言,LSTM
将内部状态向量 Memory
和输出分开为两个变量,利用输入门 (Input Gate
)、遗忘门 (Forget Gate
)和输出门 (Output Gate
) 三个门控来控制内部信息的流动。门控机制是一种控制网络中数据流通量的手段,可以较好地控制数据流通的流量程度。
2.2.1 遗忘门
需要忘记的内容是通过“遗忘门
”获得的,用于控制上一个时间戳的记忆
c
t
−
1
c_t-1
ct−1 对当前时间戳的影响,遗忘门的控制变量
f
t
f_t
ft 由:
f t = σ ( W x f x ( t ) + W h f h ( t − 1 ) + b f ) f_t=\\sigma(W_xfx^(t)+W_hfh^(t-1)+b_f) ft=σ(Wxfx(t)+Whfh(t−1)+bf)
sigmoid
激活函数使网络能够选择性地识别需要忘记的内容。在确定需要忘记的内容后,更新后的单元状态如下:
c t = ( c ( t − 1 ) ⊗ f ) c_t=(c_(t-1)\\otimes f) ct=(c(t−1)⊗f)
其中,
⊗
\\otimes
⊗ 表示逐元素乘法。例如,如果句子的输入序列是 I live in China. I speak ___
,可以根据输入的单词 China
来填充空格,在之后,我们可能并不再需要有关国家名称的信息。我们根据当前时间戳需要忘记的内容来更新单元状态。
2.2.2 输入门
输入门用于控制 LSTM
对输入的接受程度,根据当前时间戳提供的输入将其他信息添加到单元状态中,通过 tanh
激活函数获得更新,因此也称为更新门。首先通过对当前时间戳的输入和上一时间戳的输出作非线性变换:
i t = σ ( W x i x ( t ) + W h i h ( t − 1 ) + b i ) i_t=\\sigma(W_xix^(t)+W_hih^(t-1)+b_i) it=σ(Wxix(t)+Whih(t−1)+bi)
输入门中,输入更新计算方法如下:
g t = t a n h ( W x g x ( t ) + W h g h ( t − 1 ) + b g ) g_t=tanh(W_xgx^(t)+W_hgh^(t-1)+b_g) gt=tanh(Wxgx(t)+Whgh(t−1)+bg)
在当前时间戳中需要忘记某些信息,并在其中添加一些其他信息,此时单元状态将按以下方式更新:
c ( t ) = ( c ( t 1 − ) ⊙ f t ) ⊕ ( i t ⊙ g t ) c^(t)=(c^(t1-)\\odot f_t)\\oplus(i_t\\odot g_t) c(t)=(c(t1−)⊙ft)⊕(it⊙gt)
得到的新的状态向量 c ( t ) c^(t) c(t) 即为当前时间戳的状态向量。
2.2.3 输入门
最后一个门称为输出门,我们需要指定输入组合和单元状态的哪一部分需要传递到下一个时刻,输入组合包括当前时间戳的输入和前一时间戳的输出值:
o t = σ ( W x o x ( t ) + W h o h ( t − 1 ) + b o ) o_t=\\sigma(W_xox^(t)+W_hoh^(t-1)+b_o) ot=σ(Wxox(t)+Whoh(t−1)+bo)
最终的网络状态值表示如下:
h ( t ) = o t ⊙ t a n h ( c ( t ) ) h^(t)=o_t\\odot tanh(c^(t)) h(t)=ot⊙tanh(c(t))
这样,我们就可以利用 LSTM
中的各个门来有选择地识别需要存储在存储器中的信息,从而克服了 RNN
的局限性。
3. 从零开始实现 LSTM
在本小节中,我们通过使用一个简单示例来了解 LSTM
的工作原理。
3.1 LSTM 模型实现
(1) 对输入数据进行预处理,该示例所用输入数据与预处理过程与在 RNN 模型中使用的完全相同:
# 定义输入与输出数据
docs = ['this is','is an']
# define class labels
labels = ['an','example']
from collections import Counter
counts = Counter()
for i,review in enumerate(docs+labels):
counts.update(review.split())
words = sorted(counts, key=counts.get, reverse=True)
vocab_size=len(words)
word_to_int = word: i for i, word in enumerate(words, 1)
encoded_docs = []
for doc in docs:
encoded_docs.append([word_to_int[word] for word in doc.split()])
encoded_labels = []
for label in labels:
encoded_labels.append([word_to_int[word] for word 以上是关于直观理解LSTM(长短时记忆网络)的主要内容,如果未能解决你的问题,请参考以下文章
机器学习面试题:LSTM长短期记忆网络的理解?LSTM是怎么解决梯度消失的问题的?还有哪些其它的解决梯度消失或梯度爆炸的方法?