TensorFlow学习笔记--- 使用CPABD实现最简单的CNN模型
Posted 一朵包纸
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了TensorFlow学习笔记--- 使用CPABD实现最简单的CNN模型相关的知识,希望对你有一定的参考价值。
import os from tensorflow.keras.datasets import mnist import tensorflow as tf from tensorflow.python.keras import Model from tensorflow.python.keras.datasets import cifar10 from tensorflow.python.keras.layers import Flatten, Dense, Conv2D, BatchNormalization, AvgPool2D, Activation, MaxPool2D, \\ Dropout (x_train, y_train), (x_test, y_test) = cifar10.load_data() x_train, x_test = x_train/255.0, x_test/255.0 checkpoint_save_path = \'./checkpoint/model.ckpt\' # 搭建模型类, 口诀:CBAPD,C卷积B批标准化A激活P池化D全连接 class ConvModel(Model): def __init__(self): super(ConvModel, self).__init__() # filters: 卷积核个数 kernel_size:卷积核尺寸 strides:横纵向步长 padding:是否使用全零填充,same为是 activation:激活函数 self.conv1 = Conv2D(filters=6, kernel_size=(5, 5), strides=(1, 1), padding=\'same\', activation=None) # 在激活函数前,先进行一次批标准化,使得输入值更靠近0均值 self.bn = BatchNormalization() # 激活函数 self.activation = Activation(\'relu\') # 池化,减少输入特征值 self.pool = MaxPool2D(pool_size=(2, 2), strides=2, padding=\'same\') # Dropout防止过拟合 self.dropout1 = Dropout(0.2) # 特征抽取完,拉直维度后通过全连接层输出 self.flatten = Flatten() self.d1 = Dense(128, activation=\'relu\') self.dropout2 = Dropout(0.2) self.d2 = Dense(10, activation=\'softmax\') def call(self, x): x = self.conv1(x) x = self.bn(x) x = self.activation(x) x = self.pool(x) x = self.dropout1(x) x = self.flatten(x) x = self.d1(x) x = self.dropout2(x) y = self.d2(x) return y model = ConvModel() # 模型优化 model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=[\'sparse_categorical_accuracy\']) # callback保存模型 model_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True) # 曾经保存过,直接加载权重参数 if os.path.exists(checkpoint_save_path + \'.index\'): model.load_weights(checkpoint_save_path) # 开始训练 model.fit(x=x_train, y=y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), callbacks=[model_callback]) # 结果总览 model.summary()
以上是关于TensorFlow学习笔记--- 使用CPABD实现最简单的CNN模型的主要内容,如果未能解决你的问题,请参考以下文章