树链剖分模板

Posted mudrobot

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了树链剖分模板相关的知识,希望对你有一定的参考价值。

下面我就来详细讲解一下关于树剖的一些重点,其实树剖的主要就是轻重链的判断,这一点我默认大家都懂,所以我就直接从两个dfs那里开始说。

dfs1

第一个dfs主要是处理一下基本的一些信息,就是我们更新每一个点的爸爸,处理一下他们当前的深度,然后更新他们的子树大小,然后我们再从他们的儿子中找到重儿子。具体代码实现如下:

void dfs1(long long v,long long fa,long long depth)
{
    f[v]=fa;
    d[v]=depth;
    size[v]=1;
    for(long long i=edge[v].size()-1;i>=0;--i)
    {
        if(edge[v][i]==fa) continue;
        dfs1(edge[v][i],v,depth+1);
        size[v]+=size[edge[v][i]];
        if(size[edge[v][i]]>size[so[v]])
        so[v]=edge[v][i];
    }
}

dfs2

这一遍的dfs非常的重要,我们更新的是每一个节点当前所在重链的顶点(top数组),然后把他们的访问次序打上标记(id数组),在把打上的标记映射回来(rk数组),然后为了保证同一条重链上的标记是连续的,所以我们再向下搜索他们的重儿子做同样的操作,对于不是他的重儿子的其他儿子,我们就重新把他们当成新的重链的起点,重新进行搜索。然后一个节点如果没有重儿子说明他是叶子节点,没有儿子那么就返回:

具体代码实现如下:

void dfs2(long long v,long long t)
{
    top[v]=t;
    id[v]=++cnt;
    rk[cnt]=v;
    if(!so[v]) return ;//说明我没有儿子,并不只是没有重儿子 
    dfs2(so[v],t);
    for(long long i=edge[v].size()-1;i>=0;--i)
    {
        if(edge[v][i]!=f[v]&&edge[v][i]!=so[v])
        dfs2(edge[v][i],edge[v][i]); 
    } 
}

special_modify和sum(special_query)

这两个东西是线段树中不是非常支持的操作,其实这个东西是另外一种求lca的方法,本人也不好讲,大概提一下思路吧,就是我们要查找x,y两个点之间的距离那么我们首先先找到x,y两个点所在的重链,如果x所在的重链的顶点比y重链所在的顶点深度要深一些的话,我们就先找到x重链上面的一条重链,然后把x原来所在的那条重链的值全部在线段树中用区间修改更改了!如果y更深我们就对y进行同样的操作,然后直到他们两个玩意儿在同一条重链上的时候,这个时候我们就可以确定他们的节点之间所连接的节点的编号是连续的,然后我们在进行最后一次线段树的区间加就可以了,查询就是把修改的过程改成查询的过程就可以了!

具体代码实现如下:

long long sum(long long x,long long y)
{
    long long ans=0;
    long long fx=top[x],fy=top[y];
    while(fx!=fy)
    {
        if(d[fx]>=d[fy])
        {
            ans=(ans+query(root,id[fx],id[x]))%p;
            x=f[fx];fx=top[x];
        }
        else
        {
            ans=(ans+query(root,id[fy],id[y]))%p;
            y=f[fy];fy=top[y];
        }
    }
    if(id[x]<=id[y])
    ans=(ans+query(root,id[x],id[y]))%p;
    else
    ans=(ans+query(root,id[y],id[x]))%p;
    return ans;
} 
void special_modify(int x,int y,int c)
{
	int fx=top[x],fy=top[y];
	while(fx!=fy)
	{
		if(d[fx]>=d[fy])
		{
			modify(root,id[fx],id[x],c);
			x=f[fx];fx=top[x];
		}
		else
		{
			modify(root,id[fy],id[y],c);
			y=f[fy];fy=top[y];
		}
	}
	if(id[x]<=id[y])
	{
		modify(root,id[x],id[y],c);
	}
	else modify(root,id[y],id[x],c);
} 

然后线段树部分其实是和普通线段树是一样的,我就不多说了!

全部代码如下:

