K近邻法

Posted Tao-Coder

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了K近邻法相关的知识,希望对你有一定的参考价值。

K近邻算法

  给定一个训练数据集,对新的输入实例,在训练数据集中找到跟它最近的k个实例,根据这k个实例的类判断它自己的类(一般采用多数表决的方法)。

  算法详解:

  输入:训练数据集

            

  其中,为实例的特征向量,为实例的类别, 实例特征向量 ;

  输出:实例所属的类 ;

  (1)根据给定的距离度量,在训练集中找出与最邻近的个点,涵盖这个点的的邻域记作 ;

  (2)在中根据分类决策规则(如多数表决)决定的类别

             

  注:为指示函数,即当为1,否则为0

 

K近邻模型

  模型三要素:距离度量方法k值的选择分类决策规则

  当三要素确定的时候,对任何实例(训练或输入),它所属的类都是确定的,相当于将特征空间分为一些子空间。

                                                            

距离度量

对n维实数向量空间Rn,经常用Lp距离或曼哈顿Minkowski距离。

Lp距离定义如下:

              

  当p=2时,称为欧氏距离:

                

  当p=1时,称为曼哈顿距离:  

   

                 

   当p=∞,它是各个坐标距离的最大值,即:

                

  用图表示如下:

                     

  

k值的选择

k较小,容易被噪声影响,发生过拟合。

k较大,较远的训练实例也会对预测起作用,容易发生错误。

分类决策规则

使用0-1损失函数衡量,那么误分类率是:

              

 

    是近邻集合,要使左边最小,右边的必须最大,所以多数表决=经验最小化 。

 

k近邻法的实现:kd树

算法核心在于怎么快速搜索k个近邻出来,朴素做法是线性扫描,不可取,这里介绍的方法是kd树。

构造kd树

对数据集T中的子集S初始化S=T,取当前节点node=root取维数的序数i=0,对S递归执行:找出S的第i维的中位数对应的点,通过该点,且垂直于第i维坐标轴做一个超平面。该点加入node的子节点。该超平面将空间分为两个部分,对这两个部分分别重复此操作(S=S\',++i,node=current),直到不可再分。

 

例子:给定一个二维空间的数据集

      

     构造一个平衡kd树 。

 

Python代码实现:

  

# -*- coding:utf-8 -*-

import copy
import itertools
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib import animation

T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]


