统计学习三:2.K近邻法代码实现(以最近邻法为例)
Posted zhiyuxuan
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了统计学习三:2.K近邻法代码实现(以最近邻法为例)相关的知识,希望对你有一定的参考价值。
通过上文可知感知机模型的基本原理,以及算法的具体流程。本文实现了感知机模型算法的原始形式,通过对算法的具体实现,我们可以对算法有进一步的了解。具体代码可以在我的github上查看。
代码
#!/usr/bin/python3
# -*- coding:utf-8 -*-
import sys
import numpy as np
class Kdtree(object):
‘‘‘
类名: Kdtree
用于存储kd树的数据
成员:
__value: 训练数据,保存数据点的坐标
__type: 保存点对应的类型
__dim: 保存当前kd树节点的切分平面维度
left: 左子树
right: 右子树
‘‘‘
def __init__(self, node = None, node_type = -1, dim = 0, left = None, right = None):
self.__value = node
self.__type = node_type
self.__dim = dim
self.left = left
self.right = right
@property
def type(self):
return self.__type
@property
def value(self):
return self.__value
@property
def dim(self):
return self.__dim
def distance(self, node):
‘‘‘
计算当前节点与传入节点之间的距离
参数:
node: 需要计算距离的节点
‘‘‘
if node == None:
return sys.maxsize
dis = 0
for i in range(len(self.__value)):
dis = dis + (self.__value[i] - node.__value[i]) ** 2
return dis
def build_tree(self, nodes, dim = 0):
‘‘‘
利用训练数据建立一棵kd树
参数: nodes: 训练数据集
dim: 树的切分平面维度
return: a kd-tree
‘‘‘
if len(nodes) == 0:
return None
elif len(nodes) == 1:
self.__dim = dim
self.__value = nodes[0][:-1]
self.__type = nodes[0][-1]
return self
#将数据集按照第dim维度的值的大小进行排序
sortNodes = sorted(nodes, key = lambda x:x[dim], reverse = False)
#排序后,中间的点为当前节点值
midNode = sortNodes[len(sortNodes) // 2]
self.__value = midNode[:-1]
self.__type = midNode[-1]
self.__dim = dim
leftNodes = list(filter(lambda x: x[dim] < midNode[dim], sortNodes[:len(sortNodes) // 2]))
rightNodes = list(filter(lambda x: x[dim] >= midNode[dim], sortNodes[len(sortNodes) // 2 + 1:]))
nextDim = (dim + 1) % (len(midNode) - 1)
self.left = Kdtree().build_tree(leftNodes, nextDim)
self.right = Kdtree().build_tree(rightNodes, nextDim)
return self
def find_type(self, fnode):
‘‘‘
在kd树内查找传入点的最近邻点和对应的类型
参数: fnode: 需要判断类型的点
return: fnode的最近邻点和其类型
‘‘‘
if fnode == None:
return self, -1
fNode = Kdtree(fnode)
#首先搜索整棵树到达叶子节点
path = []
currentNode = self
while currentNode != None:
path.append(currentNode)
dim = currentNode.__dim
if fNode.value[dim] < currentNode.value[dim]:
currentNode = currentNode.left
else:
currentNode = currentNode.right
#path的最后一个节点即为叶子节点
nearestNode = path[-1]
nearestDist = fNode.distance(nearestNode)
path = path[:-1]
#向上进行回溯
while path != None and len(path) > 0:
currentNode = path[-1]
path = path[:-1]
dim = currentNode.__dim
#判断当前点是否比最近点更近
if fNode.distance(currentNode) < nearestDist:
nearestNode = currentNode
nearestDist = fNode.distance(currentNode)
#当前最近点一定存在于当前点的一棵子树上,那么找到它的兄弟子树的节点
brotherNode = currentNode.left
if fNode.value[dim] < currentNode.value[dim]:
brotherNode = currentNode.right
if brotherNode == None:
continue
#若兄弟子树的节点对应的区域与以fnode为圆心,以nearestDist为半径的圆相交,则进入兄弟子树,进行递归查找
bdim = brotherNode.__dim
if np.abs(fnode[bdim] - brotherNode.__value[bdim]) < nearestDist:
cNode, _ = brotherNode.find_type(fnode)
if fNode.distance(cNode) < nearestDist:
nearestDist = fNode.distance(cNode)
nearestNode = cNode
return nearestNode, nearestNode.type
if __name__ == "__main__":
#训练数据集
trainArray = [[1.0, 1.0, ‘a‘], [1.1, 1.1, ‘a‘], [1.5, 1.5, ‘a‘], [5.0, 5.0, ‘b‘], [5.2, 5.2, ‘b‘], [5.5, 5.5, ‘b‘], [3.0, 2.5, ‘c‘], [3.1, 2.8, ‘c‘], [3.2, 2.4, ‘c‘]]
kdtree = Kdtree().build_tree(trainArray)
#test1
testNode = [1.6, 1.5]
_, testType = kdtree.find_type(testNode)
print("the type of ", testNode, "is ", testType)
#test2
testNode = [3.5, 2.7]
_, testType = kdtree.find_type(testNode)
print("the type of ", testNode, "is ", testType)
#test3
testNode = [4.3, 5.1]
_, testType = kdtree.find_type(testNode)
print("the type of ", testNode, "is ", testType)
测试结果
通过测试结果可知,kd树可以有效地对输入数据进行类型的识别。
讨论
虽然通过测试结果正确,但代码依然存在许多需要改进的地方,如kd树的选择,可以通过改进为红黑平衡树,来提高搜索速度。以及对于树的每层切分平面的维度选择,可以选择各维度中方差最大的维度,这样在此维度下的点分布更加分散,使后续的查找难度更小等等。
以上是关于统计学习三:2.K近邻法代码实现(以最近邻法为例)的主要内容,如果未能解决你的问题,请参考以下文章