tensorflow-实现knn算法-识别mnist数据集

Posted wyply115

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow-实现knn算法-识别mnist数据集相关的知识,希望对你有一定的参考价值。

概述

Mnist数据集被分成两部分:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)。每一张图片包含28像素X28像素的灰度图片。
我们要做的是对测试数据集的每一个数据从训练数据集中找出最临近的类别,进行预测。
那么如何找最临近的(距离),我们通过计算L1距离或L2距离来计算最临近。

实现步骤

  • 1.获取mnist数据
  • 2.计算距离(L1或L2)
  • 3.获取最小距离得索引(计算准确度用)
  • 4.开启会话,初始化变量op
  • 5.循环对每一个测试数据分别查找和训练数据集的最小距离,得出索引。
  • 6.计算准确度(预测值和真实值相等的数据/所有测试数据)

实现代码

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np


def knn_tensorflow():
    """ tensorflow实现knn算法,对mnist数据识别分类
    :return None
    """
    mnist = input_data.read_data_sets("./data/mnist/", one_hot=True)

    # 数据全部取出, 普通pc计算需要半小时左右,如果嫌太慢,可以少取一些数据。
    train_x, train_y = mnist.train.next_batch(60000)
    test_x, test_y = mnist.test.next_batch(10000)

    # 占位符
    train_x_p = tf.placeholder(tf.float32, [None, 784])
    test_x_p = tf.placeholder(tf.float32, [784])

    # L1距离计算:dist = sum(|X1-X2|)
    #dist_l1 = tf.reduce_sum(tf.abs(train_x_p + tf.negative(test_x_p)), reduction_indices=1)

    # L2距离计算:dist = sqrt(sum(|X1-X2|^2))
    dist_l2 = tf.sqrt(tf.reduce_sum(tf.square(tf.abs(train_x_p + tf.negative(test_x_p))), reduction_indices=1))

    # 获得最小距离的索引
    prediction = tf.arg_min(dist_l2, 0)

    # 定义准确率
    accuracy = 0.

    init_op = tf.initialize_all_variables()

    with tf.Session() as sess:
        sess.run(init_op)

        for i in range(len(test_x)):
            # 获取最近邻的值得索引
            nn_index = sess.run(prediction, feed_dict=train_x_p: train_x, test_x_p: test_x[i, :])
            print("测试集第 %d 条,实际值:%d,预测值:%d" % (i, np.argmax(test_y[i]), np.argmax(train_y[nn_index])))

            # 当预测值==真实值时,计算准确率。
            if np.argmax(test_y[i]) == np.argmax(train_y[nn_index]):
                accuracy += 1. / len(test_x)

        print("准确率:%f " % accuracy)

    return None

if __name__ == '__main__':
    knn_tensorflow()

L1距离输出:

L2距离输出:

以上是knn算法的tensorflow实现,但大家可以看出,并没有k值的调整,以上默认k值是1,那假设k值是3或者5的时候会不会更加准确,可以自己尝试下。

注:tf.reduce_sum()方法中有个reduction_indices参数,表示函数的处理纬度,默认为None,即会把input_tensor降到0维,即1个数值。如下图所示(图片来自网络):

以上是关于tensorflow-实现knn算法-识别mnist数据集的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow实现knn(k近邻)算法

机器学习KNN算法实现手写板字迹识别

机器学习 - TensorflowSharp 简单使用与KNN识别MNIST流程

TensorFlow TensorFlow图像识别(KNN)

KNN算法实现数字识别

基于OpenCV的KNN算法实现手写数字识别