机器学习 - KNN识别MNIST
Posted Chobits的文集
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了机器学习 - KNN识别MNIST相关的知识,希望对你有一定的参考价值。
代码
https://github.com/s055523/MNISTTensorFlowSharp
数据的获得
数据可以由http://yann.lecun.com/exdb/mnist/下载。之后,储存在trainDir中,下次就不需要下载了。
/// <summary> /// 如果文件不存在就去下载 /// </summary> /// <param name="urlBase">下载地址</param> /// <param name="trainDir">文件目录地址</param> /// <param name="file">文件名</param> /// <returns></returns> public static Stream MaybeDownload(string urlBase, string trainDir, string file) { if (!Directory.Exists(trainDir)) { Directory.CreateDirectory(trainDir); } var target = Path.Combine(trainDir, file); if (!File.Exists(target)) { var wc = new WebClient(); wc.DownloadFile(urlBase + file, target); } return File.OpenRead(target); }
数据格式处理
下载下来的文件共有四个,都是扩展名为gz的压缩包。
train-images-idx3-ubyte.gz 55000张训练图片和5000张验证图片
train-labels-idx1-ubyte.gz 训练图片对应的数字标签(即答案)
t10k-images-idx3-ubyte.gz 10000张测试图片
t10k-labels-idx1-ubyte.gz 测试图片对应的数字标签(即答案)
处理图片数据压缩包
每个压缩包的格式为:
偏移量 |
类型 |
值 |
意义 |
0 |
Int32 |
2051或2049 |
一个定死的魔术数。用来验证该压缩包是训练集(2051)或测试集(2049) |
4 |
Int32 |
60000或10000 |
压缩包的图片数 |
8 |
Int32 |
28 |
每个图片的行数 |
12 |
Int32 |
28 |
每个图片的列数 |
16 |
Unsigned byte |
0 - 255 |
第一张图片的第一个像素 |
17 |
Unsigned byte |
0 - 255 |
第一张图片的第二个像素 |
… |
… |
… |
… |
因此,我们可以使用一个统一的方式将数据处理。我们只需要那些图片像素。
/// <summary> /// 从数据流中读取下一个int32 /// </summary> /// <param name="s"></param> /// <returns></returns> int Read32(Stream s) { var x = new byte[4]; s.Read(x, 0, 4); return DataConverter.BigEndian.GetInt32(x, 0); } /// <summary> /// 处理图片数据 /// </summary> /// <param name="input"></param> /// <param name="file"></param> /// <returns></returns> MnistImage[] ExtractImages(Stream input, string file) { //文件是gz格式的 using (var gz = new GZipStream(input, CompressionMode.Decompress)) { //不是2051说明下载的文件不对 if (Read32(gz) != 2051) { throw new Exception("不是2051说明下载的文件不对: " + file); } //图片数 var count = Read32(gz); //行数 var rows = Read32(gz); //列数 var cols = Read32(gz); Console.WriteLine($"准备读取{count}张图片。"); var result = new MnistImage[count]; for (int i = 0; i < count; i++) { //图片的大小(每个像素占一个bit) var size = rows * cols; var data = new byte[size]; //从数据流中读取这么大的一块内容 gz.Read(data, 0, size); //将读取到的内容转换为MnistImage类型 result[i] = new MnistImage(cols, rows, data); } return result; } }
准备一个MnistImage类型:
/// <summary> /// 图片类型 /// </summary> public struct MnistImage { public int Cols, Rows; public byte[] Data; public float[] DataFloat; public MnistImage(int cols, int rows, byte[] data) { Cols = cols; Rows = rows; Data = data; DataFloat = new float[data.Length]; for (int i = 0; i < data.Length; i++) { //数据归一化(这里将0-255除255变成了0-1之间的小数) //也可以归一为-0.5到0.5之间 DataFloat[i] = Data[i] / 255f; } } }
这样一来,图片数据就处理完成了。
处理数字标签数据压缩包
数字标签数据压缩包和图片数据压缩包的格式类似。
偏移量 |
类型 |
值 |
意义 |
0 |
Int32 |
2051或2049 |
一个定死的魔术数。用来验证该压缩包是训练集(2051)或测试集(2049) |
4 |
Int32 |
60000或10000 |
压缩包的数字标签数 |
5 |
Unsigned byte |
0 - 9 |
第一张图片对应的数字 |
6 |
Unsigned byte |
0 - 9 |
第二张图片对应的数字 |
… |
… |
… |
… |
它的处理更加简单。
/// <summary> /// 处理标签数据 /// </summary> /// <param name="input"></param> /// <param name="file"></param> /// <returns></returns> byte[] ExtractLabels(Stream input, string file) { using (var gz = new GZipStream(input, CompressionMode.Decompress)) { //不是2049说明下载的文件不对 if (Read32(gz) != 2049) { throw new Exception("不是2049说明下载的文件不对:" + file); } var count = Read32(gz); var labels = new byte[count]; gz.Read(labels, 0, count); return labels; } }
将数字标签转化为二维数组:one-hot编码
由于我们的数字为0-9,所以,可以视为有十个class。此时,为了后续的处理方便,我们将数字标签转化为数组。因此,一组标签就转换为了一个二维数组。
例如,标签0变成[1,0,0,0,0,0,0,0,0,0]
标签1变成[0,1,0,0,0,0,0,0,0,0]
以此类推。
/// <summary> /// 将数字标签一维数组转为一个二维数组 /// </summary> /// <param name="labels"></param> /// <param name="numClasses">多少个类别,这里是10(0到9)</param> /// <returns></returns> byte[,] OneHot(byte[] labels, int numClasses) { var oneHot = new byte[labels.Length, numClasses]; for (int i = 0; i < labels.Length; i++) { oneHot[i, labels[i]] = 1; } return oneHot; }
到此为止,数据格式处理就全部结束了。下面的代码展示了数据处理的全过程。
/// <summary> /// 处理数据集 /// </summary> /// <param name="trainDir">数据集所在文件夹</param> /// <param name="numClasses"></param> /// <param name="validationSize">拿出多少做验证?</param> public void ReadDataSets(string trainDir, int numClasses = 10, int validationSize = 5000) { const string SourceUrl = "http://yann.lecun.com/exdb/mnist/"; const string TrainImagesName = "train-images-idx3-ubyte.gz"; const string TrainLabelsName = "train-labels-idx1-ubyte.gz"; const string TestImagesName = "t10k-images-idx3-ubyte.gz"; const string TestLabelsName = "t10k-labels-idx1-ubyte.gz"; //获得训练数据,然后处理训练数据和测试数据 TrainImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TrainImagesName), TrainImagesName); TestImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TestImagesName), TestImagesName); TrainLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TrainLabelsName), TrainLabelsName); TestLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TestLabelsName), TestLabelsName); //拿出前面的一部分做验证 ValidationImages = Pick(TrainImages, 0, validationSize); ValidationLabels = Pick(TrainLabels, 0, validationSize); //拿出剩下的做训练(输入0意味着拿剩下所有的) TrainImages = Pick(TrainImages, validationSize, 0); TrainLabels = Pick(TrainLabels, validationSize, 0); //将数字标签转换为二维数组 //例如,标签3 =》 [0,0,0,1,0,0,0,0,0,0] //标签0 =》 [1,0,0,0,0,0,0,0,0,0] if (numClasses != -1) { OneHotTrainLabels = OneHot(TrainLabels, numClasses); OneHotValidationLabels = OneHot(ValidationLabels, numClasses); OneHotTestLabels = OneHot(TestLabels, numClasses); } } /// <summary> /// 获得source集合中的一部分,从first开始,到last结束 /// </summary> /// <typeparam name="T"></typeparam> /// <param name="source"></param> /// <param name="first"></param> /// <param name="last"></param> /// <returns></returns> T[] Pick<T>(T[] source, int first, int last) { if (last == 0) { last = source.Length; } var count = last - first; var ret = source.Skip(first).Take(count).ToArray(); return ret; } public static Mnist Load() { var x = new Mnist(); x.ReadDataSets(@"D:\人工智能\C#代码\MNISTTensorFlowSharp\MNISTTensorFlowSharp\data"); return x; }
在这里,数据共有下面几部分:
- 训练图片数据55000 TrainImages及对应标签TrainLabels
- 验证图片数据5000 ValidationImages及对应标签ValidationLabels
- 测试图片数据10000 TestImages及对应标签TestLabels
KNN算法的实现
现在,我们已经有了所有的数据在手。需要实现的是:
- 拿出数据中的一部分(例如,5000张图片)作为KNN的训练数据,然后,再从数据中的另一部分拿一张图片A
- 对这张图片A,求它和5000张训练图片的距离,并找出一张训练图片B,它是所有训练图片中,和A距离最小的那张(这意味着K=1)
- 此时,就认为A所代表的数字等同于B所代表的数字b
- 重复1-3,N次
首先进行数据的收集:
//三个Reader分别从总的数据库中获得数据 public BatchReader GetTrainReader() => new BatchReader(TrainImages, TrainLabels, OneHotTrainLabels); public BatchReader GetTestReader() => new BatchReader(TestImages, TestLabels, OneHotTestLabels); public BatchReader GetValidationReader() => new BatchReader(ValidationImages, ValidationLabels, OneHotValidationLabels); /// <summary> /// 数据的一部分,包括了所有的有用信息 /// </summary> public class BatchReader { int start = 0; //图片库 MnistImage[] source; //数字标签 byte[] labels; //oneHot之后的数字标签 byte[,] oneHotLabels; internal BatchReader(MnistImage[] source, byte[] labels, byte[,] oneHotLabels) { this.source = source; this.labels = labels; this.oneHotLabels = oneHotLabels; } /// <summary> /// 返回两个浮点二维数组(C# 7的新语法) /// </summary> /// <param name="batchSize"></param> /// <returns></returns> public (float[,], float[,]) NextBatch(int batchSize) { //一张图 var imageData = new float[batchSize, 784]; //标签 var labelData = new float[batchSize, 10]; int p = 0; for (int item = 0; item < batchSize; item++) { Buffer.BlockCopy(source[start + item].DataFloat, 0, imageData, p, 784 * sizeof(float)); p += 784 * sizeof(float); for (var j = 0; j < 10; j++) labelData[item, j] = oneHotLabels[item + start, j]; } start += batchSize; return (imageData, labelData); } }
然后,在算法中,获取数据:
static void KNN() { //取得数据 var mnist = Mnist.Load(); //拿5000个训练数据,200个测试数据 const int trainCount = 5000; const int testCount = 200; //获得的数据有两个 //一个是图片,它们都是28*28的 //一个是one-hot的标签,它们都是1*10的 (var trainingImages, var trainingLabels) = mnist.GetTrainReader().NextBatch(trainCount); (var testImages, var testLabels) = mnist.GetTestReader().NextBatch(testCount); Console.WriteLine($"MNIST 1NN");
下面进行计算。这里使用了K=1的L1距离。这是最简单的情况。
//建立一个图表示计算任务 using (var graph = new TFGraph()) { var session = new TFSession(graph); //用来feed数据的占位符。trainingInput表示N张用来进行训练的图片,N是一个变量,所以这里使用-1 TFOutput trainingInput = graph.Placeholder(TFDataType.Float, new TFShape(-1, 784)); //xte表示一张用来测试的图片 TFOutput xte = graph.Placeholder(TFDataType.Float, new TFShape(784)); //计算这两张图片的L1距离。这很简单,实际上就是把784个数字逐对相减,然后取绝对值,最后加起来变成一个总和 var distance = graph.ReduceSum(graph.Abs(graph.Sub(trainingInput, xte)), axis: graph.Const(1)); //这里只是用了最近的那个数据 //也就是说,最近的那个数据是什么,那pred(预测值)就是什么 TFOutput pred = graph.ArgMin(distance, graph.Const(0));
最后是开启Session计算的过程:
var accuracy = 0f; //开始循环进行计算,循环trainCount次 for (int i = 0; i < testCount; i++) { var runner = session.GetRunner(); //每次,对一张新的测试图,计算它和trainCount张训练图的距离,并获得最近的那张 var result = runner.Fetch(pred).Fetch(distance) //trainCount张训练图(数据是trainingImages) .AddInput(trainingInput, trainingImages) //testCount张测试图(数据是从testImages中拿出来的) .AddInput(xte, Extract(testImages, i)) .Run(); //最近的点的序号 var nn_index = (int)(long)result[0].GetValue(); //从trainingLabels中找到答案(这是预测值) var prediction = ArgMax(trainingLabels, nn_index); //正确答案位于testLabels[i]中 var real = ArgMax(testLabels, i); //PrintImage(testImages, i); Console.WriteLine($"测试 {i}: " + $"预测: {prediction} " + $"正确答案: {real} (最近的点的序号={nn_index})"); //Console.WriteLine(testImages); if (prediction == real) { accuracy += 1f / testCount; } } Console.WriteLine("准确率: " + accuracy);
对KNN的改进
本文只是对KNN识别MNIST数据集进行了一个非常简单的介绍。在实现了最简单的K=1的L1距离计算之后,正确率约为91%。大家可以试着将算法进行改进,例如取K=2或者其他数,或者计算L2距离等。L2距离的结果比L1好一些,可以达到93-94%的正确率。
以上是关于机器学习 - KNN识别MNIST的主要内容,如果未能解决你的问题,请参考以下文章