使用mllib完成mnist手写识别任务

Posted 凌祈丶微光

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用mllib完成mnist手写识别任务相关的知识,希望对你有一定的参考价值。

使用mllib完成mnist手写识别任务

  1. 小提示,通过restart命令重启已经退出了的容器

    sudo docker restart <contain id>

  2. 完成识别任务准备工作

    1. 从以下网站下载数据集:

      MNIST手写数字数据库,Yann LeCun,Corinna Cortes和Chris Burges

      数据集包含以下四个压缩包,下载后解压得到数据集文件:

      • t10k-images-idx3-ubyte.gz
      • t10k-labels-idx1-ubyte.gz
      • train-images-idx3-ubyte.gz
      • train-labels-idx1-ubyte.gz
    2. 通过以下python程序,将数据集文件转换为csv文件

      def convert(imgf, labelf, outf, n):
          f = open(imgf, "rb")
          o = open(outf, "w")
          l = open(labelf, "rb")
      
          f.read(16)
          l.read(8)
          images = []
      
          for i in range(n):
              image = [ord(l.read(1))]
              for j in range(28 * 28):
                  image.append(ord(f.read(1)))
              images.append(image)
      
          for image in images:
              o.write(",".join(str(pix) for pix in image) + "\\n")
          f.close()
          o.close()
          l.close()
      
      
      # 数据集在 http://yann.lecun.com/exdb/mnist/ 下载
      convert("train-images.idx3-ubyte", "train-labels.idx1-ubyte",
              "mnist_train.csv", 60000)
      convert("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte",
              "mnist_test.csv", 10000)
      

      通过这个程序将在根目录下产生以下两个文件:

      • mnist_train.csv
      • mnist_test.csv
    3. 通过以下python程序转换csv文件为libsvm文件

      import csv
      
      
      def execute(data, savepath):
      
          csv_reader = csv.reader(open(data))
          f = open(savepath, 'wb')
          for line in csv_reader:
              label = line[0]
              features = line[1:]
              libsvm_line = label + ' '
      
              for index, feature in enumerate(features):
                  libsvm_line += str(index + 1) + ':' + feature + ' '
              f.write(bytes(libsvm_line.strip() + '\\n', 'UTF-8'))
      
          f.close()
      
      
      execute('mnist_train.csv', 'mnist_train.libsvm')
      execute('mnist_test.csv', 'mnist_test.libsvm')
      

      该程序将生成以下两个.libsvm文件:

      • mnist_test.libsvm
      • mnist_train.libsvm
    4. 通过共享目录传递数据集到spark-master容器内。

    5. 进入spark-master

      sudo docker exec -it spark-master /bin/bash

    6. 打开spark-shell

      spark-shell位于/spark/bin目录下

      使用./spark-shell命令进入spark-shell。

  3. 完成识别任务

    1. 读取训练集

      val train = spark.read.format("libsvm").load("/data/mnist_train.libsvm")
      

    2. 读取测试集

      val test = 		spark.read.format("libsvm").load("/data/mnist_test.libsvm")
      

    3. 定义网络结构。如果计算机性能不好可以降低隐藏层的参数。

      val layers = Array[Int](784, 784, 784, 10)
      

    4. 导入多层感知机与多分类评价器。

      import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
      import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
      

    5. 使用多层感知机初始化训练器。

      val trainer = new MultilayerPerceptronClassifier().setLayers(layers).setBlockSize(128).setSeed(1234L).setMaxIter(100)
      

    6. 训练模型

      var model = trainer.fit(train)
      

    7. 输入测试集进行识别

      val result = model.transform(test)
      

    8. 获取测试结果中的预测结果与实际结果

      val predictionAndLabels = result.select("prediction", "label")
      

    9. 初始化评价器

      val evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy")
      

    10. 计算识别精度

      println(s"Test set accuracy = $evaluator.evaluate(predictionAndLabels)")
      

    11. 在result上创建临时视图

      result.toDF.createOrReplaceTempView("deep_learning")
      

    12. 使用Spark SQL的方式计算识别精度

      spark.sql("select (select count(*) from deep_learning where label=prediction)/count(*) as accuracy from deep_learning").show()
      

以上是关于使用mllib完成mnist手写识别任务的主要内容,如果未能解决你的问题,请参考以下文章

只用一个神经元可以完成MNIST手写体识别吗?

Mnist手写数字识别 Tensorflow

深度学习笔记:基于Keras库的MNIST手写数字识别

基于 Mindspore 框架与 ModelArts 平台的 MNIST 手写体识别实验

pytorch实现MNIST手写体识别(全连接神经网络)

Tensorflow编程基础之Mnist手写识别实验+关于cross_entropy的理解