树套树初探
Posted xu-daxia
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了树套树初探相关的知识,希望对你有一定的参考价值。
最近学了学树套树,做了几道模板题。
发现好像有点水
咳咳咳。
树套树,顾名思义,一个树套一个树。比如树状数组套平衡树,就是把树状数组的每一个结点作为一颗平衡树,线段树套权值线段树,就是一颗线段树,每一个结点都是一颗权值线段树。。。
如果有一个问题是要求一个区间([l,r])中比(x)小的数有多少个带单点修改(n<=50000),可以用树状数组套平衡树解决,首先在每个点上建立平衡树,然后再拿树状数组维护起来,然后我们可以先求出区间([1,r])中有多少个比(x)小的数减去区间([1,l-1])中有多少个比(x)小的数,每一个([1,x])的区间我们用树状数组选出(log)颗平衡树,在每颗平衡树上找出比(x)小的最后加起来就行了。修改就在树状数组上对应的平衡树中删除插入就行了。这样每次询问的复杂度为(log^2n)。
下面看几道题:
P3380 【模板】二逼平衡树(树套树)
嗯,这题有很多做法。
首先可以用树状数组套平衡树做。
排名本质上就是有多少比(x)小的。
第k小二分答案然后查排名加一个(log)是(log^3n)的。
前趋后继就是求出树状数组每一个点的前驱后继然后取( ext{max})或( ext{min})就行。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<ctime>
#include<cstdlib>
using namespace std;
const int INF=1e8;
const int N=5e4+100;
int tot,rad[N*30],size[N*30],v[N*30],ch[N*30][2],root[N<<3],x,y,z,n,m,a[N];
struct tree{
int l,r;
}tr[N<<3];
int new_node(int x){
int now=++tot;
rad[now]=rand();size[now]=1;v[now]=x;
return now;
}
void update(int now){
size[now]=size[ch[now][0]]+size[ch[now][1]]+1;
}
int merge(int x,int y){
if(x==0||y==0)return x+y;
if(rad[x]>rad[y]){
ch[x][1]=merge(ch[x][1],y);
update(x);
return x;
}
else{
ch[y][0]=merge(x,ch[y][0]);
update(y);
return y;
}
}
void split(int &x,int &y,int now,int k){
if(now==0)x=y=0;
else{
if(v[now]<=k){
x=now;
split(ch[x][1],y,ch[x][1],k);
}
else {
y=now;
split(x,ch[y][0],ch[y][0],k);
}
update(now);
}
}
void ins(int now,int w){
split(x,y,root[now],w);
root[now]=merge(merge(x,new_node(w)),y);
}
void ins(int x,int w,int now){
ins(now,w);
if(tr[now].l==tr[now].r)return;
int mid=(tr[now].l+tr[now].r)>>1;
if(x>mid)ins(x,w,now*2+1);
else ins(x,w,now*2);
}
int rank(int now,int k){
split(x,y,root[now],k-1);
int w=size[x];
root[now]=merge(x,y);
return w;
}
int rank(int l,int r,int k,int now){
if(tr[now].l==l&&tr[now].r==r){
return rank(now,k);
}
int mid=(tr[now].l+tr[now].r)>>1;
if(l>mid)return rank(l,r,k,now*2+1);
else if(r<=mid)return rank(l,r,k,now*2);
else return rank(l,mid,k,now*2)+rank(mid+1,r,k,now*2+1);
}
void work(int l,int r,int k){
int L=0,R=INF,ans;
while(L<=R){
int mid=(L+R)>>1;
if(rank(l,r,mid,1)<k){
ans=mid;
L=mid+1;
}
else R=mid-1;
}
printf("%d
",ans);
}
void del(int now,int w){
split(x,z,root[now],w);
split(x,y,x,w-1);
y=merge(ch[y][0],ch[y][1]);
root[now]=merge(merge(x,y),z);
}
void del(int x,int w,int now){
del(now,w);
if(tr[now].l==tr[now].r)return;
int mid=(tr[now].l+tr[now].r)>>1;
if(x>mid)del(x,w,now*2+1);
else del(x,w,now*2);
}
int kth(int now,int k){
int l=ch[now][0];
if(size[l]>=k)return kth(l,k);
else if(size[l]+1==k)return v[now];
else return kth(ch[now][1],k-size[l]-1);
}
int pre(int now,int k){
int ans;
split(x,y,root[now],k-1);
if(size[x]==0)ans=-2147483647;
else ans=kth(x,size[x]);
root[now]=merge(x,y);
return ans;
}
int pre(int l,int r,int k,int now){
if(tr[now].l==l&&tr[now].r==r){
return pre(now,k);
}
int mid=(tr[now].l+tr[now].r)>>1;
if(l>mid)return pre(l,r,k,now*2+1);
else if(r<=mid)return pre(l,r,k,now*2);
else return max(pre(l,mid,k,now*2),pre(mid+1,r,k,now*2+1));
}
int suc(int now,int k){
int ans;
split(x,y,root[now],k);
if(size[y]==0)ans=2147483647;
else ans=kth(y,1);
root[now]=merge(x,y);
return ans;
}
int suc(int l,int r,int k,int now){
if(tr[now].l==l&&tr[now].r==r){
return suc(now,k);
}
int mid=(tr[now].l+tr[now].r)>>1;
if(l>mid)return suc(l,r,k,now*2+1);
else if(r<=mid)return suc(l,r,k,now*2);
else return min(suc(l,mid,k,now*2),suc(mid+1,r,k,now*2+1));
}
void build(int l,int r,int now){
tr[now].l=l;tr[now].r=r;
if(l==r)return;
int mid=(l+r)>>1;
build(l,mid,now*2);
build(mid+1,r,now*2+1);
}
int read(){
int sum=0,f=1;char ch=getchar();
while(ch<‘0‘||ch>‘9‘){if(ch==‘-‘)f=-1;ch=getchar();}
while(ch>=‘0‘&&ch<=‘9‘){sum=sum*10+ch-‘0‘;ch=getchar();}
return sum*f;
}
int main(){
srand(time(NULL));
n=read();m=read();
build(1,n,1);
for(int i=1;i<=n;++i)ins(i,a[i]=read(),1);
while(m--){
int type=read(),l=read(),r=read();
if(type==1){
int k=read();
printf("%d
",rank(l,r,k,1)+1);
}
else if(type==2){
int k=read();
work(l,r,k);
}
else if(type==3){
del(l,a[l],1);
ins(l,a[l]=r,1);
}
else if(type==4){
int k=read();
printf("%d
",pre(l,r,k,1));
}
else{
int k=read();
printf("%d
",suc(l,r,k,1));
}
}
return 0;
}
然后我还写了一个线段树套权值线段树。
这个对于求第k大可以换一种求法。
就是先找到区间对应的线段树的区间编号(其实就是对应的权值线段树的根的编号),拿一个数组存下来,然后在权值线段树上二分就行了。这样复杂度就变成了(log^2n)的了。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
const int N=5e4+100;
int sum[15000000],ch[15000000][2],root[N<<6];
int cnt,tot,n,m,num,c[N],b[N<<6],a[N],l[N],r[N],x[N],type[N];
struct tree{
int l,r;
}tr[N<<6];
void build(int l,int r,int now){
tr[now].l=l;tr[now].r=r;
if(l==r)return;
int mid=(l+r)>>1;
build(l,mid,now*2);
build(mid+1,r,now*2+1);
}
void add(int l,int r,int w,int k,int &now){
if(now==0){
now=++cnt;
ch[now][0]=ch[now][1]=0;
sum[now]=0;
}
sum[now]+=k;
if(l==r)return;
int mid=(l+r)>>1;
if(w>mid)add(mid+1,r,w,k,ch[now][1]);
else add(l,mid,w,k,ch[now][0]);
}
void add(int x,int w,int k,int now){
add(1,tot,w,k,root[now]);
if(tr[now].l==tr[now].r)return;
int mid=(tr[now].l+tr[now].r)>>1;
if(x>mid)add(x,w,k,now*2+1);
else add(x,w,k,now*2);
}
int check(int l,int r,int L,int R,int &now){
if(L>R)return 0;
if(now==0){
now=++cnt;
ch[now][0]=ch[now][1]=0;
sum[now]=0;
}
if(sum[now]==0)return 0;
if(l==L&&r==R)return sum[now];
int mid=(l+r)>>1;
if(L>mid)return check(mid+1,r,L,R,ch[now][1]);
else if(R<=mid)return check(l,mid,L,R,ch[now][0]);
else return check(l,mid,L,mid,ch[now][0])+check(mid+1,r,mid+1,R,ch[now][1]);
}
int rank(int l,int r,int k,int now){
if(tr[now].l==l&&tr[now].r==r)return check(1,tot,1,k-1,root[now]);
int mid=(tr[now].l+tr[now].r)>>1;
if(l>mid)return rank(l,r,k,now*2+1);
else if(r<=mid)return rank(l,r,k,now*2);
else return rank(l,mid,k,now*2)+rank(mid+1,r,k,now*2+1);
}
void find(int l,int r,int now){
if(tr[now].l==l&&tr[now].r==r){
c[++num]=root[now];
return;
}
int mid=(tr[now].l+tr[now].r)>>1;
if(l>mid)find(l,r,now*2+1);
else if(r<=mid)find(l,r,now*2);
else find(l,mid,now*2),find(mid+1,r,now*2+1);
}
int kth(int l,int r,int k){
num=0;
find(l,r,1);
int L=1,R=tot;
while(L!=R){
int tmp=0,mid=(L+R)>>1;
for(int i=1;i<=num;i++)tmp+=sum[ch[c[i]][0]];
if(tmp>=k){
R=mid;
for(int i=1;i<=num;i++)c[i]=ch[c[i]][0];
}
else {
L=mid+1;k-=tmp;
for(int i=1;i<=num;i++)c[i]=ch[c[i]][1];
}
}
return L;
}
int kth(int l,int r,int k,int now){
if(l==r)return l;
int mid=(l+r)>>1;
if(sum[ch[now][0]]>=k)return kth(l,mid,k,ch[now][0]);
else return kth(mid+1,r,k-sum[ch[now][0]],ch[now][1]);
}
int pre(int now,int k){
if(now==0)return 0;
int tmp=check(1,tot,1,k-1,now);
if(tmp==0)return 0;
else return kth(1,tot,tmp,now);
}
int pre(int l,int r,int k,int now){
if(tr[now].l==l&&tr[now].r==r)return pre(root[now],k);
int mid=(tr[now].l+tr[now].r)>>1;
if(l>mid)return pre(l,r,k,now*2+1);
else if(r<=mid)return pre(l,r,k,now*2);
else return max(pre(l,mid,k,now*2),pre(mid+1,r,k,now*2+1));
}
int suc(int now,int k){
if(now==0)return tot+1;
int tmp=check(1,tot,k+1,tot,now);
if(tmp==0)return tot+1;
else return kth(1,tot,check(1,tot,1,k,now)+1,now);
}
int suc(int l,int r,int k,int now){
if(tr[now].l==l&&tr[now].r==r)return suc(root[now],k);
int mid=(tr[now].l+tr[now].r)>>1;
if(l>mid)return suc(l,r,k,now*2+1);
else if(r<=mid)return suc(l,r,k,now*2);
else return min(suc(l,mid,k,now*2),suc(mid+1,r,k,now*2+1));
}
int read(){
int sum=0,f=1;char ch=getchar();
while(ch<‘0‘||ch>‘9‘){if(ch==‘-‘)f=-1;ch=getchar();}
while(ch>=‘0‘&&ch<=‘9‘){sum=sum*10+ch-‘0‘;ch=getchar();}
return sum*f;
}
int main(){
n=read();m=read();
build(1,n,1);
for(int i=1;i<=n;i++)a[i]=read(),b[++tot]=a[i];
for(int i=1;i<=m;i++){
type[i]=read(),l[i]=read(),r[i]=read();
if(type[i]==3)b[++tot]=r[i];
else{
x[i]=read();
if(type[i]!=2)b[++tot]=x[i];
}
}
sort(b+1,b+1+tot);
tot=unique(b+1,b+1+tot)-b-1;
b[0]=-2147483647;b[tot+1]=2147483647;
for(int i=1;i<=n;i++){
a[i]=lower_bound(b+1,b+1+tot,a[i])-b;
add(i,a[i],1,1);
}
for(int i=1;i<=m;i++)
if(type[i]==3)r[i]=lower_bound(b+1,b+1+tot,r[i])-b;
else if(type[i]!=2)x[i]=lower_bound(b+1,b+1+tot,x[i])-b;
for(int i=1;i<=m;i++){
if(type[i]==1)printf("%d
",rank(l[i],r[i],x[i],1)+1);
else if(type[i]==2)printf("%d
",b[kth(l[i],r[i],x[i])]);
else if(type[i]==3)add(l[i],a[l[i]],-1,1),add(l[i],a[l[i]]=r[i],1,1);
else if(type[i]==4)printf("%d
",b[pre(l[i],r[i],x[i],1)]);
else if(type[i]==5)printf("%d
",b[suc(l[i],r[i],x[i],1)]);
}
return 0;
}
还有一种树状数组套权值线段树的方法,利用了k小的可减性。优化了常数和代码量。
但是我没写。。。
以上是关于树套树初探的主要内容,如果未能解决你的问题,请参考以下文章