tensorflow2.0 学习
Posted heze
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow2.0 学习相关的知识,希望对你有一定的参考价值。
用tensorflow2.0 版回顾了一下mnist的学习
代码如下,感觉这个版本下的mnist学习更简洁,更方便
关于tensorflow的基础知识,这里就不更新了,用到什么就到网上取搜索相关的知识
# encoding: utf-8 import numpy as np import tensorflow as tf import matplotlib.pyplot as plt #加载下载好的mnist数据库 60000张训练 10000张测试 每一张维度(28,28) path = r‘G:2019pythonmnist.npz‘ f = np.load(path) x_train, y_train = f[‘x_train‘], f[‘y_train‘] f.close() #预处理输入数据 x = 2*tf.convert_to_tensor(x_train, dtype = tf.float32)/255. - 1 x = tf.reshape(x, [-1, 28*28]) y = tf.convert_to_tensor(y_train, dtype=tf.int32) y = tf.one_hot(y, depth=10) #第一层输入256, 第二次输出128, 第三层输出10 #第一,二,三层参数w,b w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1)) #正态分布的一种 b1 = tf.Variable(tf.zeros([256])) w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1)) b2 = tf.Variable(tf.zeros([128])) w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1)) b3 = tf.Variable(tf.zeros([10])) #将60000组数据切分为600组,每组100个数据 train_db = tf.data.Dataset.from_tensor_slices((x, y)).batch(100) lr = 0.001 #学习率 losses = [] #储存每epoch的loss值,便于观察学习情况 for epoch in range(20): #一次性处理100组(x, y)数据 for step, (x, y) in enumerate(train_db): #遍历切分好的数据step:0->599 with tf.GradientTape() as tape: #向前传播第一,二,三层 h1 = x@w1 + tf.broadcast_to(b1, [x.shape[0], 256]) #可以直接写成 +b1 h1 = tf.nn.relu(h1) h2 = h1@w2 + b2 h2 = tf.nn.relu(h2) out = h2@w3 + b3 #计算mse loss = tf.square(y - out) loss = tf.reduce_mean(loss) #计算参数的梯度,tape.gradient为自动求导函数,loss为目标数据,目的使它越来越接近真实值 grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3]) #更新w,b w1.assign_sub(lr*grads[0]) #原地减去给定的值,实现参数的自我更新 b1.assign_sub(lr*grads[1]) w2.assign_sub(lr*grads[2]) b2.assign_sub(lr*grads[3]) w3.assign_sub(lr*grads[4]) b3.assign_sub(lr*grads[5]) #观察学习情况 if step%500 == 0: print(epoch, step, ‘loss:‘, float(loss)) #将每epoch的loss情况储存起来,最后观察 losses.append(float(loss)) plt.plot(losses, marker=‘s‘, label=‘training‘) plt.xlabel(‘Epoch‘) plt.ylabel(‘MSE‘) plt.legend()
plt.savefig(‘exam_mnist_forward.png‘) plt.show()
观察结果:
可由注释理解代码的含义!下一次更新mnist数据集训练的进阶!
以上是关于tensorflow2.0 学习的主要内容,如果未能解决你的问题,请参考以下文章