kd tree学习 (最近邻域查询)
Posted Dirge
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了kd tree学习 (最近邻域查询)相关的知识,希望对你有一定的参考价值。
https://zhuanlan.zhihu.com/p/22557068
http://blog.csdn.net/zhjchengfeng5/article/details/7855241
KD树在算法竞赛中主要用来做各种各样的平面区域查询,包含则累加直接返回,相交则继续递归,相离的没有任何贡献也直接返回。可以处理圆,三角形,矩形等判断起来相对容易的平面区域内的符合加法性质的操作。
比如查询平面内欧几里得距离最近的点的距离。
kdtree其实有点像搜索,暴力+剪枝。
每次从根结点向下搜索,并进行剪枝操作,判断是否有必要继续搜索。
它是通过横一刀,竖一刀,横一刀再竖一刀将平面进行分割,建立二叉树。
建树的复杂度是O(nlogn), 每次用nth_element()在线性时间内取出中位数。 T(n) = 2T(n/2) + O(n) = O(nlogn)
查询复杂度呢? 据第二个链接的博客说最坏是O( sqrt(n) ) 的。并不会分析查询复杂度。
HDU2966 裸kdtree
题意:给平面图上N(1 ≤ N ≤100000)个点,对每个点,找到其他 欧几里德距离 离他最近的点,输出他们之间的距离。保证没有重点。
1 #include <bits/stdc++.h> 2 #define ll long long 3 using namespace std; 4 #define N 200010 5 const ll inf = 1e18; 6 int n,i,id[N],root,cmp_d; 7 int x, y; 8 struct node{int d[2],l,r,Max[2],Min[2],val,sum,f;}t[N]; 9 bool cmp(const node&a,const node&b){return a.d[cmp_d]<b.d[cmp_d];} 10 void umax(int&a,int b){if(a<b)a=b;} 11 void umin(int&a,int b){if(a>b)a=b;} 12 void up(int x){ 13 if(t[x].l){ 14 umax(t[x].Max[0],t[t[x].l].Max[0]); 15 umin(t[x].Min[0],t[t[x].l].Min[0]); 16 umax(t[x].Max[1],t[t[x].l].Max[1]); 17 umin(t[x].Min[1],t[t[x].l].Min[1]); 18 } 19 if(t[x].r){ 20 umax(t[x].Max[0],t[t[x].r].Max[0]); 21 umin(t[x].Min[0],t[t[x].r].Min[0]); 22 umax(t[x].Max[1],t[t[x].r].Max[1]); 23 umin(t[x].Min[1],t[t[x].r].Min[1]); 24 } 25 } 26 int build(int l,int r,int D,int f){ 27 int mid=(l+r)>>1; 28 cmp_d=D,std::nth_element(t+l,t+mid,t+r+1,cmp); 29 id[t[mid].f]=mid; 30 t[mid].f=f; 31 t[mid].Max[0]=t[mid].Min[0]=t[mid].d[0]; 32 t[mid].Max[1]=t[mid].Min[1]=t[mid].d[1]; 33 //t[mid].val=t[mid].sum=0; 34 if(l!=mid)t[mid].l=build(l,mid-1,!D,mid);else t[mid].l=0; 35 if(r!=mid)t[mid].r=build(mid+1,r,!D,mid);else t[mid].r=0; 36 return up(mid),mid; 37 } 38 39 ll dis(ll x1, ll y1, ll x, ll y) { 40 ll xx = x1-x, yy = y1-y; 41 return xx*xx+yy*yy; 42 } 43 ll dis(int p, ll x, ll y){//估价函数, 以p为子树的最小距离 44 ll xx = 0, yy = 0; 45 if(t[p].Max[0] < x) xx = x-t[p].Max[0]; 46 if(t[p].Min[0] > x) xx = t[p].Min[0]-x; 47 if(t[p].Max[1] < y) yy = y-t[p].Max[1]; 48 if(t[p].Min[1] > y) yy = t[p].Min[1]-y; 49 return xx*xx+yy*yy; 50 } 51 ll ans; 52 void query(int p){ 53 ll dl = inf, dr = inf, d = dis(t[p].d[0], t[p].d[1], x, y); 54 if(d) ans = min(ans, d); 55 56 if(t[p].l) dl = dis(t[p].l, x, y); 57 if(t[p].r) dr = dis(t[p].r, x, y); 58 if(dl < dr){ 59 if(dl < ans) query(t[p].l); 60 if(dr < ans) query(t[p].r); 61 } 62 else { 63 if(dr < ans) query(t[p].r); 64 if(dl < ans) query(t[p].l); 65 } 66 } 67 68 int main(){ 69 int T; scanf("%d", &T); 70 while(T--){ 71 scanf("%d", &n); 72 for(int i = 1; i <= n; i++){ 73 scanf("%d%d", &t[i].d[0], &t[i].d[1]); 74 t[i].f = i; 75 } 76 int rt = build(1, n, 0, 0); 77 for(int i = 1; i <= n; i++){ 78 ans = inf; 79 x = t[ id[i] ].d[0], y = t[ id[i] ].d[1]; 80 query(rt); 81 printf("%lld\\n", ans); 82 } 83 } 84 return 0; 85 }
BZOJ2648
题意:给出n个点,接下来m个操作,每次插入一个点,或者询问离询问点的最近曼哈顿距离。
1 #include <bits/stdc++.h> 2 #define ll long long 3 using namespace std; 4 #define N 1000010 5 const ll inf = 1e18; 6 int n,m,i,id[N],root,cmp_d,rt; 7 int x, y; 8 struct node{int d[2],l,r,Max[2],Min[2],val,sum,f;}t[N]; 9 bool cmp(const node&a,const node&b){return a.d[cmp_d]<b.d[cmp_d];} 10 void umax(int&a,int b){if(a<b)a=b;} 11 void umin(int&a,int b){if(a>b)a=b;} 12 void up(int x){ 13 if(t[x].l){ 14 umax(t[x].Max[0],t[t[x].l].Max[0]); 15 umin(t[x].Min[0],t[t[x].l].Min[0]); 16 umax(t[x].Max[1],t[t[x].l].Max[1]); 17 umin(t[x].Min[1],t[t[x].l].Min[1]); 18 } 19 if(t[x].r){ 20 umax(t[x].Max[0],t[t[x].r].Max[0]); 21 umin(t[x].Min[0],t[t[x].r].Min[0]); 22 umax(t[x].Max[1],t[t[x].r].Max[1]); 23 umin(t[x].Min[1],t[t[x].r].Min[1]); 24 } 25 } 26 int build(int l,int r,int D,int f){ 27 int mid=(l+r)>>1; 28 cmp_d=D,std::nth_element(t+l,t+mid,t+r+1,cmp); 29 id[t[mid].f]=mid; 30 t[mid].f=f; 31 t[mid].Max[0]=t[mid].Min[0]=t[mid].d[0]; 32 t[mid].Max[1]=t[mid].Min[1]=t[mid].d[1]; 33 //t[mid].val=t[mid].sum=0; 34 if(l!=mid)t[mid].l=build(l,mid-1,!D,mid);else t[mid].l=0; 35 if(r!=mid)t[mid].r=build(mid+1,r,!D,mid);else t[mid].r=0; 36 return up(mid),mid; 37 } 38 39 ll dis(ll x1, ll y1, ll x, ll y) { 40 return abs(x1-x)+abs(y1-y); 41 //ll xx = x1-x, yy = y1-y; 42 //return xx*xx+yy*yy; 43 } 44 ll dis(int p, ll x, ll y){//估价函数, 以p为子树的最小距离 45 ll xx = 0, yy = 0; 46 if(t[p].Max[0] < x) xx = x-t[p].Max[0]; 47 if(t[p].Min[0] > x) xx = t[p].Min[0]-x; 48 if(t[p].Max[1] < y) yy = y-t[p].Max[1]; 49 if(t[p].Min[1] > y) yy = t[p].Min[1]-y; 50 return xx+yy; 51 //return xx*xx+yy*yy; 52 } 53 ll ans; 54 void ins(int now, int k, int x){ 55 if(t[x].d[k] >= t[now].d[k]){ 56 if(t[now].r) ins(t[now].r, !k, x); 57 else 58 t[now].r = x, t[x].f = now; 59 } 60 else { 61 if(t[now].l) ins(t[now].l, !k, x); 62 else t[now].l = x, t[x].f = now; 63 } 64 up(now); 65 } 66 void query(int p){ 67 ll dl = inf, dr = inf, d = dis(t[p].d[0], t[p].d[1], x, y); 68 ans = min(ans, d); 69 70 if(t[p].l) dl = dis(t[p].l, x, y); 71 if(t[p].r) dr = dis(t[p].r, x, y); 72 if(dl < dr){ 73 if(dl < ans) query(t[p].l); 74 if(dr < ans) query(t[p].r); 75 } 76 else { 77 if(dr < ans) query(t[p].r); 78 if(dl < ans) query(t[p].l); 79 } 80 } 81 82 int main(){ 83 scanf("%d%d", &n, &m); 84 for(int i = 1; i <= n; i++) 85 scanf("%d%d", &t[i].d[0], &t[i].d[1]); 86 rt = build(1, n, 0, 0); 87 while(m--){ 88 int op; 89 scanf("%d%d%d", &op, &x, &y); 90 if(op == 1){ 91 n++; 92 t[n].l = t[n].r = 0; 93 t[n].Max[0] = t[n].Min[0] = t[n].d[0] = x; 94 t[n].Max[1] = t[n].Min[1] = t[n].d[1] = y; 95 ins(rt, 0, n); 96 } 97 else{ 98 ans = inf; 99 query(rt); 100 printf("%lld\\n", ans); 101 } 102 } 103 return 0; 104 }
BZOJ3053
题意:k维坐标系下的最近的m个点。直接对于每一个询问都在kdtree中询问m次最近点,每次找到一个最近点对需要把它记录下来,用堆维护即可。
1 #include <bits/stdc++.h> 2 #define ll long long 3 #define mp make_pair 4 5 using namespace std; 6 #define N 50010 7 const ll inf = 1e18; 8 int n,m,k,i,id[N],root,cmp_d,rt; 9 int x, y, num; 10 struct node{int d[5],l,r,Max[5],Min[5],val,sum,f;}t[N]; 11 bool cmp(const node&a,const node&b){return a.d[cmp_d]<b.d[cmp_d];} 12 void umax(int&a,int b){if(a<b)a=b;} 13 void umin(int&a,int b){if(a>b)a=b;} 14 void up(int x){ 15 for(int i = 0; i < k; i++){ 16 if(t[x].l){ 17 umax(t[x].Max[i],t[t[x].l].Max[i]); 18 umin(t[x].Min[i],t[t[x].l].Min[i]); 19 } 20 if(t[x].r){ 21 umax(t[x].Max[i],t[t[x].r].Max[i]); 22 umin(t[x].Min[i],t[t[x].r].Min[i]); 23 } 24 } 25 } 26 int build(int l,int r,int D,int f){ 27 int mid=(l+r)>>1; 28 cmp_d=D,std::nth_element(t+l,t+mid,t+r+1,cmp); 29 id[t[mid].f]=mid; 30 t[mid].f=f; 31 for(int i = 0; i < k; i++) 32 t[mid].Max[i]=t[mid].Min[i]=t[mid].d[i]; 33 //t[mid].Max[1]=t[mid].Min[1]=t[mid].d[1]; 34 //t[mid].val=t[mid].sum=0; 35 if(l!=mid)t[mid].l=build(l,mid-1,(D+1)%k,mid);else t[mid].l=0; 36 if(r!=mid)t[mid].r=build(mid+1,r,(D+1)%k,mid);else t[mid].r=0; 37 return up(mid),mid; 38 } 39 int qx[5]; 40 ll dis(int p){//估价函数, 以p为子树的最小距离 41 ll ret = 0, ans = 0; 42 for(int i = 0; i < k; i++) { 43 ret = 0; 44 if(t[p].Max[i] < qx[i]) ret = qx[i]-t[p].Max[i]; 45 if(t[p].Min[i] > qx[i]) ret = t[p].Min[i]-qx[i]; 46 ans += ret*ret; 47 } 48 return ans; 49 } 50 ll getdis(int p){ 51 ll ans = 0; 52 for(int i = 0; i < k; i++) 53 ans += (qx[i]-t[p].d[i])*(qx[i]-t[p].d[i]); 54 return ans; 55 } 56 void ins(int now, int k, int x){ 57 if(t[x].d[k] >= t[now].d[k]){ 58 if(t[now].r) ins(t[now].r, !k, x); 59 else 60 t[now].r = x, t[x].f = now; 61 } 62 else { 63 if(t[now].l) ins(t[now].l, !k, x); 64 else t[now].l = x, t[x].f = now; 65 } 66 up(now); 67 } 68 ll ret; 69 multiset< pair<int, int> > ans; 70 void query(int p){ 71 ll dl = inf, dr = inf, d = getdis(p); 72 ans.insert( mp((int)d, p) ); 73 if(ans.size() > num){ 74 multiset< pair<int, int> >::iterator it = ans.end(); 75 it--; 76 ans.erase(it); 77 } 78 ret = (*ans.rbegin()).first; 79 if(t[p].l) dl = dis(t[p].l); 80 if(t[p].r) dr = dis(t[p].r); 81 if(dl < dr){ 82 if(dl < ret||ans.size() < num) query(t[p].l); 83 if(dr < ret||ans.size() < num) query(t[p].r); 84 } 85 else { 86 if(dr < ret||ans.size() < num) query(t[p].r); 87 if(dl < ret||ans.size() < num) query(t[p].l); 88 } 89 } 90 91 int main(){ 92 while(~scanf("%d%d", &n, &k)){ 93 for(int i = 1; i <= n; i++){ 94 for(int j = 0; j < k; j++) 95 scanf("%d", &t[i].d[j]); 96 } 97 rt = build(1, n, 0, 0); 98 scanf("%d", &m); 99 while(m--){ 100 for(int i = 0; i < k; i++) 101 scanf("%d", qx+i); 102 scanf("%d", &num); 103 ans.clear(); 104 query(rt); 105 printf ("the closest %d points are:\\n", num); 106 for(multiset< pair<int, int> >::iterator it = ans.begin(); it != ans.end(); it++){ 107 int pos = (*it).second; 108 for(int i = 0; i < k; i++) 109 printf("%d%c", t[pos].d[i], " \\n"[i == k-1]); 110 } 111 } 112 } 113 return 0; 114 }
以上是关于kd tree学习 (最近邻域查询)的主要内容,如果未能解决你的问题,请参考以下文章