机器学习 - KNN识别MNIST

Posted Chobits的文集

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了机器学习 - KNN识别MNIST相关的知识,希望对你有一定的参考价值。

代码

 https://github.com/s055523/MNISTTensorFlowSharp

数据的获得

数据可以由http://yann.lecun.com/exdb/mnist/下载。之后,储存在trainDir中,下次就不需要下载了。

技术分享图片
/// <summary>
        /// 如果文件不存在就去下载
        /// </summary>
        /// <param name="urlBase">下载地址</param>
        /// <param name="trainDir">文件目录地址</param>
        /// <param name="file">文件名</param>
        /// <returns></returns>
        public static Stream MaybeDownload(string urlBase, string trainDir, string file)
        {
            if (!Directory.Exists(trainDir))
            {
                Directory.CreateDirectory(trainDir);
            }

            var target = Path.Combine(trainDir, file);
            if (!File.Exists(target))
            {
                var wc = new WebClient();
                wc.DownloadFile(urlBase + file, target);
            }
            return File.OpenRead(target);
        }
View Code

数据格式处理

下载下来的文件共有四个,都是扩展名为gz的压缩包。

train-images-idx3-ubyte.gz  55000张训练图片和5000张验证图片

train-labels-idx1-ubyte.gz     训练图片对应的数字标签(即答案)

t10k-images-idx3-ubyte.gz   10000张测试图片

t10k-labels-idx1-ubyte.gz     测试图片对应的数字标签(即答案)

处理图片数据压缩包

每个压缩包的格式为:

偏移量

类型

意义

0

Int32

2051或2049

一个定死的魔术数。用来验证该压缩包是训练集(2051)或测试集(2049)

4

Int32

60000或10000

压缩包的图片数

8

Int32

28

每个图片的行数

12

Int32

28

每个图片的列数

16

Unsigned byte

0 - 255

第一张图片的第一个像素

17

Unsigned byte

0 - 255

第一张图片的第二个像素

 

因此,我们可以使用一个统一的方式将数据处理。我们只需要那些图片像素。

技术分享图片
/// <summary>
        /// 从数据流中读取下一个int32
        /// </summary>
        /// <param name="s"></param>
        /// <returns></returns>
        int Read32(Stream s)
        {
            var x = new byte[4];
            s.Read(x, 0, 4);
            return DataConverter.BigEndian.GetInt32(x, 0);
        }

        /// <summary>
        /// 处理图片数据
        /// </summary>
        /// <param name="input"></param>
        /// <param name="file"></param>
        /// <returns></returns>
        MnistImage[] ExtractImages(Stream input, string file)
        {
            //文件是gz格式的
            using (var gz = new GZipStream(input, CompressionMode.Decompress))
            {
                //不是2051说明下载的文件不对
                if (Read32(gz) != 2051)
                {
                    throw new Exception("不是2051说明下载的文件不对: " + file);
                }
                //图片数
                var count = Read32(gz);
                //行数
                var rows = Read32(gz);
                //列数
                var cols = Read32(gz);

                Console.WriteLine($"准备读取{count}张图片。");

                var result = new MnistImage[count];
                for (int i = 0; i < count; i++)
                {
                    //图片的大小(每个像素占一个bit)
                    var size = rows * cols;
                    var data = new byte[size];

                    //从数据流中读取这么大的一块内容
                    gz.Read(data, 0, size);

                    //将读取到的内容转换为MnistImage类型
                    result[i] = new MnistImage(cols, rows, data);
                }
                return result;
            }
        }
View Code

准备一个MnistImage类型:

技术分享图片
/// <summary>
    /// 图片类型
    /// </summary>
    public struct MnistImage
    {
        public int Cols, Rows;
        public byte[] Data;
        public float[] DataFloat;

        public MnistImage(int cols, int rows, byte[] data)
        {
            Cols = cols;
            Rows = rows;
            Data = data;
            DataFloat = new float[data.Length];
            for (int i = 0; i < data.Length; i++)
            {
                //数据归一化(这里将0-255除255变成了0-1之间的小数)
                //也可以归一为-0.5到0.5之间
                DataFloat[i] = Data[i] / 255f;
            }
        }
    }
View Code

这样一来,图片数据就处理完成了。

处理数字标签数据压缩包

数字标签数据压缩包和图片数据压缩包的格式类似。

偏移量

类型

意义

0

Int32

2051或2049

一个定死的魔术数。用来验证该压缩包是训练集(2051)或测试集(2049)

4

Int32

60000或10000

压缩包的数字标签数

5

Unsigned byte

0 - 9

第一张图片对应的数字

6

Unsigned byte

0 - 9

第二张图片对应的数字

 

它的处理更加简单。

技术分享图片
/// <summary>
        /// 处理标签数据
        /// </summary>
        /// <param name="input"></param>
        /// <param name="file"></param>
        /// <returns></returns>
        byte[] ExtractLabels(Stream input, string file)
        {
            using (var gz = new GZipStream(input, CompressionMode.Decompress))
            {
                //不是2049说明下载的文件不对
                if (Read32(gz) != 2049)
                {
                    throw new Exception("不是2049说明下载的文件不对:" + file);
                }
                var count = Read32(gz);
                var labels = new byte[count];

                gz.Read(labels, 0, count);

                return labels;
            }
        }
