kd tree学习 (最近邻域查询)

Posted Dirge

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了kd tree学习 (最近邻域查询)相关的知识,希望对你有一定的参考价值。

https://zhuanlan.zhihu.com/p/22557068

http://blog.csdn.net/zhjchengfeng5/article/details/7855241

 

KD树在算法竞赛中主要用来做各种各样的平面区域查询,包含则累加直接返回,相交则继续递归,相离的没有任何贡献也直接返回。可以处理圆,三角形,矩形等判断起来相对容易的平面区域内的符合加法性质的操作。

比如查询平面内欧几里得距离最近的点的距离。

kdtree其实有点像搜索,暴力+剪枝。

每次从根结点向下搜索,并进行剪枝操作,判断是否有必要继续搜索。

它是通过横一刀,竖一刀,横一刀再竖一刀将平面进行分割,建立二叉树。

建树的复杂度是O(nlogn), 每次用nth_element()在线性时间内取出中位数。 T(n) = 2T(n/2) + O(n) = O(nlogn)

查询复杂度呢? 据第二个链接的博客说最坏是O( sqrt(n) ) 的。并不会分析查询复杂度。

 

HDU2966 裸kdtree

题意:给平面图上N(≤ 100000)个点,对每个点,找到其他 欧几里德距离 离他最近的点,输出他们之间的距离。保证没有重点。

 1 #include <bits/stdc++.h>
 2 #define ll long long
 3 using namespace std;
 4 #define N 200010
 5 const ll inf = 1e18;
 6 int n,i,id[N],root,cmp_d;
 7 int x, y;
 8 struct node{int d[2],l,r,Max[2],Min[2],val,sum,f;}t[N];
 9 bool cmp(const node&a,const node&b){return a.d[cmp_d]<b.d[cmp_d];}
10 void umax(int&a,int b){if(a<b)a=b;}
11 void umin(int&a,int b){if(a>b)a=b;}
12 void up(int x){
13     if(t[x].l){
14         umax(t[x].Max[0],t[t[x].l].Max[0]);
15         umin(t[x].Min[0],t[t[x].l].Min[0]);
16         umax(t[x].Max[1],t[t[x].l].Max[1]);
17         umin(t[x].Min[1],t[t[x].l].Min[1]);
18     }
19     if(t[x].r){
20         umax(t[x].Max[0],t[t[x].r].Max[0]);
21         umin(t[x].Min[0],t[t[x].r].Min[0]);
22         umax(t[x].Max[1],t[t[x].r].Max[1]);
23         umin(t[x].Min[1],t[t[x].r].Min[1]);
24     }
25 }
26 int build(int l,int r,int D,int f){
27     int mid=(l+r)>>1;
28     cmp_d=D,std::nth_element(t+l,t+mid,t+r+1,cmp);
29     id[t[mid].f]=mid;
30     t[mid].f=f;
31     t[mid].Max[0]=t[mid].Min[0]=t[mid].d[0];
32     t[mid].Max[1]=t[mid].Min[1]=t[mid].d[1];
33     //t[mid].val=t[mid].sum=0;
34     if(l!=mid)t[mid].l=build(l,mid-1,!D,mid);else t[mid].l=0;
35     if(r!=mid)t[mid].r=build(mid+1,r,!D,mid);else t[mid].r=0;
36     return up(mid),mid;
37 }
38 
39 ll dis(ll x1, ll y1, ll x, ll y) {
40     ll xx = x1-x, yy = y1-y;
41     return xx*xx+yy*yy;
42 }
43 ll dis(int p, ll x, ll y){//估价函数, 以p为子树的最小距离
44     ll xx = 0, yy = 0;
45     if(t[p].Max[0] < x) xx = x-t[p].Max[0];
46     if(t[p].Min[0] > x) xx = t[p].Min[0]-x;
47     if(t[p].Max[1] < y) yy = y-t[p].Max[1];
48     if(t[p].Min[1] > y) yy = t[p].Min[1]-y;
49     return xx*xx+yy*yy;
50 }
51 ll ans;
52 void query(int p){
53     ll dl = inf, dr = inf, d = dis(t[p].d[0], t[p].d[1], x, y);
54     if(d) ans = min(ans, d);
55 
56     if(t[p].l) dl = dis(t[p].l, x, y);
57     if(t[p].r) dr = dis(t[p].r, x, y);
58     if(dl < dr){
59         if(dl < ans) query(t[p].l);
60         if(dr < ans) query(t[p].r);
61     }
62     else {
63         if(dr < ans) query(t[p].r);
64         if(dl < ans) query(t[p].l);
65     }
66 }
67 
68 int main(){
69     int T; scanf("%d", &T);
70     while(T--){
71         scanf("%d", &n);
72         for(int i = 1; i <= n; i++){
73             scanf("%d%d", &t[i].d[0], &t[i].d[1]);
74             t[i].f = i;
75         }
76         int rt = build(1, n, 0, 0);
77         for(int i = 1; i <= n; i++){
78             ans = inf;
79             x = t[ id[i] ].d[0], y = t[ id[i] ].d[1];
80             query(rt);
81             printf("%lld\\n", ans);
82         }
83     }
84     return 0;
85 }
View Code

 

