阅读项目:通过机器学习识别手写数字

Posted dayoulaoshi

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了阅读项目:通过机器学习识别手写数字相关的知识,希望对你有一定的参考价值。

地址:https://github.com/JosephPai/KaggleSolution/tree/master/DigitRec

数据集:https://www.kaggle.com/c/digit-recognizer/data

 

import pandas as pd

import matplotlib.pyplot as plt, matplotlib.image as mpimg

from sklearn.model_selection import train_test_split

from sklearn import svm

导入库,这里使用了panadas 库进行数据处理。

通过skllearn库选择分类器进行训练

 

 

labeled_images = pd.read_csv(‘C:/Users/75201/Desktop/train/input/train.csv‘)

images = labeled_images.iloc[0:5000,1:]

labels = labeled_images.iloc[0:5000,:1]

train_images,test_images,train_labels,test_labels=train_test_split(images,labels, train_size=0.8, random_state=0)

 

首先导入训练集,然后将标签与内容分开,labels保存标签,images保存内容,然后划分训练集和测试集

train_images:训练集内容

test_images:测试集内容

train_labels,:训练集标签

test_labels:测试集标签

 

 

这一步可以进行查看图片,将一维的数据展示成图片

i=8

img=train_images.iloc[i].as_matrix()

img=img.reshape((28,28))

plt.imshow(img,cmap=‘gray‘)

plt.title(train_labels.iloc[i,0])

 

可以看出图片是有灰度的

 

 

 

这一步开始训练,使用sklearn 包提供的 svm 模型来建立一个分类器 classifier,

clf = svm.SVC()

clf.fit(train_images, train_labels.values.ravel())

clf.score(test_images,test_labels)

训练结果0.10000,很不理想

 

 

 

这一步将图片转化为黑白,简化特征值,可以大幅提高准确率

test_images[test_images>0]=1

train_images[train_images>0]=1

img=train_images.iloc[i].as_matrix().reshape((28,28))

plt.imshow(img,cmap=‘binary‘)

plt.title(train_labels.iloc[i])

 

 

 

 

 再次使用相同的分类器进行训练

clf = svm.SVC()
clf.fit(train_images, train_labels.values.ravel())
clf.score(test_images,test_labels)

 训练结果0.887

 

成绩初步满意,可以开始测试,导入测试集,并将测试结果写入到test文件中

test_data=pd.read_csv(‘C:/Users/75201/Desktop/train/input/test.csv‘)

test_data[test_data>0]=1

results=clf.predict(test_data[0:5000])

df = pd.DataFrame(results)

df.index.name=‘ImageId‘

df.index+=1

df.columns=[‘Label‘]

df.to_csv(‘C:/Users/75201/Desktop/train/input/results.csv‘, header=True)

 

 

 

 

读完这个项目,我认为可以优化以下几点

  1. 增大训练样本,数据集中的数量不仅仅有5000,加大样本可以提高准确率
  2. 增加外部接口,将图片预处理为28*28像素的图片,方便进行外部测试
  3. 尝试其他分类器。
  4. 优化特征。

 

以上是关于阅读项目:通过机器学习识别手写数字的主要内容,如果未能解决你的问题,请参考以下文章

Andrew Ng 机器学习课程笔记 ———— 通过初步的神经网络实现手写数字的识别(尽力去向量化实现)

BP神经网络-手写数字的识别-机器学习实验二

机器学习初探(手写数字识别)matlab读取数据集

机器学习-kNN手写数字识别

机器学习教程 十四-利用tensorflow做手写数字识别

机器学习——15 手写数字识别-小数据集