tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别
Posted djangoblog
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别相关的知识,希望对你有一定的参考价值。
tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别
https://blog.csdn.net/u014365862/article/details/78238807
MachineLP的Github(欢迎follow):https://github.com/MachineLP
我的GitHub:https://github.com/MachineLP/train_cnn-rnn-attention 自己搭建的一个框架,包含模型有:vgg(vgg16,vgg19), resnet(resnet_v2_50,resnet_v2_101,resnet_v2_152), inception_v4, inception_resnet_v2等。
-
chunk_size = 256
-
chunk_n = 160
-
rnn_size = 256
-
num_layers = 2
-
n_output_layer = MAX_CAPTCHA*CHAR_SET_LEN # 输出层
单层rnn:
tf.contrib.rnn.static_rnn:
输入:[步长,batch,input]
输出:[n_steps,batch,n_hidden]
还有rnn中加dropout
-
def recurrent_neural_network(data):
-
-
data = tf.reshape(data, [-1, chunk_n, chunk_size])
-
data = tf.transpose(data, [1,0,2])
-
data = tf.reshape(data, [-1, chunk_size])
-
data = tf.split(data,chunk_n)
-
-
# 只用RNN
-
layer = {‘w_‘:tf.Variable(tf.random_normal([rnn_size, n_output_layer])), ‘b_‘:tf.Variable(tf.random_normal([n_output_layer]))}
-
lstm_cell = tf.contrib.rnn.BasicLSTMCell(rnn_size)
-
outputs, status = tf.contrib.rnn.static_rnn(lstm_cell, data, dtype=tf.float32)
-
# outputs = tf.transpose(outputs, [1,0,2])
-
# outputs = tf.reshape(outputs, [-1, chunk_n*rnn_size])
-
ouput = tf.add(tf.matmul(outputs[-1], layer[‘w_‘]), layer[‘b_‘])
-
-
return ouput
多层rnn:
tf.nn.dynamic_rnn:
输入:[batch,步长,input]
输出:[batch,n_steps,n_hidden]
所以我们需要tf.transpose(outputs, [1, 0, 2]),这样就可以取到最后一步的output
-
def recurrent_neural_network(data):
-
# [batch,chunk_n,input]
-
data = tf.reshape(data, [-1, chunk_n, chunk_size])
-
#data = tf.transpose(data, [1,0,2])
-
#data = tf.reshape(data, [-1, chunk_size])
-
#data = tf.split(data,chunk_n)
-
-
# 只用RNN
-
layer = {‘w_‘:tf.Variable(tf.random_normal([rnn_size, n_output_layer])), ‘b_‘:tf.Variable(tf.random_normal([n_output_layer]))}
-
#1
-
# lstm_cell1 = tf.contrib.rnn.BasicLSTMCell(rnn_size)
-
# outputs1, status1 = tf.contrib.rnn.static_rnn(lstm_cell1, data, dtype=tf.float32)
-
-
def lstm_cell():
-
return tf.contrib.rnn.LSTMCell(rnn_size)
-
def attn_cell():
-
return tf.contrib.rnn.DropoutWrapper(lstm_cell(), output_keep_prob=keep_prob)
-
# stack = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(0, num_layers)], state_is_tuple=True)
-
stack = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(0, num_layers)], state_is_tuple=True)
-
# outputs, _ = tf.nn.dynamic_rnn(stack, data, seq_len, dtype=tf.float32)
-
outputs, _ = tf.nn.dynamic_rnn(stack, data, dtype=tf.float32)
-
# [batch,chunk_n,rnn_size] -> [chunk_n,batch,rnn_size]
-
outputs = tf.transpose(outputs, (1, 0, 2))
-
-
ouput = tf.add(tf.matmul(outputs[-1], layer[‘w_‘]), layer[‘b_‘])
-
-
return ouput
以上是关于tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别的主要内容,如果未能解决你的问题,请参考以下文章