用SVM(有核和无核函数)进行MNIST手写字体的分类
Posted yaowuyangwei521
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了用SVM(有核和无核函数)进行MNIST手写字体的分类相关的知识,希望对你有一定的参考价值。
1.普通SVM分类MNIST数据集
1 #导入必备的包 2 import numpy as np 3 import struct 4 import matplotlib.pyplot as plt 5 import os 6 ##加载svm模型 7 from sklearn import svm 8 ###用于做数据预处理 9 from sklearn import preprocessing 10 import time 11 12 #加载数据的路径 13 path=‘./dataset/mnist/raw‘ 14 def load_mnist_train(path, kind=‘train‘): 15 labels_path = os.path.join(path,‘%s-labels-idx1-ubyte‘% kind) 16 images_path = os.path.join(path,‘%s-images-idx3-ubyte‘% kind) 17 with open(labels_path, ‘rb‘) as lbpath: 18 magic, n = struct.unpack(‘>II‘,lbpath.read(8)) 19 labels = np.fromfile(lbpath,dtype=np.uint8) 20 with open(images_path, ‘rb‘) as imgpath: 21 magic, num, rows, cols = struct.unpack(‘>IIII‘,imgpath.read(16)) 22 images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784) 23 return images, labels 24 def load_mnist_test(path, kind=‘t10k‘): 25 labels_path = os.path.join(path,‘%s-labels-idx1-ubyte‘% kind) 26 images_path = os.path.join(path,‘%s-images-idx3-ubyte‘% kind) 27 with open(labels_path, ‘rb‘) as lbpath: 28 magic, n = struct.unpack(‘>II‘,lbpath.read(8)) 29 labels = np.fromfile(lbpath,dtype=np.uint8) 30 with open(images_path, ‘rb‘) as imgpath: 31 magic, num, rows, cols = struct.unpack(‘>IIII‘,imgpath.read(16)) 32 images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784) 33 return images, labels 34 train_images,train_labels=load_mnist_train(path) 35 test_images,test_labels=load_mnist_test(path) 36 37 X=preprocessing.StandardScaler().fit_transform(train_images) 38 X_train=X[0:60000] 39 y_train=train_labels[0:60000] 40 41 print(time.strftime(‘%Y-%m-%d %H:%M:%S‘)) 42 model_svc = svm.LinearSVC() 43 #model_svc = svm.SVC() 44 model_svc.fit(X_train,y_train) 45 print(time.strftime(‘%Y-%m-%d %H:%M:%S‘)) 46 47 ##显示前30个样本的真实标签和预测值,用图显示 48 x=preprocessing.StandardScaler().fit_transform(test_images) 49 x_test=x[0:10000] 50 y_pred=test_labels[0:10000] 51 print(model_svc.score(x_test,y_pred)) 52 y=model_svc.predict(x) 53 54 fig1=plt.figure(figsize=(8,8)) 55 fig1.subplots_adjust(left=0,right=1,bottom=0,top=1,hspace=0.05,wspace=0.05) 56 for i in range(100): 57 ax=fig1.add_subplot(10,10,i+1,xticks=[],yticks=[]) 58 ax.imshow(np.reshape(test_images[i], [28,28]),cmap=plt.cm.binary,interpolation=‘nearest‘) 59 ax.text(0,2,"pred:"+str(y[i]),color=‘red‘) 60 #ax.text(0,32,"real:"+str(test_labels[i]),color=‘blue‘) 61 plt.show()
2.运行结果:
开始时间:2018-11-17 08:31:09
结束时间:2018-11-17 08:53:04
用时:21分55秒
精度:0.9122
预测图片:
以上是关于用SVM(有核和无核函数)进行MNIST手写字体的分类的主要内容,如果未能解决你的问题,请参考以下文章
实现手写数字识别(数据集50000张图片)比较3种算法神经网络灰度平均值SVM各自的准确率—Jason niu
SVM:利用SVM算法实现手写图片识别(数据集50000张图片)—Jason niu