Tensorflow v2 创建网络模型且保存参数至本地(非keras)
Posted lavender-pansy
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow v2 创建网络模型且保存参数至本地(非keras)相关的知识,希望对你有一定的参考价值。
//20201030
写在前面:最近几天在在学习Tensorflow v2框架搭建网络,今天在这里做一下summary,主要简述一下搭建的大致流程以及需要的要素,最后就是如何存储以及读取存储恢复网络
1.导包
(此处因为做了可视化以及使用mnist当做数据集,所以使用了matplotlib和keras)
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from tensorflow import keras
2.数据准备
此处直接使用mnist数据集
(x_train,y_train),(x_test,y_test) = keras.datasets.mnist.load_data() x_train,x_test = np.array(x_train,np.float32),np.array(x_test,np.float32) x_train = x_train/255.0# 将数据缩小,减小计算量 x_test = x_test/255.0 training_data = tf.data.Dataset.from_tensor_slices((x_train,y_train)) training_data = training_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)# 此行的意思是将数据变成数据池,且打乱顺序,一次取出batch_size个数据,且在去除是准备下次所需数据————类似流
如果因为墙的原因下载不了数据集,可单独下载数据集,然后使用如下方法解析
数据集链接(使用迅雷下载会很快):https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
解析方法:
import numpy as np def load_data(path): data_set = np.load(path) file_name_list = data_set.files dict = {} for each in file_name_list: dict[each] = data_set[each] return dict
3.设置网络(此处使用卷积网络)所需参数
num_classes = 10# 类别数,因为此网络用于分类0-9十个数字,所以共有10个类别 num_features = 28*28# 特性数,文件集中图片为28*28灰度图像,所以特征数为28*28 learning_rate = 0.001# 学习率,网络将使用Adam优化器,此处设置学习率为0.001(可改) training_step = 1000# 使用多少数据训练,此处设置1000,数据集总共有60000张图片(可自行打印shape查看) display_step = 100# 手动打印损失函数与准确率频率参数 batch_size = 256# 数据预处理时使用 fc_utils = 1024# fully connection layer单元数
4.继承keras.Model重写网络模型
class ConvNet(keras.Model): def __init__(self): super(ConvNet,self).__init__() self.conv1 = keras.layers.Conv2D(32,kernel_size=5,activation=‘relu‘) self.mp1 = keras.layers.MaxPool2D(2,strides=2) self.conv2 = keras.layers.Conv2D(64,kernel_size = 3,activation=‘relu‘) self.mp2 = keras.layers.MaxPool2D(2,strides=2) self.flatten = keras.layers.Flatten() self.fc = keras.layers.Dense(fc_utils) self.dropout = keras.layers.Dropout(rate = 0.5) self.out = keras.layers.Dense(num_classes) def call(self,x,training = False): x = tf.reshape(x,[-1,28,28,1]) x = self.conv1(x) x = self.mp1(x) x = self.conv2(x) x = self.mp2(x) x = self.flatten(x) x = self.fc(x) x = self.dropout(x,training = training) x = self.out(x) if not training: x = tf.nn.softmax(x) return x
此处网络模型为 [ 卷积层_32个过滤器——5个核心---->池化层_2x2——步长为2----->卷积层_64个过滤器_3个核心----->池化层_2x2步长为2----->flatten层_将数据拉平----->拥有1024个单元的全连接层----->dropout层(提升网络稀疏性,提升速度以及准确性)----->out层(输出)------>(如果不是在训练而是在预测,需要加一个softmax层来输出预测值)]
5.定义交叉熵函数(用于计算损失函数)
def cross_entropy(y_pred,y_true): y_true = tf.cast(y_true,tf.int64) loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true,logits=y_pred) return tf.reduce_mean(loss)
6.定义准确值函数(用于计算acc)
def accuracy(y_pred,y_true): correct_prediction = tf.equal(tf.argmax(y_pred,1),tf.cast(y_true,tf.int64)) return tf.reduce_mean(tf.cast(correct_prediction,tf.float32),axis = -1)
7.定义运行优化器函数
optimizer = tf.optimizers.Adam(learning_rate) def run_optimization(x,y): with tf.GradientTape() as g: pred = conv(x,True) loss = cross_entropy(pred,y) gradients = g.gradient(loss,conv.trainable_variables) optimizer.apply_gradients(zip(gradients,conv.trainable_variables))
8.开始优化
for steps,(batch_x,batch_y) in enumerate(training_data.take(training_step),1): run_optimization(batch_x,batch_y) if steps%display_step==0: pred = conv(batch_x,training=False) loss = cross_entropy(pred,batch_y) acc = accuracy(pred,batch_y) print("step:{}---->loss:{}---->acc:{}".format(steps,loss,acc))
输出如下图
9.可视化
此处使用测试数据集中前25个数据进行可视化
test_data = x_test[:25] label = y_test[:25] fig,ax = plt.subplots(5,5) plt.subplots_adjust(wspace=1,hspace=1) ax = ax.flatten() pred = conv(test_data,False) for i in range(25): ax[i].imshow(test_data[i],cmap=‘Greys‘) ax[i].set_title("pred:{},true:{}".format(np.argmax(pred[i]),label[i])) plt.show()
可视化结果如下
10.存储模型权重
conv.save_weights(‘./tfmodel.ckpt‘)
路径可自定义,但必须是.ckpt文件
11.读取权重参数恢复网络
conv = ConvNet()# 定义一个空网络(此网络必须和恢复权重网络相同)
conv.load_weights(‘./tfmodel.ckpt‘)
ps:恢复权重参数后,网络就是训练后的状态,可以直接用于预测或者进一步训练
以上
希望对大家有所帮助
以上是关于Tensorflow v2 创建网络模型且保存参数至本地(非keras)的主要内容,如果未能解决你的问题,请参考以下文章
tensorflow 1.0 学习:模型的保存与恢复(Saver)