第五讲 卷积神经网络 --baseline

Posted wbloger

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了第五讲 卷积神经网络 --baseline相关的知识,希望对你有一定的参考价值。

 1 import tensorflow as tf
 2 import os
 3 import numpy as np
 4 from matplotlib import pyplot as plt
 5 from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
 6 from tensorflow.keras import Model
 7 
 8 
 9 np.set_printoptions(threshold=np.inf)
10 
11 cifar10 = tf.keras.datasets.cifar10
12 (x_train, y_train), (x_test, y_test) = cifar10.load_data()
13 x_train, x_test = x_train/25.0, x_test/255.0
14 
15 
16 class BaseLine(Model):
17   def __init__(self):
18     super(BaseLine, self).__init__()
19     self.c1 = Conv2D(filters=6, kernel_size=(5, 5), padding=same) #卷积层
20     self.b1 = BatchNormalization() #BN层
21     self.a1 = Activation(relu) #激活层
22     self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding=same) #池化层
23     self.d1 = Dropout(0.2) #dropou层
24 
25     self.flatten = Flatten()
26     self.f1 = Dense(128, activation=relu)
27     self.d2 = Dropout(0.2)
28     self.f2 = Dense(10, activation=softmax)
29   
30   def call(self, x):
31     x = self.c1(x)
32     x = self.b1(x)
33     x = self.a1(x)
34     x = self.p1(x)
35     x = self.d1(x)
36 
37     x = self.flatten(x)
38     x = self.f1(x)
39     x = self.d2(x)
40     y = self.f2(x)
41     return y
42 
43 
44 
45 model = BaseLine()
46 
47 model.compile(optimizer=adam, 
48               loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
49               metrics = [sparse_categorical_accuracy])
50 
51 checkpoint_save_path = "./checkpoint/Baseline.ckpt"
52 if os.path.exists(checkpoint_save_path + ".index"):
53   print("--------------------load the model-----------------")
54   model.load_weights(checkpoint_save_path)
55 
56 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True)
57 
58 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback])
59 
60 model.summary()
61 
62 
63 with open(./weights.txt, w) as file:
64   for v in model.trainable_variables:
65     file.write(str(v.name) + 
)
66     file.write(str(v.shape) + 
)
67     file.write(str(v.numpy()) + 
)
68 
69 
70 def plot_acc_loss_curve(history):
71     # 显示训练集和验证集的acc和loss曲线
72     from matplotlib import pyplot as plt
73     acc = history.history[sparse_categorical_accuracy]
74     val_acc = history.history[val_sparse_categorical_accuracy]
75     loss = history.history[loss]
76     val_loss = history.history[val_loss]
77     
78     plt.figure(figsize=(15, 5))
79     plt.subplot(1, 2, 1)
80     plt.plot(acc, label=Training Accuracy)
81     plt.plot(val_acc, label=Validation Accuracy)
82     plt.title(Training and Validation Accuracy)
83     plt.legend()
84    #plt.grid()
85     
86     plt.subplot(1, 2, 2)
87     plt.plot(loss, label=Training Loss)
88     plt.plot(val_loss, label=Validation Loss)
89     plt.title(Training and Validation Loss)
90     plt.legend()
91     #plt.grid()
92     plt.show()
93 
94 plot_acc_loss_curve(history)

 

以上是关于第五讲 卷积神经网络 --baseline的主要内容,如果未能解决你的问题,请参考以下文章

第五讲 卷积神经网路-- Inception10 --cifar10

第五讲 卷积神经网路-- Inception10 --cifar10

第五讲:工业网络——网络设备及其功能

第五讲 网络欺骗技术笔记

CDN百科第五讲 | CDN和游戏加速器有什么区别?

RocketMQ第五讲