使用mllib完成mnist手写识别任务
Posted 凌祈丶微光
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用mllib完成mnist手写识别任务相关的知识,希望对你有一定的参考价值。
使用mllib完成mnist手写识别任务
-
小提示,通过restart命令重启已经退出了的容器
sudo docker restart <contain id>
-
完成识别任务准备工作
-
从以下网站下载数据集:
MNIST手写数字数据库,Yann LeCun,Corinna Cortes和Chris Burges
数据集包含以下四个压缩包,下载后解压得到数据集文件:
- t10k-images-idx3-ubyte.gz
- t10k-labels-idx1-ubyte.gz
- train-images-idx3-ubyte.gz
- train-labels-idx1-ubyte.gz
-
通过以下python程序,将数据集文件转换为csv文件
def convert(imgf, labelf, outf, n): f = open(imgf, "rb") o = open(outf, "w") l = open(labelf, "rb") f.read(16) l.read(8) images = [] for i in range(n): image = [ord(l.read(1))] for j in range(28 * 28): image.append(ord(f.read(1))) images.append(image) for image in images: o.write(",".join(str(pix) for pix in image) + "\\n") f.close() o.close() l.close() # 数据集在 http://yann.lecun.com/exdb/mnist/ 下载 convert("train-images.idx3-ubyte", "train-labels.idx1-ubyte", "mnist_train.csv", 60000) convert("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte", "mnist_test.csv", 10000)
通过这个程序将在根目录下产生以下两个文件:
- mnist_train.csv
- mnist_test.csv
-
通过以下python程序转换csv文件为libsvm文件
import csv def execute(data, savepath): csv_reader = csv.reader(open(data)) f = open(savepath, 'wb') for line in csv_reader: label = line[0] features = line[1:] libsvm_line = label + ' ' for index, feature in enumerate(features): libsvm_line += str(index + 1) + ':' + feature + ' ' f.write(bytes(libsvm_line.strip() + '\\n', 'UTF-8')) f.close() execute('mnist_train.csv', 'mnist_train.libsvm') execute('mnist_test.csv', 'mnist_test.libsvm')
该程序将生成以下两个.libsvm文件:
- mnist_test.libsvm
- mnist_train.libsvm
-
通过共享目录传递数据集到spark-master容器内。
-
进入spark-master
sudo docker exec -it spark-master /bin/bash
-
打开spark-shell
spark-shell位于/spark/bin目录下
使用
./spark-shell
命令进入spark-shell。
-
-
完成识别任务
-
读取训练集
val train = spark.read.format("libsvm").load("/data/mnist_train.libsvm")
-
读取测试集
val test = spark.read.format("libsvm").load("/data/mnist_test.libsvm")
-
定义网络结构。如果计算机性能不好可以降低隐藏层的参数。
val layers = Array[Int](784, 784, 784, 10)
-
导入多层感知机与多分类评价器。
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
-
使用多层感知机初始化训练器。
val trainer = new MultilayerPerceptronClassifier().setLayers(layers).setBlockSize(128).setSeed(1234L).setMaxIter(100)
-
训练模型
var model = trainer.fit(train)
-
输入测试集进行识别
val result = model.transform(test)
-
获取测试结果中的预测结果与实际结果
val predictionAndLabels = result.select("prediction", "label")
-
初始化评价器
val evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy")
-
计算识别精度
println(s"Test set accuracy = $evaluator.evaluate(predictionAndLabels)")
-
在result上创建临时视图
result.toDF.createOrReplaceTempView("deep_learning")
-
使用Spark SQL的方式计算识别精度
spark.sql("select (select count(*) from deep_learning where label=prediction)/count(*) as accuracy from deep_learning").show()
-
以上是关于使用mllib完成mnist手写识别任务的主要内容,如果未能解决你的问题,请参考以下文章