tensorflow 基础学习十:RNN
Posted blackx
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow 基础学习十:RNN相关的知识,希望对你有一定的参考价值。
RNN网络的结构:
上图展示了一个简单的循环神经网络结构,在这个循环体中仅使用了一个类似全连接的神经网络结构。循环神经网络中的状态是通过一个向量来表示的,这个向量的维度称为循环神经网络隐藏层的大小,设其为h。从上图可以看出,循环体中的神经网络的输入包含两部分,一分部为上一时刻的状态,另一部分为当前时刻的输入样本。假设输入向量的维度为x,则上图中循环体的全连接层神经网络的输入大小为h+x,即将上一时刻的状态和当前时刻的输入拼接成一个大的向量作为循环体中神经网络的输入。因为该神经网络的输出为当前时刻的状态,于是输出层的节点个数也为h,循环体中的参数个数为(h+x)×h+h个。从上图可以看到,循环体中的神经网络输出不仅提供给下一时刻作为状态,同时也会提供给当前时刻的输出。为了将当前时刻的状态转换为最终的输出,循环神经网络还需要另外一个全连接神经网络来完成这个过程。不同时刻用于输出的全连接神经网络中的参数也是一致的。
下面展示一个循环神经网络前向传播的具体计算过程:
上图中,假设状态的维度为2,输入、输出的维度都为1,循环体中的全连接层中权重为:$w_{rnn}=\\begin{bmatrix} 0.1 & 0.2\\\\ 0.3 & 0.4\\\\ 0.5 & 0.6 \\end{bmatrix}$
偏置项的大小为brnn=[0.1,-0.1],用于输出的全连接层权重为:$w_{output}=\\begin{bmatrix} 1.0 \\\\ 2.0 \\end{bmatrix}$ ,偏置项大小为$b_{output}=0.1$,那么在$t_{0}$时刻,因为没有上一时刻,所以将状态初始化为[0,0],而当前的输入为1,所以拼接得到向量[0,0,1],通过循环体中的全连接层神经网络得到结果为:
$tanh\\left ( [0,0,1] \\times \\begin{bmatrix} 0.1 & 0.2 \\\\ 0.3 & 0.4 \\\\ 0.5 & 0.6\\end{bmatrix}+[0.1,-0.1] \\right )=tanh\\left ( [0.6,0.5]\\right )=[0.537,0.462]$
这个结果将作为下一时刻的输入状态,同时循环神经网络也会使用该状态生成输出,最终得到$t_{0}$的输出为:$[0.537,0.462] \\times \\begin{bmatrix} 1.0 \\\\ 2.0 \\end{bmatrix}+0.1=1.56$
使用$t_{0}$时刻的状态可以类似地推导得出$t_{1}$时刻的状态为[0.860,0.884],而$t_{1}$时刻的输出为2.73。在得到循环神经网络的前向传播结果后,可以和其他神经网络类似的定义损失函数。循环神经网络唯一的区别在于它每个时刻都有一个输出,所以循环神经网络的总损失为所有时刻上的损失函数的总和,以下代码实现了这个循环神经网络前向传播的过程。
import numpy as np X=[1,2] state=[0.0,0.0] # 分开定义不同输入部分的权重以方便计算 w_cell_state=np.array([[0.1,0.2],[0.3,0.4]]) w_cell_input=np.array([0.5,0.6]) b_cell=np.array([0.1,-0.1]) # 定义用于输出的全连接层参数。 w_output=np.array([[1.0],[2.0]]) b_output=0.1 # 按照时间顺序执行循环神经网络的前向传播过程。 for i in range(len(X)): # 计算循环体中的全连接层神经网络 before_activation=np.dot(state,w_cell_state)+X[i]*w_cell_input+b_cell state=np.tanh(before_activation) # 根据当前时刻状态计算最终输出 final_output=np.dot(state,w_output)+b_output # 输出每个时刻的信息 print(\'before activation: \',before_activation) print(\'state: \',state) print(\'output: \',final_output)
在实际应用中,如果序列过长会导致优化时出现梯度消散的问题,所以实际中一般会规定一个最大长度,当序列长度超过规定长度之后会对序列进行截断。
以上是关于tensorflow 基础学习十:RNN的主要内容,如果未能解决你的问题,请参考以下文章