CNN实战--mnist
Posted cherrypill
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了CNN实战--mnist相关的知识,希望对你有一定的参考价值。
CNN实战--mnist
dataprocessing
我一般把数据处理单独写一个函数
因为网上大多数都是直接在线下载做学习,导致与实际应用的情况不相符,所以我这是直接下载下来并读取,处理数据
这个数据类型文档说的很清楚 雾
是图片二进制存储的(图片大小28*28),并且开头有一个magic num (需要跳过它)
不知道跳几位的可以多尝试一下不同的offset输出长度看是不是整除
def read_data():
with open(‘./t10k-labels.idx1-ubyte‘,‘rb‘) as f:
y_test=np.frombuffer(f.read(),np.uint8,offset=8)
y_test=tf.convert_to_tensor(y_test,tf.int32)
# offset代表从第几个byte后面开始读取,0则是从头开始读 1byte=8bit
# y_test=tf.one_hot(y_test,10)
with open(‘./train-labels.idx1-ubyte‘,‘rb‘) as f:
y_train=np.frombuffer(f.read(),np.uint8,offset=8)
y_train=tf.convert_to_tensor(y_train,tf.int32)
# 1*10000
# y_train=tf.one_hot(y_train,10)
with open(‘./t10k-images.idx3-ubyte‘, ‘rb‘) as f:
x_test = np.frombuffer(f.read(), np.uint8,offset=16).reshape(len(y_test), 28, 28,1)
x_test=tf.convert_to_tensor(x_test,tf.float32)/255
# #502098=28*28*60000
with open(‘./train-images.idx3-ubyte‘, ‘rb‘) as f:
x_train = np.frombuffer(f.read(), np.uint8,offset=16).reshape(len(y_train),28,28,1)
x_train=tf.convert_to_tensor(x_train,dtype=tf.float32)/255
#78400=28*28*10000
return x_train,y_train,x_test,y_test
train_model
#-*- coding:utf-8 -*-
# @Author : Dummerfu
# @Time : 2020/4/20 21:42
import tensorflow as tf
import data_processing
import numpy as np
import os
os.environ[‘TF_CPP_MIN_LOG_LEVEL‘] = ‘2‘
if __name__ == ‘__main__‘:
# x: [60k, 28, 28,1], [10k, 28, 28,1]
# y: [60k], [10k]
x_train, y_train, x_test, y_test = data_processing.read_data()
# print(y_test.shape,x_train.shape)
model=tf.keras.models.Sequential([
# 这里输入层还是要写单个输入的shape
tf.keras.layers.Conv2D(input_shape=(28,28,1),filters=32,
kernel_size=(3,3),strides=(1,1),padding=‘SAME‘,activation=‘relu‘),
tf.keras.layers.MaxPool2D(pool_size=(2,2),strides=(2,2),padding=‘SAME‘),
tf.keras.layers.Conv2D(filters=64,kernel_size=(3,3),
strides=(1,1),padding=‘SAME‘,activation=‘relu‘),
tf.keras.layers.MaxPool2D(pool_size=(2,2),strides=(2,2),padding=‘SAME‘),
tf.keras.layers.Dropout(0.7),
tf.keras.layers.Flatten(),
# FC1
tf.keras.layers.Dense(128,activation=‘relu‘),
tf.keras.layers.Dropout(0.5),
# FC2|output
tf.keras.layers.Dense(10,activation=‘softmax‘),
])
# 查看层的信息
# print(model.summary())
# 设置训练参数
model.compile(optimizer=‘adam‘,loss=‘sparse_categorical_crossentropy‘,metrics=[‘accuracy‘])
# 训练(你甚至都不需要自己转onehot)
# validation_split=x 将训练集*x变为测试集,进行预测
# verbose=1 显示训练信息
model.fit(x=x_train,y=y_train,batch_size=32,epochs=5,validation_split=0.3,verbose=1)
train_loss,train_accu=model.evaluate(x=x_test,y=y_test)
print(train_loss)
print(train_accu)
这个才训练到 98.5%好垃圾
model save|restore
有两种方式save
只保存weight和bias,不保存网络结构
这个知道就好了?其实是我懒得写,可以看那个链接里面写的
保存网络结构
import tensorflow as tf
# 这个model是前面的那个model类
model.save("path")
# model del
# 这里的测试可以自己输入
x_train,y_train,x_test,y_test=data_processing.read_data()
restore_model= tf.keras.models.load_model(‘./my_model.ckpt‘)
loss,acc=restore_model.evaluate(x_test,y_test)
print(loss)
print(acc)
predict
# draw 当然自己随便写,预测数据还是得本地导入
draw(x_test.numpy()[rad].reshape(28,28),y_test.numpy()[rad])
restore_model= tf.keras.models.load_model(‘./my_model.ckpt‘)
pro=np.argmax(restore_model.predict(x_test.numpy()[rad].reshape(1,28,28,1)))
print(‘???‘,pro)
以上是关于CNN实战--mnist的主要内容,如果未能解决你的问题,请参考以下文章
TensorFlow1.x 代码实战系列:MNIST手写数字识别
MATLAB可视化实战系列(四十)-基于MATLAB 自带手写数字集的CNN(LeNet5)手写数字识别-图像处理(附源代码)
TensorFlow入门实战|第1周:实现mnist手写数字识别