View Code

将数字标签转化为二维数组:one-hot编码

由于我们的数字为0-9,所以,可以视为有十个class。此时,为了后续的处理方便,我们将数字标签转化为数组。因此,一组标签就转换为了一个二维数组。

例如,标签0变成[1,0,0,0,0,0,0,0,0,0]

标签1变成[0,1,0,0,0,0,0,0,0,0]

以此类推。

技术分享图片
/// <summary>
        /// 将数字标签一维数组转为一个二维数组
        /// </summary>
        /// <param name="labels"></param>
        /// <param name="numClasses">多少个类别,这里是10(0到9)</param>
        /// <returns></returns>
        byte[,] OneHot(byte[] labels, int numClasses)
        {
            var oneHot = new byte[labels.Length, numClasses];
            for (int i = 0; i < labels.Length; i++)
            {
                oneHot[i, labels[i]] = 1;
            }
            return oneHot;
        }
View Code

到此为止,数据格式处理就全部结束了。下面的代码展示了数据处理的全过程。

技术分享图片
        /// <summary>
        /// 处理数据集
        /// </summary>
        /// <param name="trainDir">数据集所在文件夹</param>
        /// <param name="numClasses"></param>
        /// <param name="validationSize">拿出多少做验证?</param>
        public void ReadDataSets(string trainDir, int numClasses = 10, int validationSize = 5000)
        {
            const string SourceUrl = "http://yann.lecun.com/exdb/mnist/";
            const string TrainImagesName = "train-images-idx3-ubyte.gz";
            const string TrainLabelsName = "train-labels-idx1-ubyte.gz";
            const string TestImagesName = "t10k-images-idx3-ubyte.gz";
            const string TestLabelsName = "t10k-labels-idx1-ubyte.gz";

            //获得训练数据,然后处理训练数据和测试数据
            TrainImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TrainImagesName), TrainImagesName);
            TestImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TestImagesName), TestImagesName);
            TrainLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TrainLabelsName), TrainLabelsName);
            TestLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TestLabelsName), TestLabelsName);

            //拿出前面的一部分做验证
            ValidationImages = Pick(TrainImages, 0, validationSize);
            ValidationLabels = Pick(TrainLabels, 0, validationSize);

            //拿出剩下的做训练(输入0意味着拿剩下所有的)
            TrainImages = Pick(TrainImages, validationSize, 0);
            TrainLabels = Pick(TrainLabels, validationSize, 0);

            //将数字标签转换为二维数组
            //例如,标签3 =》 [0,0,0,1,0,0,0,0,0,0]
            //标签0 =》 [1,0,0,0,0,0,0,0,0,0]
            if (numClasses != -1)
            {
                OneHotTrainLabels = OneHot(TrainLabels, numClasses);
                OneHotValidationLabels = OneHot(ValidationLabels, numClasses);
                OneHotTestLabels = OneHot(TestLabels, numClasses);
            }
        }

        /// <summary>
        /// 获得source集合中的一部分,从first开始,到last结束
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="source"></param>
        /// <param name="first"></param>
        /// <param name="last"></param>
        /// <returns></returns>
        T[] Pick<T>(T[] source, int first, int last)
        {
            if (last == 0)
            {
                last = source.Length;
            }

            var count = last - first;
            var ret = source.Skip(first).Take(count).ToArray();
            return ret;
        }

        public static Mnist Load()
        {
            var x = new Mnist();
            x.ReadDataSets(@"D:\人工智能\C#代码\MNISTTensorFlowSharp\MNISTTensorFlowSharp\data");
            return x;
        }
View Code

在这里,数据共有下面几部分:

  1. 训练图片数据55000 TrainImages及对应标签TrainLabels
  2. 验证图片数据5000 ValidationImages及对应标签ValidationLabels
  3. 测试图片数据10000 TestImages及对应标签TestLabels

KNN算法的实现

现在,我们已经有了所有的数据在手。需要实现的是:

  1. 拿出数据中的一部分(例如,5000张图片)作为KNN的训练数据,然后,再从数据中的另一部分拿一张图片A
  2. 对这张图片A,求它和5000张训练图片的距离,并找出一张训练图片B,它是所有训练图片中,和A距离最小的那张(这意味着K=1)
  3. 此时,就认为A所代表的数字等同于B所代表的数字b
  4. 重复1-3,N次

首先进行数据的收集:

