kdTree
Posted 小丫头い
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了kdTree相关的知识,希望对你有一定的参考价值。
什么是K近邻法
- K近邻(k-nearest neighbor,k-NN)算法简单、直观:给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最近邻的k个实例,这k个实例的多数属于某个类,就把该输入实例分为这个类。因此,k近邻法不具有显示的学习过程。K近邻法实际上利用训练数据集对特征向量空间进行划分,并作为其分类的“模型”。K值的选择、距离度量及分类决策规则是k近邻法的三个基本要素。
kd树
- 今天好好看了下K近邻法,这是一个很简单很直观的算法,但是实现起来,如果使用线性扫描,那是相当的耗时,不可取,因此在书中看到了kd树算法。书中的算法是这么描述的:
- 写了一早上终于写好了,我真不愧是慢慢~~~放代码咯
package meachineLearning;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Queue;
class KDTreeNode
double[] val;
KDTreeNode left;
KDTreeNode right;
KDTreeNode(double[] val)
this.val = val;
public class KDTree
//下面这段代码利用快速排序,分治的策略找到中位数,若当前base最后放置的位置i==median则直接返回,若i>median,那么中位数一定在start到i-1之间
public double[] getMedian(double[][] data,int cutPoint, int start,int end,int median)
double[] base = data[start];
int i = start;
int j = end;
if(i<j)
while(i<j)
while(i<j&&data[j][cutPoint]>=base[cutPoint])
j--;
if(i<j)
data[i] = data[j];
i++;
while(i<j&&data[i][cutPoint]<=base[cutPoint])
i++;
if(i<j)
data[j] = data[i];
j--;
data[i] = base;
if(i==median)
return data[i];
if(i<median)
return getMedian(data,cutPoint,i+1,end,median);
else
return getMedian(data,cutPoint,start,i-1,median);
return data[i];
//创建kdTree,返回树的根节点
public KDTreeNode createKDTree(double[][] data,int start,int end,int depth)
if(start>end)
return null;
int cutPoint = depth%2; //切分的维度,深度为depth节点以depth%k为切分点,此处k=2
double[] meadin = getMedian(data,cutPoint,start,end,(end+start+1)/2); //划分节点
KDTreeNode root = new KDTreeNode(meadin);
//原数据被排序后变为[[2, 3], [4, 7], [5, 4], [7, 2], [8, 1], [9, 6]];[7,2]是中位数,左边子数组进入左节点递归,右边子数组进入右节点递归
root.left = createKDTree(data,start,(end+start+1)/2-1,depth+1);
root.right = createKDTree(data,(end+start+1)/2+1,end,depth+1);
return root;
//使用kd树进行最近邻搜索,只用找到距离最近的那个节点
public KDTreeNode kdSearch(KDTreeNode root,KDTreeNode node,int depth)
if(root==null)
return null;
if(root.left==null&&root.right==null)
return root;
int cutPoint = depth%2;
KDTreeNode curNearest;
if(node.val[cutPoint]<root.val[cutPoint])
curNearest = kdSearch(root.left,node,depth+1);
else
curNearest = kdSearch(root.right,node,depth+1);
//下面进行两个判断是因为不知道当前叶子节点是在左子树还是右子树,所以干脆两个都判断一下,可能会进入到兄弟节点分支进行递归
if(root.left!=null&&distance(root.left.val,node.val)<distance(curNearest.val,node.val))
curNearest = kdSearch(root.left,node,depth+1);
else if(root.right!=null&&distance(root.right.val,node.val)<distance(curNearest.val,node.val))
curNearest = kdSearch(root.right,node,depth+1);
return curNearest;
public double distance(double[] val,double[] val2)
double sum = 0;
for(int i=0;i<val.length;i++)
sum += (val[i]-val2[i])*(val[i]-val2[i]);
return Math.sqrt(sum);
public static void main(String[] args)
double[][] data = 2,3,5,4,9,6,4,7,8,1,7,2;
KDTreeNode root = new KDTree().createKDTree(data,0,data.length-1,0);
//层序遍历所有的节点
Queue<KDTreeNode> q = new LinkedList<KDTreeNode>();
q.add(root);
while(!q.isEmpty())
KDTreeNode node = q.poll();
//System.out.println(Arrays.toString(node.val));
if(node.left!=null) q.offer(node.left);
if(node.right!=null) q.offer(node.right);
KDTreeNode node = new KDTreeNode(new double[]3,4.5);
KDTreeNode nearest = new KDTree().kdSearch(root,node,0);
System.out.println(Arrays.toString(nearest.val));
kd树的搜索策略是酱紫的:因为树就是分治的建的,所以搜索也是二分着找,区别在于找到当前分支的最近邻还要判断一下它的兄弟节点是否满足条件,如果兄弟节点比本节点更近,则进入到兄弟节点分支进行递归,否则回退到父节点。
实际应用
- 实际在使用它时我不会自己去写,sklearn包中调用如此简单。自己写一下加深理解,锻炼动手能力。
>>> X = [[0], [1], [2], [3]] # 数据集特征向量,一维特征
>>> y = [0, 0, 1, 1] # 数据集的label,我的代码里没有涉及这个
>>> from sklearn.neighbors import KNeighborsClassifier
>>> neigh = KNeighborsClassifier(n_neighbors=3)
>>> neigh.fit(X, y)
KNeighborsClassifier(...)
>>> print(neigh.predict([[1.1]]))
[0] # 预测该数据样本属于0类
>>> print(neigh.predict_proba([[0.9]]))
[[ 0.66666667 0.33333333]] #预测0.9属于0类的概率为0.66666667,属于1类的概率为0.33333333。
以上是关于kdTree的主要内容,如果未能解决你的问题,请参考以下文章