第四讲 网络八股拓展--用mnist数据集实现断点续训, 绘制准确图像和损失图像

Posted wbloger

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了第四讲 网络八股拓展--用mnist数据集实现断点续训, 绘制准确图像和损失图像相关的知识,希望对你有一定的参考价值。

 1 import tensorflow as tf
 2 import os
 3 import numpy as np
 4 from matplotlib import pyplot as plt
 5 
 6 
 7 np.set_printoptions(threshold=np.inf)
 8 
 9 
10 mnist = tf.keras.datasets.mnist
11 (x_train, y_train), (x_test, y_test) = mnist.load_data()
12 x_train, x_test = x_train/255.0, x_test/255.0
13 
14 
15 
16 model = tf.keras.models.Sequential([
17       tf.keras.layers.Flatten(),
18       tf.keras.layers.Dense(128, activation=relu),
19       tf.keras.layers.Dense(10, activation=softmax)
20 ])
21 
22 model.compile(optimizer=adam, 
23               loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
24               metrics = [sparse_categorical_accuracy])
25 
26 
27 
28 checkpoint_save_path = "./checkpoint/mnist.ckpt"
29 if os.path.exists(checkpoint_save_path + .index):
30   print(----------------load the model-----------------)
31   model.load_weights(checkpoint_save_path)
32 
33 
34 
35 cp_callback = tf.keras.callbacks.ModelCheckpoint(
36     filepath=checkpoint_save_path,
37     save_weights_only=True,
38     save_best_only=True
39 )
40 
41 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback])
42 
43 model.summary()
44 
45 
46 print(model.trainable_variables)
47 with open(,.weights.txt, w) as file:
48   for v in model.trainable_variables:
49     file.write(str(v.name) + 
)
50     file.write(str(v.shape) + 
)
51     file.write(str(v.numpy()) + 
)
52 
53 
54 
55 # 显示训练集和验证集的acc和loss曲线
56 acc = history.history[sparse_categorical_accuracy]
57 val_acc = history.history[val_sparse_categorical_accuracy]
58 loss = history.history[loss]
59 val_loss = history.history[val_loss]
60 plt.figure(figsize=(15, 5))
61 plt.subplot(1, 2, 1)
62 plt.plot(acc, label=Training Accuracy)
63 plt.plot(val_acc, label=Validation Accuracy)
64 plt.title(Training and Validation Accuracy)
65 #plt.legend()
66 plt.grid()
67 
68 plt.subplot(1, 2, 2)
69 plt.plot(loss, label=Training Loss)
70 plt.plot(val_loss, label=Validation Loss)
71 plt.title(Training and Validation Loss)
72 plt.legend()
73 #plt.grid()
74 plt.show()

 

以上是关于第四讲 网络八股拓展--用mnist数据集实现断点续训, 绘制准确图像和损失图像的主要内容,如果未能解决你的问题,请参考以下文章

第四讲 网络八股拓展--自定义数据集加载

PyTorch学习实战第四篇:MNIST数据集的读取显示以及全连接实现数字识别

深度学习------用NNCNNRNN神经网络实现mnist数据集处理

用tensorflow搭建简单神经网络测试iris 数据集和MNIST 数据集

探索用卷积神经网络实现MNIST数据集分类

用标准3层神经网络实现MNIST识别