KD Tree算法
Posted ZisZ
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了KD Tree算法相关的知识,希望对你有一定的参考价值。
参考:http://blog.csdn.net/v_july_v/article/details/8203674
#!/user/bin/env python # -*- coding:utf8 -*- __author__ = ‘[email protected]‘ import sys import numpy import heapq import Queue class KDNode(object): def __init__(self, name, feature): self.name = name self.ki = -1 self.is_leaf = False self.feature = feature self.kd_left = None self.kd_right = None def traverse(self, seq, order=‘in‘): if order == ‘in‘: if self.kd_left: self.kd_left.traverse(seq, order) seq.append(self) if self.kd_right: self.kd_right.traverse(seq, order) elif order == ‘pre‘: seq.append(self) if self.kd_left: self.kd_left.traverse(seq, order) if self.kd_right: self.kd_right.traverse(seq, order) elif order == ‘post‘: if self.kd_left: self.kd_left.traverse(seq, order) if self.kd_right: self.kd_right.traverse(seq, order) seq.append(self) else: assert(False) class NodeDistance(object): def __init__(self, kd_node, distance): self.kd_node = kd_node self.distance = distance # here i use a reversed result, because heapq can support only min heap def __cmp__(self, other): ret = other.distance - self.distance if ret > 0: return 1 elif ret < 0: return -1 else: return 0 def euclidean_distance(node1, node2): assert len(node1.feature) == len(node2.feature) sum = 0 for i in xrange(len(node1.feature)): sum += numpy.square(node1.feature[i] - node2.feature[i]) return numpy.sqrt(sum) class KDTree(object): # n is num of dimension def __init__(self, nodes, n): self.root = self.build_kdtree(nodes, n) self.n = n def build_kdtree(self, nodes, n): if len(nodes) == 0: return None max_var = 0 index = 0 for i in xrange(n): features_n = map(lambda node : node.feature[i], nodes) var = numpy.var(features_n) if var > max_var: max_var = var index = i sorted_nodes = sorted(nodes, key=lambda node: node.feature[index]) mid = len(sorted_nodes)/2 root = sorted_nodes[mid] left_nodes = sorted_nodes[:mid] right_nodes = sorted_nodes[mid+1:] root.ki = index if len(left_nodes) == 0 and len(right_nodes) == 0: root.is_leaf = True root.kd_left = self.build_kdtree(left_nodes, n) root.kd_right = self.build_kdtree(right_nodes, n) return root def traverse_kdtree(self, order=‘in‘): seq = [] self.root.traverse(seq, order) print map(lambda n : n.name, seq) # return a list of NodeDistance sorded by distance def kdtree_bbf_knn(self, target, k): if len(target.feature) != self.n: return None knn = [] priority_queue = Queue.LifoQueue() priority_queue.put(self.root) while not priority_queue.empty(): expl = priority_queue.get() while expl: ki = expl.ki kv = expl.feature[ki] if expl.name != target.name: # ignore target node itself # save a maybe result distance = euclidean_distance(expl, target) nd = NodeDistance(expl, distance) assert len(knn) <= k if len(knn) == k: if distance < knn[0].distance: heapq.heapreplace(knn, nd) else: # len(knn) < k heapq.heappush(knn, nd) unexpl = None # find next expl if target.feature[ki] <= kv: # left unexpl = expl.kd_right expl = expl.kd_left else: unexpl = expl.kd_left expl = expl.kd_right # ignore nodes over a long distance bin if unexpl: # save a maybe next expl if len(knn) < k: priority_queue.put(unexpl) elif (len(knn) == k) and (abs(kv - target.feature[ki]) < knn[0].distance): priority_queue.put(unexpl) ret = [] for i in xrange(len(knn)): node = heapq.heappop(knn) ret.insert(0, node) return ret if __name__ == ‘__main__‘: f1 = [7, 2] f2 = [5, 4] f3 = [9, 6] f4 = [2, 3] f5 = [4, 7] f6 = [8, 1] fx = [2, 4.5] n1 = KDNode(‘f1‘, f1) n2 = KDNode(‘f2‘, f2) n3 = KDNode(‘f3‘, f3) n4 = KDNode(‘f4‘, f4) n5 = KDNode(‘f5‘, f5) n6 = KDNode(‘f6‘, f6) nx = KDNode(‘fx‘, fx) n1_distance = NodeDistance(n4, 1.5) n2_distance = NodeDistance(n5, 3.2) n3_distance = NodeDistance(n2, 3.04) assert n1_distance > n2_distance assert n1_distance > n3_distance assert n2_distance < n3_distance tree = KDTree([n1, n2, n3, n4, n5, n6, nx], 2) tree.traverse_kdtree(‘in‘) knn = tree.kdtree_bbf_knn(nx, 3) print map(lambda n : (n.kd_node.name, n.distance), knn)
以上是关于KD Tree算法的主要内容,如果未能解决你的问题,请参考以下文章