跟我学算法- tensorflow 实现RNN操作
Posted my-love-is-python
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了跟我学算法- tensorflow 实现RNN操作相关的知识,希望对你有一定的参考价值。
对一张图片实现rnn操作,主要是通过先得到一个整体,然后进行切分,得到的最后input结果输出*_w[‘out’] + _b[‘out‘] = 最终输出结果
第一步: 数据载入
import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tutorials.mnist import input_data import numpy as np import matplotlib.pyplot as plt print("Packages imported") mnist = input_data.read_data_sets("data/", one_hot=True) trainimgs, trainlabels, testimgs, testlabels = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels ntrain, ntest, dim, nclasses = trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
第二步: 初始化参数
diminput = 28 dimhidden = 128 # nclasses = 10 dimoutput = nclasses nsteps = 28 # w参数初始化 weights = { ‘hidden‘: tf.Variable(tf.random_normal([diminput, dimhidden])), ‘out‘: tf.Variable(tf.random_normal([dimhidden, dimoutput])) } # b参数初始化 biases = { ‘hidden‘: tf.Variable(tf.random_normal([dimhidden])), ‘out‘: tf.Variable(tf.random_normal([dimoutput])) }
第三步: 构建RNN函数
def _RNN(_X, _W, _b, _nsteps, _name): # 第一步:转换输入,输入_X是还有batchSize=5的5张28*28图片,需要将输入从 # [batchSize,nsteps,diminput]==>[nsteps,batchSize,diminput] _X = tf.transpose(_X, [1, 0, 2]) # 第二步:reshape _X为[nsteps*batchSize,diminput] _X = tf.reshape(_X, [-1, diminput]) # 第三步:input layer -> hidden layer _H = tf.matmul(_X, _W[‘hidden‘]) + _b[‘hidden‘] # 第四步:将数据切分为‘nsteps’个切片,第i个切片为第i个batch data # tensoflow >0.12 _Hsplit = tf.split(_H, _nsteps, 0) # tensoflow <0.12 _Hsplit = tf.split(0,_nsteps,_H) # 第五步:计算LSTM final output(_LSTM_O) 和 state(_LSTM_S) # _LSTM_O和_LSTM_S都有‘batchSize’个元素 # _LSTM_O用于预测输出 with tf.variable_scope(_name) as scope: # 表示公用一份变量 scope.reuse_variables() # forget_bias = 1.0不忘记数据 ###tensorflow <1.0 # lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden,forget_bias = 1.0) # _LSTM_O,_SLTM_S = tf.nn.rnn(lstm_cell,_Hsplit,dtype=tf.float32) ###tensorflow 1.0 lstm_cell = rnn.BasicLSTMCell(dimhidden) _LSTM_O, _LSTM_S = rnn.static_rnn(lstm_cell, _Hsplit, dtype=tf.float32) # 第六步:输出,需要最后一个RNN单元作为预测输出所以取_LSTM_O[-1] _O = tf.matmul(_LSTM_O[-1], _W[‘out‘]) + _b[‘out‘] return { ‘X‘: _X, ‘H‘: _H, ‘_Hsplit‘: _Hsplit, ‘LSTM_O‘: _LSTM_O, ‘LSTM_S‘: _LSTM_S, ‘O‘: _O }
第四步: 构建cost函数和准确度函数
learning_rate = 0.001 x = tf.placeholder("float", [None, nsteps, diminput]) y = tf.placeholder("float", [None, dimoutput]) myrnn = _RNN(x, weights, biases, nsteps, ‘basic‘) pred = myrnn[‘O‘] cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # Adam accr = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)), tf.float32)) init = tf.global_variables_initializer() print("Network Ready!")
第五步: 训练模型, 降低cost值,优化参数
# 训练次数 training_epochs = 5 # 每次训练的图片数 batch_size = 16 # 循环的展示次数 display_step = 1 sess = tf.Session() sess.run(init) print("Start optimization") for epoch in range(training_epochs): avg_cost = 0. # total_batch = int(mnist.train.num_examples/batch_size) total_batch = 100 # Loop over all batches for i in range(total_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) batch_xs = batch_xs.reshape((batch_size, nsteps, diminput)) # print(batch_xs.shape) # print(batch_ys.shape) # batch_ys = batch_ys.reshape((batch_size, dimoutput)) # Fit training using batch data feeds = {x: batch_xs, y: batch_ys} sess.run(optm, feed_dict=feeds) # Compute average loss avg_cost += sess.run(cost, feed_dict=feeds) / total_batch # Display logs per epoch step if epoch % display_step == 0: print("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost)) feeds = {x: batch_xs, y: batch_ys} train_acc = sess.run(accr, feed_dict=feeds) print(" Training accuracy: %.3f" % (train_acc)) testimgs = testimgs.reshape((ntest, nsteps, diminput)) feeds = {x: testimgs, y: testlabels} test_acc = sess.run(accr, feed_dict=feeds) print(" Test accuracy: %.3f" % (test_acc)) print("Optimization Finished.")
以上是关于跟我学算法- tensorflow 实现RNN操作的主要内容,如果未能解决你的问题,请参考以下文章
跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()