tf.nn.rnn_cell.MultiRNNCell
Posted yanshw
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tf.nn.rnn_cell.MultiRNNCell相关的知识,希望对你有一定的参考价值。
- Class tf.contrib.rnn.MultiRNNCell
- Class tf.nn.rnn_cell.MultiRNNCell
构建多隐层神经网络
__init__(cells, state_is_tuple=True)
cells:rnn cell 的list
state_is_tuple:true,状态Ct和ht就是分开记录,放在一个tuple中,接受和返回的states是n-tuples,其中n=len(cells),False,states是concatenated沿着列轴.后者即将弃用。
BasicLSTMCell 单隐层
BasicLSTMCell 多隐层
代码示例
# encoding:utf-8 import tensorflow as tf batch_size=10 depth=128 inputs=tf.Variable(tf.random_normal([batch_size,depth])) previous_state0=(tf.random_normal([batch_size,100]),tf.random_normal([batch_size,100])) previous_state1=(tf.random_normal([batch_size,200]),tf.random_normal([batch_size,200])) previous_state2=(tf.random_normal([batch_size,300]),tf.random_normal([batch_size,300])) num_units=[100,200,300] print(inputs) cells=[tf.nn.rnn_cell.BasicLSTMCell(num_unit) for num_unit in num_units] mul_cells=tf.nn.rnn_cell.MultiRNNCell(cells) outputs,states=mul_cells(inputs,(previous_state0,previous_state1,previous_state2)) print(outputs.shape) #(10, 300) print(states[0]) #第一层LSTM print(states[1]) #第二层LSTM print(states[2]) ##第三层LSTM print(states[0].h.shape) #第一层LSTM的h状态,(10, 100) print(states[0].c.shape) #第一层LSTM的c状态,(10, 100) print(states[1].h.shape) #第二层LSTM的h状态,(10, 200)
输出
(10, 300) LSTMStateTuple(c=<tf.Tensor ‘multi_rnn_cell/cell_0/basic_lstm_cell/Add_1:0‘ shape=(10, 100) dtype=float32>, h=<tf.Tensor ‘multi_rnn_cell/cell_0/basic_lstm_cell/Mul_2:0‘ shape=(10, 100) dtype=float32>) LSTMStateTuple(c=<tf.Tensor ‘multi_rnn_cell/cell_1/basic_lstm_cell/Add_1:0‘ shape=(10, 200) dtype=float32>, h=<tf.Tensor ‘multi_rnn_cell/cell_1/basic_lstm_cell/Mul_2:0‘ shape=(10, 200) dtype=float32>) LSTMStateTuple(c=<tf.Tensor ‘multi_rnn_cell/cell_2/basic_lstm_cell/Add_1:0‘ shape=(10, 300) dtype=float32>, h=<tf.Tensor ‘multi_rnn_cell/cell_2/basic_lstm_cell/Mul_2:0‘ shape=(10, 300) dtype=float32>) (10, 100) (10, 100) (10, 200)
以上是关于tf.nn.rnn_cell.MultiRNNCell的主要内容,如果未能解决你的问题,请参考以下文章