BZOJ2648

题意:给出n个点,接下来m个操作,每次插入一个点,或者询问离询问点的最近曼哈顿距离。

  1 #include <bits/stdc++.h>
  2 #define ll long long
  3 using namespace std;
  4 #define N 1000010
  5 const ll inf = 1e18;
  6 int n,m,i,id[N],root,cmp_d,rt;
  7 int x, y;
  8 struct node{int d[2],l,r,Max[2],Min[2],val,sum,f;}t[N];
  9 bool cmp(const node&a,const node&b){return a.d[cmp_d]<b.d[cmp_d];}
 10 void umax(int&a,int b){if(a<b)a=b;}
 11 void umin(int&a,int b){if(a>b)a=b;}
 12 void up(int x){
 13     if(t[x].l){
 14         umax(t[x].Max[0],t[t[x].l].Max[0]);
 15         umin(t[x].Min[0],t[t[x].l].Min[0]);
 16         umax(t[x].Max[1],t[t[x].l].Max[1]);
 17         umin(t[x].Min[1],t[t[x].l].Min[1]);
 18     }
 19     if(t[x].r){
 20         umax(t[x].Max[0],t[t[x].r].Max[0]);
 21         umin(t[x].Min[0],t[t[x].r].Min[0]);
 22         umax(t[x].Max[1],t[t[x].r].Max[1]);
 23         umin(t[x].Min[1],t[t[x].r].Min[1]);
 24     }
 25 }
 26 int build(int l,int r,int D,int f){
 27     int mid=(l+r)>>1;
 28     cmp_d=D,std::nth_element(t+l,t+mid,t+r+1,cmp);
 29     id[t[mid].f]=mid;
 30     t[mid].f=f;
 31     t[mid].Max[0]=t[mid].Min[0]=t[mid].d[0];
 32     t[mid].Max[1]=t[mid].Min[1]=t[mid].d[1];
 33     //t[mid].val=t[mid].sum=0;
 34     if(l!=mid)t[mid].l=build(l,mid-1,!D,mid);else t[mid].l=0;
 35     if(r!=mid)t[mid].r=build(mid+1,r,!D,mid);else t[mid].r=0;
 36     return up(mid),mid;
 37 }
 38 
 39 ll dis(ll x1, ll y1, ll x, ll y) {
 40     return abs(x1-x)+abs(y1-y);
 41     //ll xx = x1-x, yy = y1-y;
 42     //return xx*xx+yy*yy;
 43 }
 44 ll dis(int p, ll x, ll y){//估价函数, 以p为子树的最小距离
 45     ll xx = 0, yy = 0;
 46     if(t[p].Max[0] < x) xx = x-t[p].Max[0];
 47     if(t[p].Min[0] > x) xx = t[p].Min[0]-x;
 48     if(t[p].Max[1] < y) yy = y-t[p].Max[1];
 49     if(t[p].Min[1] > y) yy = t[p].Min[1]-y;
 50     return xx+yy;
 51     //return xx*xx+yy*yy;
 52 }
 53 ll ans;
 54 void ins(int now, int k, int x){
 55     if(t[x].d[k] >= t[now].d[k]){
 56         if(t[now].r) ins(t[now].r, !k, x);
 57         else
 58             t[now].r = x, t[x].f = now;
 59     }
 60     else {
 61         if(t[now].l) ins(t[now].l, !k, x);
 62         else t[now].l = x, t[x].f = now;
 63     }
 64     up(now);
 65 }
 66 void query(int p){
 67     ll dl = inf, dr = inf, d = dis(t[p].d[0], t[p].d[1], x, y);
 68     ans = min(ans, d);
 69 
 70     if(t[p].l) dl = dis(t[p].l, x, y);
 71     if(t[p].r) dr = dis(t[p].r, x, y);
 72     if(dl < dr){
 73         if(dl < ans) query(t[p].l);
 74         if(dr < ans) query(t[p].r);
 75     }
 76     else {
 77         if(dr < ans) query(t[p].r);
 78         if(dl < ans) query(t[p].l);
 79     }
 80 }
 81 
 82 int main(){
 83     scanf("%d%d", &n, &m);
 84     for(int i = 1; i <= n; i++)
 85         scanf("%d%d", &t[i].d[0], &t[i].d[1]);
 86     rt = build(1, n, 0, 0);
 87     while(m--){
 88         int op;
 89         scanf("%d%d%d", &op, &x, &y);
 90         if(op == 1){
 91             n++;
 92             t[n].l = t[n].r = 0;
 93             t[n].Max[0] = t[n].Min[0] = t[n].d[0] = x;
 94             t[n].Max[1] = t[n].Min[1] = t[n].d[1] = y;
 95             ins(rt, 0, n);
 96         }
 97         else{
 98             ans = inf;
 99             query(rt);
100             printf("%lld\\n", ans);
101         }
102     }
103     return 0;
104 }
View Code

 

