机器学习-MNIST数据集-神经网络

Posted david2018098

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了机器学习-MNIST数据集-神经网络相关的知识,希望对你有一定的参考价值。

 1 #设置随机种子
 2 seed = 7 
 3 numpy.random.seed(seed)
 4 
 5 #加载数据
 6 (X_train,y_train),(X_test,y_test) = mnist.load_data() 
 7 #print(X_train.shape[0])
 8 
 9 #数据集是3维的向量(instance length,width,height).对于多层感知机,模型的输入是二维的向量,因此这里需要将数据集reshape,即将28*28的向量转成784长度的数组。可以用numpy的reshape函数轻松实现这个过程。
10 num_pixels = X_train.shape[1] * X_train.shape[2] 
11 X_train = X_train.reshape(X_train.shape[0],num_pixels).astype(float32)
12 X_test = X_test.reshape(X_test.shape[0],num_pixels).astype(float32)
13 
14 #给定的像素的灰度值在0-255,为了使模型的训练效果更好,通常将数值归一化映射到0-1
15 X_train = X_train / 255
16 X_test = X_test / 255
17 # one hot encoding
18 y_train = np_utils.to_categorical(y_train)
19 y_test = np_utils.to_categorical(y_test)
20 num_classes = y_test.shape[1]
21 
22 # 搭建神经网络模型了,创建一个函数,建立含有一个隐层的神经网络
23 def baseline_model():
24     model = Sequential() # 建立一个Sequential模型,然后一层一层加入神经元
25     # 第一步是确定输入层的数目正确:在创建模型时用input_dim参数确定。例如,有784个个输入变量,就设成num_pixels。
26     #全连接层用Dense类定义:第一个参数是本层神经元个数,然后是初始化方式和激活函数。这里的初始化方法是0到0.05的连续型均匀分布(uniform),Keras的默认方法也是这个。也可以用高斯分布进行初始化(normal)。
27     # 具体定义参考:https://cnbeining.github.io/deep-learning-with-python-cn/3-multi-layer-perceptrons/ch7-develop-your-first-neural-network-with-keras.html
28     model.add(Dense(num_pixels,input_dim=num_pixels,kernel_initializer=normal,activation=relu))
29     model.add(Dense(num_classes,kernel_initializer=normal,activation=softmax))
30     model.compile(loss=categorical_crossentropy,optimizer=adam,metrics=[accuracy])
31     return model
32 
33 model = baseline_model()
34 #model.fit() 函数每个参数的意义参考:https://blog.csdn.net/a1111h/article/details/82148497
35 model.fit(X_train,y_train,validation_data=(X_test,y_test),epochs=10,batch_size=200,verbose=2) 
36 # 1、模型概括打印
37 model.summary()
38 
39 scores = model.evaluate(X_test,y_test,verbose=0) #model.evaluate 返回计算误差和准确率
40 print(scores)
41 print("Base Error:%.2f%%"%(100-scores[1]*100))

 

以上是关于机器学习-MNIST数据集-神经网络的主要内容,如果未能解决你的问题,请参考以下文章

深度学习与TensorFlow 2.0卷积神经网络(CNN)

MNIST机器学习数据集

机器学习算法专题(蓄力计划)二十实操代码MNIST 数据集

手写数字识别——基于全连接层和MNIST数据集

机器学习(TensorFlow)---Fashion MNIST数据集使用范例(计算机视觉)

神经网络的学习-搭建神经网络实现mnist数据集分类