#include<bits/stdc++.h>
using namespace std;
const long long N=5e5+10;
vector<long long> edge[N];
struct sd{
    long long sum,add,l,r,son[2];
}node[N*2];
long long root,n,m,rt,p,a[N],cnt,top[N],f[N],d[N],size[N],so[N],rk[N],id[N];
void pushdown(long long k)
{
    node[node[k].son[0]].add+=node[k].add;
    node[node[k].son[1]].add+=node[k].add;
    node[node[k].son[0]].sum+=node[k].add*(node[node[k].son[0]].r-node[node[k].son[0]].l+1);
    node[node[k].son[1]].sum+=node[k].add*(node[node[k].son[1]].r-node[node[k].son[1]].l+1);
    node[node[k].son[0]].sum=node[node[k].son[0]].sum%p;
    node[node[k].son[1]].sum=node[node[k].son[1]].sum%p;
    node[k].add=0;
} 
void check()
{
	printf("


%lld
",node[3].sum);
	printf("%lld %lld
",node[node[3].son[0]].sum,node[node[3].son[1]].sum);
}
void update(long long k)
{
    node[k].sum=node[node[k].son[0]].sum+node[node[k].son[1]].sum;
} 
void modify(long long k,long long l,long long r,long long val)
{
    if(node[k].l==l&&node[k].r==r)
    {
    	pushdown(k);
        node[k].add+=val;
        node[k].sum+=val*(node[k].r-node[k].l+1);
        node[k].sum=node[k].sum%p;
    }
    else
    {
        pushdown(k);
        long long mid=(node[k].l+node[k].r)/2;
        if(r<=mid) modify(node[k].son[0],l,r,val);
        else if(l>mid) modify(node[k].son[1],l,r,val);
        else 
        {
            modify(node[k].son[0],l,mid,val);modify(node[k].son[1],mid+1,r,val);
        } 
        update(k);
    }
} 
void Build_tree(long long &k,long long l,long long r)
{
    cnt++;k=cnt; 
    node[k].l=l;node[k].r=r;
    if(node[k].l==node[k].r)
    {
        node[k].sum=a[rk[node[k].l]];
    }
    else
    {
        long long mid=(l+r)/2;
        Build_tree(node[k].son[0],l,mid);Build_tree(node[k].son[1],mid+1,r);
        update(k); 
    }
}

long long query(long long k,long long l,long long r)
{
    if(node[k].l==l&&node[k].r==r)
    {
    	pushdown(k);
        return node[k].sum%p;
    }
    else
    {
        pushdown(k);
        long long mid=(node[k].l+node[k].r)/2;
        if(r<=mid) return query(node[k].son[0],l,r);
        else if(l>mid) return query(node[k].son[1],l,r);
        else 
        {
            return (query(node[k].son[0],l,mid)+query(node[k].son[1],mid+1,r))%p;//这个模一定要取在外面 !!! 
        } 
    }
}
long long sum(long long x,long long y)
{
    long long ans=0;
    long long fx=top[x],fy=top[y];
    while(fx!=fy)
    {
        if(d[fx]>=d[fy])
        {
            ans=(ans+query(root,id[fx],id[x]))%p;
            x=f[fx];fx=top[x];
        }
        else
        {
            ans=(ans+query(root,id[fy],id[y]))%p;
            y=f[fy];fy=top[y];
        }
    }
    if(id[x]<=id[y])
    ans=(ans+query(root,id[x],id[y]))%p;
    else
    ans=(ans+query(root,id[y],id[x]))%p;
    return ans;
} 
//long long lca(int x,int y){
//	while(top[x]^top[y]){
//		if(d[top[x]]>d[top[y]])x=f[top[x]];
//		else y=f[top[y]];
//	}return d[x]>d[y]?y:x;
//}
void special_modify(long long x,long long y,long long c)
{
    long long fx=top[x],fy=top[y];
    while(fx!=fy)//当我们不在同一条重链上时 
    {
        if(d[fx]>=d[fy])
        {
            modify(root,id[fx],id[x],c);//先把我锁在的重链的权值改了
            x=f[fx],fx=top[x]; 
        } 
        else
        {
            modify(root,id[fy],id[y],c);
            y=f[fy],fy=top[y];
        }
    } 
    if(id[x]<=id[y])
    {
        modify(root,id[x],id[y],c);
    }
    else
    {
        modify(root,id[y],id[x],c);
    }
}

void dfs1(long long v,long long fa,long long depth)
{
    f[v]=fa;
    d[v]=depth;
    size[v]=1;
    for(long long i=edge[v].size()-1;i>=0;--i)
    {
        if(edge[v][i]==fa) continue;
        dfs1(edge[v][i],v,depth+1);
        size[v]+=size[edge[v][i]];
        if(size[edge[v][i]]>size[so[v]])
        so[v]=edge[v][i];
    }
}
void dfs2(long long v,long long t)
{
    top[v]=t;
    id[v]=++cnt;
    rk[cnt]=v;
    if(!so[v]) return ;//说明我没有儿子,并不只是没有重儿子 
    dfs2(so[v],t);
    for(long long i=edge[v].size()-1;i>=0;--i)
    {
        if(edge[v][i]!=f[v]&&edge[v][i]!=so[v])
        dfs2(edge[v][i],edge[v][i]); 
    } 
}
void pt(long long k)
{
    if(node[k].l==node[k].r){printf("%lld: %lld  NO:%lld
",node[k].l,node[k].sum,rk[node[k].l]);}
    else {pt(node[k].son[0]);pt(node[k].son[1]);}
}
int main()
{
    scanf("%lld%lld%lld%lld",&n,&m,&rt,&p);
    for(long long i=1;i<=n;++i)
    {
        scanf("%lld",&a[i]);
    }
    for(long long i=1;i<n;++i)
    {
        long long x,y;
        scanf("%lld%lld",&x,&y);
        edge[x].push_back(y);edge[y].push_back(x);
        
    }
    cnt=0;
    dfs1(rt,0,1);
    dfs2(rt,rt);
    cnt=0;
    Build_tree(root,1,n);
    for(long long i=1;i<=m;++i)
    {
        long long op,xx,xy,xz;
        scanf("%lld",&op);
        if(op==1)
        {
            scanf("%lld%lld%lld",&xx,&xy,&xz);
            special_modify(xx,xy,xz);
        }
        else if(op==2)
        {
            scanf("%lld%lld",&xx,&xy);
            //printf("

");
            //pt(root);
            //check();
            printf("%lld
",sum(xx,xy));
        }
        else if(op==3)
        {
            scanf("%lld%lld",&xx,&xz);
            modify(root,id[xx],id[xx]+size[xx]-1,xz); 
        }
        else if(op==4)
        {
            scanf("%lld",&xx);
            printf("%lld
",query(root,id[xx],id[xx]+size[xx]-1));
        }
    }
    return 0;
} 

下面这一篇树链剖分要稍微清真一些:

#include<bits/stdc++.h>
#define LL long long
#define N 100004
using namespace std;
struct sd{
    LL sum,l,r,son[2],add;
}node[N*2];
struct line{
    LL to,next;
}edge[N*2];
LL n,m,rt,p,head[N],ini[N],qnt,cnt,root;
LL dep[N],x[N],size[N],maxson[N],top[N],id[N],rk[N];
void add(LL a,LL b){edge[++qnt].next=head[a];edge[qnt].to=b;head[a]=qnt;}
//______________________________________________________________________________________
void update(LL k){node[k].sum=node[node[k].son[0]].sum+node[node[k].son[1]].sum;}
void add_add(LL k,LL val)
{
    node[k].add=(node[k].add+val)%p;
    node[k].sum=(node[k].sum+(node[k].r-node[k].l+1)*val)%p;
}
void pushdown(LL k)
{
    add_add(node[k].son[0],node[k].add);add_add(node[k].son[1],node[k].add);
    node[k].add=0;
}
void Buildtree(LL &k,LL l,LL r)
{
    k=++cnt;node[k].l=l;node[k].r=r;
    if(l==r) node[k].sum=ini[rk[l]];
    else {LL mid=(l+r)/2;Buildtree(node[k].son[0],l,mid);
    Buildtree(node[k].son[1],mid+1,r);update(k);}
}
void modify(LL k,LL l,LL r,LL val)
{
    if(node[k].l==l&&node[k].r==r) add_add(k,val);
    else
    {
        pushdown(k);
        LL mid=(node[k].l+node[k].r)/2;
        if(mid>=r) modify(node[k].son[0],l,r,val);
        else if(mid<l)modify(node[k].son[1],l,r,val);
        else modify(node[k].son[0],l,mid,val),modify(node[k].son[1],mid+1,r,val);
        update(k);
    }
}
LL query(LL k,LL l,LL r)
{
    if(node[k].l==l&&node[k].r==r) return node[k].sum%p;
    else
    {
        pushdown(k);
        LL mid=(node[k].l+node[k].r)/2;
        if(mid>=r) return query(node[k].son[0],l,r)%p;
        else if(mid<l) return query(node[k].son[1],l,r)%p;
        else return (query(node[k].son[0],l,mid)+query(node[k].son[1],mid+1,r))%p;
    }
}
//______________________________________________________________________________________
void dfs1(LL v,LL fa,LL depth)
{
    x[v]=fa;dep[v]=depth;size[v]=1;
    for(LL i=head[v];i;i=edge[i].next)
    {
        if(fa==edge[i].to)continue;
        dfs1(edge[i].to,v,depth+1);
        size[v]+=size[edge[i].to];
        if(size[edge[i].to]>size[maxson[v]])
        maxson[v]=edge[i].to;
    }
}
void dfs2(LL v,LL sign)
{
    top[v]=sign;
    id[v]=++cnt;
    rk[cnt]=v;
    if(!maxson[v]) return;
    dfs2(maxson[v],sign);
    for(LL i=head[v];i;i=edge[i].next)
    if(edge[i].to!=x[v]&&edge[i].to!=maxson[v])dfs2(edge[i].to,edge[i].to);
}
LL special_query(LL xx,LL yy)
{
    LL ans=0;
    LL fx=top[xx],fy=top[yy];
    while(fx!=fy)
    {
        if(dep[fx]>=dep[fy]) ans=(ans+query(root,id[fx],id[xx]))%p,xx=x[fx],fx=top[xx];
        else  ans=(ans+query(root,id[fy],id[yy]))%p,yy=x[fy],fy=top[yy];
    }
    ans=(ans+query(root,min(id[xx],id[yy]),max(id[xx],id[yy])))%p;
    return ans;
}
LL special_modify(LL xx,LL yy,LL val)
{
    LL fx=top[xx],fy=top[yy];
    while(fx!=fy)
    {
        if(dep[fx]>=dep[fy]) modify(root,id[fx],id[xx],val),xx=x[fx],fx=top[xx];
        else modify(root,id[fy],id[yy],val),yy=x[fy],fy=top[yy];
    }
    modify(root,min(id[xx],id[yy]),max(id[xx],id[yy]),val);
}
//______________________________________________________________________________________

int main()
{
    scanf("%lld%lld%lld%lld",&n,&m,&rt,&p);
    for(LL i=1;i<=n;++i) scanf("%lld",&ini[i]);
    for(LL a,b,i=1;i<n;++i) scanf("%lld%lld",&a,&b),add(a,b),add(b,a);
    dfs1(rt,0,1); dfs2(rt,rt);cnt=0; Buildtree(root,1,n); 
    for(int i=1;i<=m;++i)
    {
        LL op,a,b,c;
        scanf("%lld",&op);
        if(op==1) scanf("%lld%lld%lld",&a,&b,&c),special_modify(a,b,c);
        if(op==2) scanf("%lld%lld",&a,&b),printf("%lld
",special_query(a,b)%p);
        if(op==3) scanf("%lld%lld",&a,&b),/*printf("%lld %lld
",id[a],size[a]),*/modify(root,id[a],id[a]+size[a]-1,b);
        if(op==4) scanf("%lld",&a),printf("%lld
",query(root,id[a],id[a]+size[a]-1)%p);
    }
    return 0;
}

以上是关于树链剖分模板的主要内容,如果未能解决你的问题,请参考以下文章

模板时间◆模板·II◆ 树链剖分

树链剖分模板题(luogu3384 模板树链剖分)

树链剖分模板

BZOJ 2243--染色(树链剖分)

模板树链剖分

luoguP3384 模板树链剖分