BZOJ3053

题意:k维坐标系下的最近的m个点。直接对于每一个询问都在kdtree中询问m次最近点,每次找到一个最近点对需要把它记录下来,用堆维护即可。

  1 #include <bits/stdc++.h>
  2 #define ll long long
  3 #define mp make_pair
  4 
  5 using namespace std;
  6 #define N 50010
  7 const ll inf = 1e18;
  8 int n,m,k,i,id[N],root,cmp_d,rt;
  9 int x, y, num;
 10 struct node{int d[5],l,r,Max[5],Min[5],val,sum,f;}t[N];
 11 bool cmp(const node&a,const node&b){return a.d[cmp_d]<b.d[cmp_d];}
 12 void umax(int&a,int b){if(a<b)a=b;}
 13 void umin(int&a,int b){if(a>b)a=b;}
 14 void up(int x){
 15     for(int i = 0; i < k; i++){
 16         if(t[x].l){
 17             umax(t[x].Max[i],t[t[x].l].Max[i]);
 18             umin(t[x].Min[i],t[t[x].l].Min[i]);
 19         }
 20         if(t[x].r){
 21             umax(t[x].Max[i],t[t[x].r].Max[i]);
 22             umin(t[x].Min[i],t[t[x].r].Min[i]);
 23         }
 24     }
 25 }
 26 int build(int l,int r,int D,int f){
 27     int mid=(l+r)>>1;
 28     cmp_d=D,std::nth_element(t+l,t+mid,t+r+1,cmp);
 29     id[t[mid].f]=mid;
 30     t[mid].f=f;
 31     for(int i = 0; i < k; i++)
 32     t[mid].Max[i]=t[mid].Min[i]=t[mid].d[i];
 33     //t[mid].Max[1]=t[mid].Min[1]=t[mid].d[1];
 34     //t[mid].val=t[mid].sum=0;
 35     if(l!=mid)t[mid].l=build(l,mid-1,(D+1)%k,mid);else t[mid].l=0;
 36     if(r!=mid)t[mid].r=build(mid+1,r,(D+1)%k,mid);else t[mid].r=0;
 37     return up(mid),mid;
 38 }
 39 int qx[5];
 40 ll dis(int p){//估价函数, 以p为子树的最小距离
 41     ll ret = 0, ans = 0;
 42     for(int i = 0; i < k; i++) {
 43         ret = 0;
 44         if(t[p].Max[i] < qx[i]) ret = qx[i]-t[p].Max[i];
 45         if(t[p].Min[i] > qx[i]) ret = t[p].Min[i]-qx[i];
 46         ans += ret*ret;
 47     }
 48     return ans;
 49 }
 50 ll getdis(int p){
 51     ll ans = 0;
 52     for(int i = 0; i < k; i++)
 53         ans += (qx[i]-t[p].d[i])*(qx[i]-t[p].d[i]);
 54     return ans;
 55 }
 56 void ins(int now, int k, int x){
 57     if(t[x].d[k] >= t[now].d[k]){
 58         if(t[now].r) ins(t[now].r, !k, x);
 59         else
 60             t[now].r = x, t[x].f = now;
 61     }
 62     else {
 63         if(t[now].l) ins(t[now].l, !k, x);
 64         else t[now].l = x, t[x].f = now;
 65     }
 66     up(now);
 67 }
 68 ll ret;
 69 multiset< pair<int, int> > ans;
 70 void query(int p){
 71     ll dl = inf, dr = inf, d = getdis(p);
 72     ans.insert( mp((int)d, p) );
 73     if(ans.size() > num){
 74         multiset< pair<int, int> >::iterator it = ans.end();
 75         it--;
 76         ans.erase(it);
 77     }
 78     ret = (*ans.rbegin()).first;
 79     if(t[p].l) dl = dis(t[p].l);
 80     if(t[p].r) dr = dis(t[p].r);
 81     if(dl < dr){
 82         if(dl < ret||ans.size() < num) query(t[p].l);
 83         if(dr < ret||ans.size() < num) query(t[p].r);
 84     }
 85     else {
 86         if(dr < ret||ans.size() < num) query(t[p].r);
 87         if(dl < ret||ans.size() < num) query(t[p].l);
 88     }
 89 }
 90 
 91 int main(){
 92     while(~scanf("%d%d", &n, &k)){
 93         for(int i = 1; i <= n; i++){
 94             for(int j = 0; j < k; j++)
 95                 scanf("%d", &t[i].d[j]);
 96         }
 97         rt = build(1, n, 0, 0);
 98         scanf("%d", &m);
 99         while(m--){
100             for(int i = 0; i < k; i++)
101                 scanf("%d", qx+i);
102             scanf("%d", &num);
103             ans.clear();
104             query(rt);
105             printf ("the closest %d points are:\\n", num);
106             for(multiset< pair<int, int> >::iterator it = ans.begin(); it != ans.end(); it++){
107                 int pos = (*it).second;
108                 for(int i = 0; i < k; i++)
109                     printf("%d%c", t[pos].d[i], " \\n"[i == k-1]);
110             }
111         }
112     }
113     return 0;
114 }
View Code

 

以上是关于kd tree学习 (最近邻域查询)的主要内容,如果未能解决你的问题,请参考以下文章

Kd-tree原理与实现

机器学习k近邻算法kd树实现优化查询

KD-tree讲解

SDOI2010 捉迷藏 —— KD-Tree

[bzoj4066/2683]简单题_KD-Tree

算法导论第四版学习——习题五Kd-Tree