用SVM(有核和无核函数)进行MNIST手写字体的分类

Posted yaowuyangwei521

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了用SVM(有核和无核函数)进行MNIST手写字体的分类相关的知识,希望对你有一定的参考价值。

1.普通SVM分类MNIST数据集

 1 #导入必备的包
 2 import numpy as np
 3 import struct
 4 import matplotlib.pyplot as plt
 5 import os
 6 ##加载svm模型
 7 from sklearn import svm
 8 ###用于做数据预处理
 9 from sklearn import preprocessing
10 import time
11 
12 #加载数据的路径
13 path=./dataset/mnist/raw
14 def load_mnist_train(path, kind=train):    
15     labels_path = os.path.join(path,%s-labels-idx1-ubyte% kind)
16     images_path = os.path.join(path,%s-images-idx3-ubyte% kind)
17     with open(labels_path, rb) as lbpath:
18         magic, n = struct.unpack(>II,lbpath.read(8))
19         labels = np.fromfile(lbpath,dtype=np.uint8)
20     with open(images_path, rb) as imgpath:
21         magic, num, rows, cols = struct.unpack(>IIII,imgpath.read(16))
22         images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784)
23     return images, labels
24 def load_mnist_test(path, kind=t10k):
25     labels_path = os.path.join(path,%s-labels-idx1-ubyte% kind)
26     images_path = os.path.join(path,%s-images-idx3-ubyte% kind)
27     with open(labels_path, rb) as lbpath:
28         magic, n = struct.unpack(>II,lbpath.read(8))
29         labels = np.fromfile(lbpath,dtype=np.uint8)
30     with open(images_path, rb) as imgpath:
31         magic, num, rows, cols = struct.unpack(>IIII,imgpath.read(16))
32         images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784)
33     return images, labels   
34 train_images,train_labels=load_mnist_train(path)
35 test_images,test_labels=load_mnist_test(path)
36 
37 X=preprocessing.StandardScaler().fit_transform(train_images)  
38 X_train=X[0:60000]
39 y_train=train_labels[0:60000]
40 
41 print(time.strftime(%Y-%m-%d %H:%M:%S))
42 model_svc = svm.LinearSVC()
43 #model_svc = svm.SVC()
44 model_svc.fit(X_train,y_train)
45 print(time.strftime(%Y-%m-%d %H:%M:%S))
46 
47 ##显示前30个样本的真实标签和预测值,用图显示
48 x=preprocessing.StandardScaler().fit_transform(test_images)
49 x_test=x[0:10000]
50 y_pred=test_labels[0:10000]
51 print(model_svc.score(x_test,y_pred))
52 y=model_svc.predict(x)
53 
54 fig1=plt.figure(figsize=(8,8))
55 fig1.subplots_adjust(left=0,right=1,bottom=0,top=1,hspace=0.05,wspace=0.05)
56 for i in range(100):
57     ax=fig1.add_subplot(10,10,i+1,xticks=[],yticks=[])
58     ax.imshow(np.reshape(test_images[i], [28,28]),cmap=plt.cm.binary,interpolation=nearest)
59     ax.text(0,2,"pred:"+str(y[i]),color=red)
60     #ax.text(0,32,"real:"+str(test_labels[i]),color=‘blue‘)
61 plt.show()

2.运行结果:

开始时间:2018-11-17 08:31:09

结束时间:2018-11-17 08:53:04

用时:21分55秒

精度:0.9122

预测图片:

技术分享图片

 

 

 

以上是关于用SVM(有核和无核函数)进行MNIST手写字体的分类的主要内容,如果未能解决你的问题,请参考以下文章

简单HOG+SVM mnist手写数字分类

实现手写数字识别(数据集50000张图片)比较3种算法神经网络灰度平均值SVM各自的准确率—Jason niu

SVM:利用SVM算法实现手写图片识别(数据集50000张图片)—Jason niu

识别数字,bp神经网络算法,卷积神经网络算法,svm算法,adaboost算法哪种好

SVM+核函数选择

手写数字集介绍