TensorFlow2 手把手教你训练 MNIST 数据集 part 2
Posted 我是小白呀
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了TensorFlow2 手把手教你训练 MNIST 数据集 part 2相关的知识,希望对你有一定的参考价值。
TensorFlow2 手把手教你训练 MNIST 数据集
概述
MNIST 包含 0~9 的手写数字, 共有 60000 个训练集和 10000 个测试集. 数据的格式为单通道 28*28 的灰度图.
get_data 函数
def get_data():
"""
获取数据
:return: 返回分批完的训练集和测试集
"""
# 获取数据
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
# 分割训练集
train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(60000, seed=0)
train_db = train_db.batch(batch_size).map(pre_processing)
# 分割测试集
test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test)).shuffle(10000, seed=0)
test_db = test_db.batch(batch_size).map(pre_processing)
# 返回
return train_db, test_db
pre_processing 函数
def pre_processing(x, y):
"""
数据预处理
:param x: 特征值
:param y: 目标值
:return: 返回处理好的x, y
"""
# 转换x
x = tf.cast(x, tf.float32) / 255
x = tf.reshape(x, [-1, 784])
# 转换y
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x, y
train 函数
def train(train_db):
"""
训练数据
:param train_db: 分批的数据集
:return: 无返回值
"""
for step, (x, y) in enumerate(train_db):
with tf.GradientTape() as tape:
# 获取模型输出结果
logits = model(x)
# 计算MSE
MSE = tf.reduce_mean(tf.losses.MSE(y, logits))
# 计算交叉熵
Cross_Entropy = tf.losses.categorical_crossentropy(y, logits, from_logits=True)
Cross_Entropy = tf.reduce_sum(Cross_Entropy)
# 计算梯度
grads = tape.gradient(Cross_Entropy, model.trainable_variables)
# 跟新参数
optimizer.apply_gradients(zip(grads, model.trainable_variables))
# 每100批调试输出一下误差
if step % 100 == 0:
print("step:", step, "Cross_Entropy:", float(Cross_Entropy), "MSE:", float(MSE))
test 函数
def test(epoch, test_db):
"""
测试模型
:param epoch: 轮数
:param test_db: 分批的测试集
:return: 无返回值
"""
total_correct = 0 # 正确数
total_num = 0 # 总数
for x, y in test_db:
# 获取模型输出结果
logits = model(x)
# 预测结果
pred = tf.argmax(logits, axis=1)
# 从one_hot编码变回来
y = tf.argmax(y, axis=1)
# 计算准确数
correct = tf.equal(pred, y)
correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))
# 添加正确数和总数
total_correct += int(correct)
total_num += x.shape[0]
# 计算准确率
accuracy = total_correct / total_num
# 调试输出
print("epoch:", epoch, "Accuracy:", accuracy * 100, "%")
main 函数
def main():
"""
主函数
:return: 无返回值
"""
# 获取数据
train_db, test_db = get_data()
# 轮期
for epoch in range(1, iteration_num):
train(train_db)
test(epoch, test_db)
完整代码
import tensorflow as tf
# 定义超参数
batch_size = 256 # 一次训练的样本数目
learning_rate = 0.001 # 学习率
iteration_num = 20 # 迭代次数
# 优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
# 模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation=tf.nn.relu),
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dense(64, activation=tf.nn.relu),
tf.keras.layers.Dense(32, activation=tf.nn.relu),
tf.keras.layers.Dense(10)
])
# 调试输出summary
model.build(input_shape=[None, 28*28])
print(model.summary())
def pre_processing(x, y):
"""
数据预处理
:param x: 特征值
:param y: 目标值
:return: 返回处理好的x, y
"""
# 转换x
x = tf.cast(x, tf.float32) / 255
x = tf.reshape(x, [-1, 784])
# 转换y
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x, y
def get_data():
"""
获取数据
:return: 返回分批完的训练集和测试集
"""
# 获取数据
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
# 分割训练集
train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(60000, seed=0)
train_db = train_db.batch(batch_size).map(pre_processing)
# 分割测试集
test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test)).shuffle(10000, seed=0)
test_db = test_db.batch(batch_size).map(pre_processing)
# 返回
return train_db, test_db
def train(train_db):
"""
训练数据
:param train_db: 分批的数据集
:return: 无返回值
"""
for step, (x, y) in enumerate(train_db):
with tf.GradientTape() as tape:
# 获取模型输出结果
logits = model(x)
# 计算MSE
MSE = tf.reduce_mean(tf.losses.MSE(y, logits))
# 计算交叉熵
Cross_Entropy = tf.losses.categorical_crossentropy(y, logits, from_logits=True)
Cross_Entropy = tf.reduce_sum(Cross_Entropy)
# 计算梯度
grads = tape.gradient(Cross_Entropy, model.trainable_variables)
# 跟新参数
optimizer.apply_gradients(zip(grads, model.trainable_variables))
# 每100批调试输出一下误差
if step % 100 == 0:
print("step:", step, "Cross_Entropy:", float(Cross_Entropy), "MSE:", float(MSE))
def test(epoch, test_db):
"""
测试模型
:param epoch: 轮数
:param test_db: 分批的测试集
:return: 无返回值
"""
total_correct = 0 # 正确数
total_num = 0 # 总数
for x, y in test_db:
# 获取模型输出结果
logits = model(x)
# 预测结果
pred = tf.argmax(logits, axis=1)
# 从one_hot编码变回来
y = tf.argmax(y, axis=1)
# 计算准确数
correct = tf.equal(pred, y)
correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))
# 添加正确数和总数
total_correct += int(correct)
total_num += x.shape[0]
# 计算准确率
accuracy = total_correct / total_num
# 调试输出
print("epoch:", epoch, "Accuracy:", accuracy * 100, "%")
def main():
"""
主函数
:return: 无返回值
"""
# 获取数据
train_db, test_db = get_data()
# 轮期
for epoch in range(1, iteration_num):
train(train_db)
test(epoch, test_db)
if __name__ == "__main__":
main()
输出结果:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 256) 200960
_________________________________________________________________
dense_1 (Dense) (None, 128) 32896
_________________________________________________________________
dense_2 (Dense) (None, 64) 8256
_________________________________________________________________
dense_3 (Dense) (None, 32) 2080
_________________________________________________________________
dense_4 (Dense) (None, 10) 330
=================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________
None
step: 0 Cross_Entropy: 589.905029296875 MSE: 0.1215171366930008
step: 100 Cross_Entropy: 61.73141098022461 MSE: 16.245494842529297
step: 200 Cross_Entropy: 46.609832763671875 MSE: 17.865381240844727
epoch: 1 Accuracy: 95.44 %
step: 0 Cross_Entropy: 47.514892578125 MSE: 20.183507919311523
step: 100 Cross_Entropy: 35.65019226074219 MSE: 18.90221405029297
step: 200 Cross_Entropy: 33.837703704833984 MSE: 16.84846305847168
epoch: 2 Accuracy: 96.61 %
step: 0 Cross_Entropy: 17.38262939453125 MSE: 18.48729133605957
step: 100 Cross_Entropy: 27.96572494506836 MSE: 21.008562088012695
step: 200 Cross_Entropy: 27.25030517578125 MSE: 21.703704833984375
epoch: 3 Accuracy: 97.22 %
step: 0 Cross_Entropy: 21.492198944091797 MSE: 22.19614028930664
step: 100 Cross_Entropy: 11.623129844665527 MSE: 27.867923736572266
step: 200 Cross_Entropy: 7.261983394622803 MSE: 25.641494750976562
epoch: 4 Accuracy: 97.41 %
step: 0 Cross_Entropy: 11.380800247192383 MSE: 26.688203811645508
step: 100 Cross_Entropy: 10.21794319152832 MSE: 27.864110946655273
step: 200 Cross_Entropy: 14.44814682006836 MSE: 31.53815460205078
epoch: 5 Accuracy: 97.18 %
step: 0 Cross_Entropy: 5.241445541381836 MSE: 30.080406188964844
step: 100 Cross_Entropy: 3.1642959117889404 MSE: 33.59324645996094
step: 200 Cross_Entropy: 9.680063247680664 MSE: 34.96605682373047
epoch: 6 Accuracy: 97.95 %
step: 0 Cross_Entropy: 11.292088508605957 MSE: 36.604915618896484
step: 100 Cross_Entropy: 4.599205017089844 MSE: 38.455101013183594
step: 200 Cross_Entropy: 13.383275032043457 MSE: 41.19858932495117
epoch: 7 Accuracy: 97.65 %
step: 0 Cross_Entropy: 6.985865592956543 MSE: 33.687713623046875
step: 100 Cross_Entropy: 5.281797409057617 MSE: 44.13557815551758
step: 200 Cross_Entropy: 6.665032863616943 MSE: 44.898216247558594
epoch: 8 Accuracy: 97.72 %
step: 0 Cross_Entropy: 1.8101396560668945 MSE: 42.560211181640625
step: 100 Cross_Entropy: 4.517214298248291 MSE: 46.41954803466797
step: 200 Cross_Entropy: 5.113927364349365 MSE: 47.692081451416016
epoch: 9 Accuracy: 97.84 %
step: 0 Cross_Entropy: 5.45690393447876 MSE: 44.61886978149414
step: 100 Cross_Entropy: 6.035201549530029 MSE: 51.11096954345703
step: 200 Cross_Entropy: 7.727978229522705 MSE: 50.56428527832031
epoch: 10 Accuracy: 97.78 %
step: 0 Cross_Entropy: 6.566008567810059 MSE: 53.64844512939453
step: 100 Cross_Entropy: 12.636188507080078 MSE: 59.566192626953125
step: 200 Cross_Entropy: 0.9305715560913086 MSE: 63.96886444091797
epoch: 11 Accuracy: 97.68 %
step: 0 Cross_Entropy: 3.799677610397339 MSE: 57.57715606689453
step: 100 Cross_Entropy: 7.782512664794922 MSE: 63.94820785522461
step: 200 Cross_Entropy: 6.952803611755371 MSE: 59.19414138793945
epoch: 12 Accuracy: 97.85000000000001 %
step: 0 Cross_Entropy: 1.316650152206421 MSE: 57.405555725097656
step: 100 Cross_Entropy: 3.3630568981170654 MSE: 65.93612670898438
step: 200 Cross_Entropy: 2.8188657760620117 MSE: 63.6553955078125
epoch: 13 Accuracy: 97.71 %
step: 0 Cross_Entropy: 1.0694936513900757 MSE: 73.58941650390625
step: 100 Cross_Entropy: 1.1532164812088013 MSE: 72.19602966308594
step: 200 Cross_Entropy: 4.054533958435059 MSE: 66.22490692138672
epoch: 14 Accuracy: 97.69 %
step: 0 Cross_Entropy: 0.5501946806907654 MSE: 67.73658752441406
step: 100 Cross_Entropy: 1.6239964962005615 MSE: 75.26908874511719
step: 200 Cross_Entropy: 0.25266233086586 MSE: 79.37750244140625
epoch: 15 Accuracy: 97.96000000000001 %
step: 0 Cross_Entropy: 0.5946800112724304 MSE: 78.45301818847656
step: 100 Cross_Entropy: 3.876664638519287 MSE: 86.45103454589844
step: 200 Cross_Entropy: 13.129545211791992 MSE: 70.39665222167969
epoch: 16 Accuracy: 97.67 %
step: 0 Cross_Entropy: 4.019548416137695 MSE: 66.26248168945312
step: 100 Cross_Entropy: 0.7121025323867798 MSE: 67.56402587890625
step: 200 Cross_Entropy: 3.106649875640869 MSE: 77.95216369628906
epoch: 17 Accuracy: 97.71 %
step: 0 Cross_Entropy: 0.797190248966217 MSE: 70.34780883789062
step: 100 Cross_Entropy: 5.868640422821045 MSE: 74.68391418457031
step: 200 Cross_Entropy: 2.415027141571045 MSE: 85.03378295898438
epoch: 18 Accuracy: 97.54 %
step: 0 Cross_Entropy: 1.5692293643951416 MSE: 90.47661590576172
step: 100 Cross_Entropy: 0.6557420492172241 MSE: 81.88681030273438
step: 200 Cross_Entropy: 5.726837158203125 MSE: 76.24435424804688
epoch: 19 Accuracy: 98.0 %
以上是关于TensorFlow2 手把手教你训练 MNIST 数据集 part 2的主要内容,如果未能解决你的问题,请参考以下文章
TensorFlow2 手把手教你训练 MNIST 数据集 part1