机器学习-MNIST数据集-神经网络
Posted david2018098
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了机器学习-MNIST数据集-神经网络相关的知识,希望对你有一定的参考价值。
1 #设置随机种子 2 seed = 7 3 numpy.random.seed(seed) 4 5 #加载数据 6 (X_train,y_train),(X_test,y_test) = mnist.load_data() 7 #print(X_train.shape[0]) 8 9 #数据集是3维的向量(instance length,width,height).对于多层感知机,模型的输入是二维的向量,因此这里需要将数据集reshape,即将28*28的向量转成784长度的数组。可以用numpy的reshape函数轻松实现这个过程。 10 num_pixels = X_train.shape[1] * X_train.shape[2] 11 X_train = X_train.reshape(X_train.shape[0],num_pixels).astype(‘float32‘) 12 X_test = X_test.reshape(X_test.shape[0],num_pixels).astype(‘float32‘) 13 14 #给定的像素的灰度值在0-255,为了使模型的训练效果更好,通常将数值归一化映射到0-1 15 X_train = X_train / 255 16 X_test = X_test / 255 17 # one hot encoding 18 y_train = np_utils.to_categorical(y_train) 19 y_test = np_utils.to_categorical(y_test) 20 num_classes = y_test.shape[1] 21 22 # 搭建神经网络模型了,创建一个函数,建立含有一个隐层的神经网络 23 def baseline_model(): 24 model = Sequential() # 建立一个Sequential模型,然后一层一层加入神经元 25 # 第一步是确定输入层的数目正确:在创建模型时用input_dim参数确定。例如,有784个个输入变量,就设成num_pixels。 26 #全连接层用Dense类定义:第一个参数是本层神经元个数,然后是初始化方式和激活函数。这里的初始化方法是0到0.05的连续型均匀分布(uniform),Keras的默认方法也是这个。也可以用高斯分布进行初始化(normal)。 27 # 具体定义参考:https://cnbeining.github.io/deep-learning-with-python-cn/3-multi-layer-perceptrons/ch7-develop-your-first-neural-network-with-keras.html 28 model.add(Dense(num_pixels,input_dim=num_pixels,kernel_initializer=‘normal‘,activation=‘relu‘)) 29 model.add(Dense(num_classes,kernel_initializer=‘normal‘,activation=‘softmax‘)) 30 model.compile(loss=‘categorical_crossentropy‘,optimizer=‘adam‘,metrics=[‘accuracy‘]) 31 return model 32 33 model = baseline_model() 34 #model.fit() 函数每个参数的意义参考:https://blog.csdn.net/a1111h/article/details/82148497 35 model.fit(X_train,y_train,validation_data=(X_test,y_test),epochs=10,batch_size=200,verbose=2) 36 # 1、模型概括打印 37 model.summary() 38 39 scores = model.evaluate(X_test,y_test,verbose=0) #model.evaluate 返回计算误差和准确率 40 print(scores) 41 print("Base Error:%.2f%%"%(100-scores[1]*100))
以上是关于机器学习-MNIST数据集-神经网络的主要内容,如果未能解决你的问题,请参考以下文章
深度学习与TensorFlow 2.0卷积神经网络(CNN)