统计学习三:2.K近邻法代码实现(以最近邻法为例)

Posted zhiyuxuan

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了统计学习三:2.K近邻法代码实现(以最近邻法为例)相关的知识,希望对你有一定的参考价值。

通过上文可知感知机模型的基本原理,以及算法的具体流程。本文实现了感知机模型算法的原始形式,通过对算法的具体实现,我们可以对算法有进一步的了解。具体代码可以在我的github上查看。

代码

#!/usr/bin/python3
# -*- coding:utf-8 -*-

import sys 
import numpy as np

class Kdtree(object):
    ‘‘‘ 
    类名: Kdtree
    用于存储kd树的数据
    成员:
    __value: 训练数据,保存数据点的坐标
     __type: 保存点对应的类型
      __dim: 保存当前kd树节点的切分平面维度
       left: 左子树
      right: 右子树
    ‘‘‘
    def __init__(self, node = None, node_type = -1, dim = 0, left = None, right = None):
        self.__value = node
        self.__type  = node_type
        self.__dim   = dim 
        self.left    = left
        self.right   = right

    @property
    def type(self):
        return self.__type
        
    @property
    def value(self):
        return self.__value

    @property
    def dim(self):
        return self.__dim

    def distance(self, node):
        ‘‘‘ 
        计算当前节点与传入节点之间的距离
        参数: 
        node: 需要计算距离的节点
        ‘‘‘
        if node == None:
            return sys.maxsize

        dis = 0 
        for i in range(len(self.__value)):
            dis = dis + (self.__value[i] - node.__value[i]) ** 2
        return dis
        
    def build_tree(self, nodes, dim = 0):
        ‘‘‘
        利用训练数据建立一棵kd树
        参数: nodes: 训练数据集
                dim: 树的切分平面维度
        return: a kd-tree
        ‘‘‘
        if len(nodes) == 0:
            return None
        elif len(nodes) == 1:
            self.__dim  = dim
            self.__value = nodes[0][:-1]
            self.__type  = nodes[0][-1]
            return self

        #将数据集按照第dim维度的值的大小进行排序
        sortNodes = sorted(nodes, key = lambda x:x[dim], reverse = False)

        #排序后,中间的点为当前节点值
        midNode      = sortNodes[len(sortNodes) // 2]
        self.__value = midNode[:-1]
        self.__type  = midNode[-1]
        self.__dim   = dim

        leftNodes  = list(filter(lambda x: x[dim] < midNode[dim], sortNodes[:len(sortNodes) // 2]))
        rightNodes = list(filter(lambda x: x[dim] >= midNode[dim], sortNodes[len(sortNodes) // 2 + 1:]))
        nextDim    = (dim + 1) % (len(midNode) - 1)

        self.left  = Kdtree().build_tree(leftNodes, nextDim)
        self.right = Kdtree().build_tree(rightNodes, nextDim)

        return self

    def find_type(self, fnode):
        ‘‘‘
        在kd树内查找传入点的最近邻点和对应的类型
        参数: fnode: 需要判断类型的点
        return: fnode的最近邻点和其类型
        ‘‘‘
        if fnode == None:
            return self, -1

        fNode = Kdtree(fnode)

        #首先搜索整棵树到达叶子节点
        path = []
        currentNode = self
        while currentNode != None:
            path.append(currentNode)

            dim   = currentNode.__dim
            if fNode.value[dim] < currentNode.value[dim]:
                currentNode = currentNode.left
            else:
                currentNode = currentNode.right

        #path的最后一个节点即为叶子节点
        nearestNode = path[-1]
        nearestDist = fNode.distance(nearestNode)
        path = path[:-1]

        #向上进行回溯
        while path != None and len(path) > 0:
            currentNode = path[-1]
            path = path[:-1]
            dim  = currentNode.__dim
            
            #判断当前点是否比最近点更近
            if fNode.distance(currentNode) < nearestDist:
                nearestNode = currentNode
                nearestDist = fNode.distance(currentNode)

            #当前最近点一定存在于当前点的一棵子树上,那么找到它的兄弟子树的节点
            brotherNode = currentNode.left
            if fNode.value[dim] < currentNode.value[dim]:
                brotherNode = currentNode.right

            if brotherNode == None:
                continue

            #若兄弟子树的节点对应的区域与以fnode为圆心,以nearestDist为半径的圆相交,则进入兄弟子树,进行递归查找
            bdim = brotherNode.__dim
            if np.abs(fnode[bdim] - brotherNode.__value[bdim]) < nearestDist:
                cNode, _ = brotherNode.find_type(fnode)
                if fNode.distance(cNode) < nearestDist:
                    nearestDist = fNode.distance(cNode)
                    nearestNode = cNode

        return nearestNode, nearestNode.type

if __name__ == "__main__":

   #训练数据集
   trainArray = [[1.0, 1.0, ‘a‘], [1.1, 1.1, ‘a‘], [1.5, 1.5, ‘a‘],            [5.0, 5.0, ‘b‘], [5.2, 5.2, ‘b‘], [5.5, 5.5, ‘b‘],            [3.0, 2.5, ‘c‘], [3.1, 2.8, ‘c‘], [3.2, 2.4, ‘c‘]]

   kdtree = Kdtree().build_tree(trainArray)

   #test1
   testNode = [1.6, 1.5]
   _, testType = kdtree.find_type(testNode)
   print("the type of ", testNode, "is ", testType)

   #test2
   testNode = [3.5, 2.7]
   _, testType = kdtree.find_type(testNode)
   print("the type of ", testNode, "is ", testType)

   #test3
   testNode = [4.3, 5.1]
   _, testType = kdtree.find_type(testNode)
   print("the type of ", testNode, "is ", testType)

测试结果

技术分享图片

通过测试结果可知,kd树可以有效地对输入数据进行类型的识别。

讨论

虽然通过测试结果正确,但代码依然存在许多需要改进的地方,如kd树的选择,可以通过改进为红黑平衡树,来提高搜索速度。以及对于树的每层切分平面的维度选择,可以选择各维度中方差最大的维度,这样在此维度下的点分布更加分散,使后续的查找难度更小等等。

以上是关于统计学习三:2.K近邻法代码实现(以最近邻法为例)的主要内容,如果未能解决你的问题,请参考以下文章

统计学习方法与Python实现——k近邻法

k近邻(KNN)复习总结

机器学习笔记——K近邻法

K近邻法机器学习

机器学习笔记三 K近邻法

统计学习方法 (第3章)K近邻法 学习笔记