def draw_point(data):
    X, Y = [], []
    for p in data:
        X.append(p[0])
        Y.append(p[1])
    plt.plot(X, Y, \'bo\')


def draw_line(xy_list):
    for xy in xy_list:
        x, y = xy
        plt.plot(x, y, \'g\', lw=2)


def draw_square(square_list):
    currentAxis = plt.gca()
    colors = itertools.cycle(["r", "b", "g", "c", "m", "y", \'#EB70AA\', \'#0099FF\'])
    for square in square_list:
        currentAxis.add_patch(
            Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
                      color=next(colors)))


def median(lst):
    m = len(lst) / 2
    return lst[m], m


history_quare = []


def build_kdtree(data, d, square):
    history_quare.append(square)
    data = sorted(data, key=lambda x: x[d])
    p, m = median(data)

    del data[m]
    print data, p

    if m >= 0:
        sub_square = copy.deepcopy(square)
        if d == 0:
            sub_square[1][0] = p[0]
        else:
            sub_square[1][1] = p[1]
        history_quare.append(sub_square)
        if m > 0: build_kdtree(data[:m], not d, sub_square)
    if len(data) > 1:
        sub_square = copy.deepcopy(square)
        if d == 0:
            sub_square[0][0] = p[0]
        else:
            sub_square[0][1] = p[1]
        build_kdtree(data[m:], not d, sub_square)


build_kdtree(T, 0, [[0, 0], [10, 10]])
print history_quare

# draw an animation to show how it works, the data comes from history
# first set up the figure, the axis, and the plot element we want to animate
fig = plt.figure()
ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
line, = ax.plot([], [], \'g\', lw=2)
label = ax.text([], [], \'\')


# initialization function: plot the background of each frame
def init():
    plt.axis([0, 10, 0, 10])
    plt.grid(True)
    plt.xlabel(\'x_1\')
    plt.ylabel(\'x_2\')
    plt.title(\'build kd tree (www.hankcs.com)\')
    draw_point(T)


currentAxis = plt.gca()
colors = itertools.cycle(["#FF6633", "g", "#3366FF", "c", "m", "y", \'#EB70AA\', \'#0099FF\', \'#66FFFF\'])


# animation function.  this is called sequentially
def animate(i):
    square = history_quare[i]
    currentAxis.add_patch(
        Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
                  color=next(colors)))
    return


# call the animator.  blit=true means only re-draw the parts that have changed.
anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(history_quare), interval=1000, repeat=False,
                               blit=False)
plt.show()
anim.save(\'kdtree_build.gif\', fps=2, writer=\'imagemagick\')
View Code

 

 

 

搜索kd树

上面的代码其实并没有搜索kd树,现在来实现搜索。

搜索跟二叉树一样来,是一个递归的过程。先找到目标点的插入位置,然后往上走,逐步用自己到目标点的距离画个超球体,用超球体圈住的点来更新最近邻(或k最近邻)。以最近邻为例,实现如下(本实现由于测试数据简单,没有做超球体与超立体相交的逻辑):

 

    # -*- coding:utf-8 -*-
     
    T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
     
     
    class node:
        def __init__(self, point):
            self.left = None
            self.right = None
            self.point = point
            self.parent = None
            pass
     
        def set_left(self, left):
            if left == None: pass
            left.parent = self
            self.left = left
     
        def set_right(self, right):
            if right == None: pass
            right.parent = self
            self.right = right
     
     
    def median(lst):
        m = len(lst) / 2
        return lst[m], m
     
     
    def build_kdtree(data, d):
        data = sorted(data, key=lambda x: x[d])
        p, m = median(data)
        tree = node(p)
     
        del data[m]
     
        if m > 0: tree.set_left(build_kdtree(data[:m], not d))
        if len(data) > 1: tree.set_right(build_kdtree(data[m:], not d))
        return tree
     
     
    def distance(a, b):
        print a, b
        return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5
     
     
    def search_kdtree(tree, d, target):
        if target[d] < tree.point[d]:
            if tree.left != None:
                return search_kdtree(tree.left, not d, target)
        else:
            if tree.right != None:
                return search_kdtree(tree.right, not d, target)
     
        def update_best(t, best):
            if t == None: return
            t = t.point
            d = distance(t, target)
            if d < best[1]:
                best[1] = d
                best[0] = t
     
     
        best = [tree.point, 100000.0]
        while (tree.parent != None):
            update_best(tree.parent.left, best)
            update_best(tree.parent.right, best)
            tree = tree.parent
        return best[0]
     
     
    kd_tree = build_kdtree(T, 0)
    print search_kdtree(kd_tree, 0, [9, 4])
View Code

 

 输出:

[8, 1] [9, 4]
[5, 4] [9, 4]
[9, 6] [9, 4]
[9, 6]

可见对于点[9, 4],在n=6的数据集中,kdtree算法一共只进行了3次计算。

 

K-近邻算法实现手写数字识别     

  1 #-*- coding: utf-8 -*-
  2 
  3 from numpy import *
  4 from os import listdir
  5 import operator
  6 
  7 # 读取数据到矩阵
  8 def img2vector(filename):
  9     # 创建向量
 10     returnVect = zeros((1,1024))
 11     
 12     # 打开数据文件,读取每行内容
 13     fr = open(filename)
 14 
 15     for i  in range(32):
 16         # 读取每一行
 17         lineStr = fr.readline()
 18         
 19         # 将每行前32字符转成int存入向量
 20         for j in range(32):
 21             returnVect[0,32*i+j] = int(lineStr[j])
 22 
 23     return returnVect
 24 
 25 # kNN算法实现    
 26 def classify0(inX, dataSet, labels, k):
 27     # 获取样本数据数量
 28     dataSetSize = dataSet.shape[0]
 29     
 30     # 矩阵运算,计算测试数据与每个样本数据对应数据项的差值
 31     diffMat = tile(inX, (dataSetSize,1)) - dataSet
 32     
 33     # sqDistances 上一步骤结果平方和
 34     sqDiffMat = diffMat**2
 35     sqDistances = sqDiffMat.sum(axis=1)
 36     
 37     # 取平方根,得到距离向量
 38     distances = sqDistances**0.5
 39     
 40     # 按照距离从低到高排序
 41     sortedDistIndicies = distances.argsort()     
 42     classCount={}          
 43     
 44     # 依次取出最近的样本数据
 45     for i in range(k):
 46         # 记录该样本数据所属的类别
 47         voteIlabel = labels[sortedDistIndicies[i]]
 48         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
 49     
 50     # 对类别出现的频次进行排序,从高到低
 51     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
 52     
 53     # 返回出现频次最高的类别
 54     return sortedClassCount[0][0]
 55 
 56 # 算法测试    
 57 def handwritingClassTest():
 58     # 样本数据的类标签列表
 59     hwLabels = []
 60     
 61     # 样本数据文件列表
 62     trainingFileList = listdir(\'digits/trainingDigits\')
 63     m = len(trainingFileList)
 64     
 65     # 初始化样本数据矩阵(M*1024)
 66     trainingMat = zeros((m,1024))
 67     
 68     # 依次读取所有样本数据到数据矩阵
 69     for i in range(m):
 70         # 提取文件名中的数字
 71         fileNameStr = trainingFileList[i]
 72         fileStr = fileNameStr.split(\'.\')[0]
 73         classNumStr = int(fileStr.split(\'_\')[0])
 74         hwLabels.append(classNumStr)
 75         
 76         # 将样本数据存入矩阵
 77         trainingMat[i,:] = img2vector(\'digits/trainingDigits/%s\' % fileNameStr)
 78     
 79     # 循环读取测试数据
 80     testFileList = listdir(\'digits/testDigits\')
 81     
 82     # 初始化错误率
 83     errorCount = 0.0
 84     mTest = len(testFileList)
 85     
 86     # 循环测试每个测试数据文件
 87     for i in range(mTest):
 88         # 提取文件名中的数字
 89         fileNameStr = testFileList[i]
 90         fileStr = fileNameStr.split(\'.\')[0]
 91         classNumStr = int(fileStr.split(\'_\')[0])
 92         
 93         # 提取数据向量
 94         vectorUnderTest = img2vector(\'digits/testDigits/%s\' % fileNameStr)
 95         
 96         # 对数据文件进行分类
 97         classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
 98         
 99         # 打印KNN算法分类结果和真实的分类
100         print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
101         
102         # 判断KNN算法结果是否准确
103         if (classifierResult != classNumStr): errorCount += 1.0
104     
105     # 打印错误率
106     print "\\nthe total number of errors is: %d" % errorCount
107     print "\\nthe total error rate is: %f" % (errorCount/float(mTest))
108 
109 # 执行算法测试
110 handwritingClassTest()
View Code

 

以上是关于K近邻法的主要内容,如果未能解决你的问题,请参考以下文章

,k 近邻法

3.K近邻法

k近邻法

机器学习入门之K近邻法

机器学习入门之K近邻法

K近邻法