技术分享图片
//三个Reader分别从总的数据库中获得数据
        public BatchReader GetTrainReader() => new BatchReader(TrainImages, TrainLabels, OneHotTrainLabels);
        public BatchReader GetTestReader() => new BatchReader(TestImages, TestLabels, OneHotTestLabels);
        public BatchReader GetValidationReader() => new BatchReader(ValidationImages, ValidationLabels, OneHotValidationLabels);

        /// <summary>
        /// 数据的一部分,包括了所有的有用信息
        /// </summary>
        public class BatchReader
        {
            int start = 0;
            //图片库
            MnistImage[] source;
            //数字标签
            byte[] labels;
            //oneHot之后的数字标签
            byte[,] oneHotLabels;

            internal BatchReader(MnistImage[] source, byte[] labels, byte[,] oneHotLabels)
            {
                this.source = source;
                this.labels = labels;
                this.oneHotLabels = oneHotLabels;
            }

            /// <summary>
            /// 返回两个浮点二维数组(C# 7的新语法)
            /// </summary>
            /// <param name="batchSize"></param>
            /// <returns></returns>
            public (float[,], float[,]) NextBatch(int batchSize)
            {
                //一张图
                var imageData = new float[batchSize, 784];
                //标签
                var labelData = new float[batchSize, 10];

                int p = 0;
                for (int item = 0; item < batchSize; item++)
                {
                    Buffer.BlockCopy(source[start + item].DataFloat, 0, imageData, p, 784 * sizeof(float));
                    p += 784 * sizeof(float);
                    for (var j = 0; j < 10; j++)
                        labelData[item, j] = oneHotLabels[item + start, j];
                }

                start += batchSize;
                return (imageData, labelData);
            }
        }
View Code

然后,在算法中,获取数据:

技术分享图片
        static void KNN()
        {
            //取得数据
            var mnist = Mnist.Load();

            //拿5000个训练数据,200个测试数据
            const int trainCount = 5000;
            const int testCount = 200;

            //获得的数据有两个
            //一个是图片,它们都是28*28的
            //一个是one-hot的标签,它们都是1*10的
            (var trainingImages, var trainingLabels) = mnist.GetTrainReader().NextBatch(trainCount);
            (var testImages, var testLabels) = mnist.GetTestReader().NextBatch(testCount);

            Console.WriteLine($"MNIST 1NN");
View Code

下面进行计算。这里使用了K=1的L1距离。这是最简单的情况。

技术分享图片
            //建立一个图表示计算任务
            using (var graph = new TFGraph())
            {
                var session = new TFSession(graph);

                //用来feed数据的占位符。trainingInput表示N张用来进行训练的图片,N是一个变量,所以这里使用-1
                TFOutput trainingInput = graph.Placeholder(TFDataType.Float, new TFShape(-1, 784));

                //xte表示一张用来测试的图片
                TFOutput xte = graph.Placeholder(TFDataType.Float, new TFShape(784));

                //计算这两张图片的L1距离。这很简单,实际上就是把784个数字逐对相减,然后取绝对值,最后加起来变成一个总和
                var distance = graph.ReduceSum(graph.Abs(graph.Sub(trainingInput, xte)), axis: graph.Const(1));

                //这里只是用了最近的那个数据
                //也就是说,最近的那个数据是什么,那pred(预测值)就是什么
                TFOutput pred = graph.ArgMin(distance, graph.Const(0));
View Code

最后是开启Session计算的过程:

技术分享图片
                var accuracy = 0f;

                //开始循环进行计算,循环trainCount次
                for (int i = 0; i < testCount; i++)
                {
                    var runner = session.GetRunner();

                    //每次,对一张新的测试图,计算它和trainCount张训练图的距离,并获得最近的那张
                    var result = runner.Fetch(pred).Fetch(distance)
                        //trainCount张训练图(数据是trainingImages)
                        .AddInput(trainingInput, trainingImages)
                        //testCount张测试图(数据是从testImages中拿出来的)
                        .AddInput(xte, Extract(testImages, i))
                        .Run();
                    
                    //最近的点的序号
                    var nn_index = (int)(long)result[0].GetValue();

                    //从trainingLabels中找到答案(这是预测值)
                    var prediction = ArgMax(trainingLabels, nn_index);

                    //正确答案位于testLabels[i]中
                    var real = ArgMax(testLabels, i);

                    //PrintImage(testImages, i);

                    Console.WriteLine($"测试 {i}: " +
                        $"预测: {prediction} " +
                        $"正确答案: {real} (最近的点的序号={nn_index})");
                    //Console.WriteLine(testImages);

                    if (prediction == real)
                    {
                        accuracy += 1f / testCount;
                    }
                }
                Console.WriteLine("准确率: " + accuracy);
View Code

对KNN的改进

本文只是对KNN识别MNIST数据集进行了一个非常简单的介绍。在实现了最简单的K=1的L1距离计算之后,正确率约为91%。大家可以试着将算法进行改进,例如取K=2或者其他数,或者计算L2距离等。L2距离的结果比L1好一些,可以达到93-94%的正确率。

以上是关于机器学习 - KNN识别MNIST的主要内容,如果未能解决你的问题,请参考以下文章

利用knn svm cnn 逻辑回归 mlp rnn等方法实现mnist数据集分类(pytorch实现)

tensorflow-实现knn算法-识别mnist数据集

tensorflow-实现knn算法-识别mnist数据集

小刘的深度学习---CNN

机器学习KNN算法实现手写板字迹识别

机器学习-kNN手写数字识别