树链剖分
Posted cold-cold
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了树链剖分相关的知识,希望对你有一定的参考价值。
首先,在学树链剖分之前最好先把 LCA、树形DP、DFS序 这三个知识点学了还有必备的 链式前向星、线段树 也要先学了。如果这些个知识点没掌握好的话,树链剖分难以理解也是当然的
树链剖分 就是对一棵树分成几条链,把树形变为线性,减少处理难度
需要处理的问题:
- 将树从x到y结点最短路径上所有节点的值都加上z
- 求树从x到y结点最短路径上所有节点的值之和
- 将以x为根节点的子树内所有节点值都加上z
- 求以x为根节点的子树内所有节点值之和
首先我们要了解这些概念
- 重儿子:对于每一个非叶子节点,它的儿子中 以那个儿子为根的子树节点数最大的儿子 为该节点的重儿子
- 轻儿子:对于每一个非叶子节点,它的儿子中 非重儿子 的剩下所有儿子即为轻儿子
- 叶子节点没有重儿子也没有轻儿子(因为它没有儿子)
- 重边:连接任意两个重儿子的边叫做重边
- 轻边:剩下的即为轻边
- 重链:相邻重边连起来的 连接一条重儿子 的链叫重链
- 对于叶子节点,若其为轻儿子,则有一条以自己为起点的长度为1的链
- 每一条重链以轻儿子为起点
dfs1()
这个dfs要处理几件事情:
- 标记每个点的深度dep[]
- 标记每个点的父亲fa[]
- 标记每个非叶子节点的子树大小(含它自己)
- 标记每个非叶子节点的重儿子编号son[]
void dfs1(int u,int ff)
{
dep[u]=dep[ff]+1;fa[u]=ff;size[u]=1;
int maxn=0;
for(int i=head[u],v=e[i].to;i;i=e[i].nxt,v=e[i].to) if(v!=ff)
{
dfs1(v,u);
size[u]+=size[v];
if(size[v]>maxn) maxn=size[v],son[u]=v;
}
}
dfs2()
这个dfs2也要预处理几件事情
- 标记每个点的新编号
- 赋值每个点的初始值到新编号上
- 处理每个点所在链的顶端
- 处理每条链
顺序:先处理重儿子再处理轻儿子,理由后面说
void dfs2(int u,int topf)
{
id[u]=++tim;wn[tim]=w[u];
top[u]=topf;
if(!son[u]) return;
dfs2(son[u],topf);
for(int i=head[u],v=e[i].to;i;i=e[i].nxt,v=e[i].to) if(v!=fa[u]&&v!=son[u])
dfs2(v,v);
}
处理问题
Attention 重要的来了!!!
前面说到dfs2的顺序是先处理重儿子再处理轻儿子 我们来模拟一下:
- 因为顺序是先重再轻,所以每一条重链的新编号是连续的
- 因为是dfs,所以每一个子树的新编号也是连续的
现在回顾一下我们要处理的问题
- 处理任意两点间路径上的点权和
- 处理一点及其子树的点权和
- 修改任意两点间路径上的点权
- 修改一点及其子树的点权
1、当我们要处理任意两点间路径时:
设所在链顶端的深度更深的那个点为x点
- ans加上x点到x所在链顶端 这一段区间的点权和
- 把x跳到x所在链顶端的那个点的上面一个点
不停执行这两个步骤,直到两个点处于一条链上,这时再加上此时两个点的区间和即可
这时我们注意到,我们所要处理的所有区间均为连续编号(新编号),于是想到线段树,用线段树处理连续编号区间和每次查询时间复杂度为O(log2n)
void update_tree(int u,int v,int val) { while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v); update(id[top[u]],id[u],val,1,n,1); u=fa[top[u]]; } if(dep[u]<dep[v]) swap(u,v); update(id[v],id[u],val,1,n,1); } int query_tree(int u,int v) { int ans=0; while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v); ans=(ans+query(id[top[u]],id[u],1,n,1))%mod; u=fa[top[u]]; } if(dep[u]<dep[v]) swap(u,v); return (ans+query(id[v],id[u],1,n,1))%mod; }
总体代码如下:
#include <iostream> #include <cstdio> #include <cstring> #include <cmath> #include <map> #include <cstdlib> #include <algorithm> #include <queue> #include <stack> #define lson l,mid,o<<1 #define rson mid+1,r,o<<1|1 using namespace std; inline int read() { int a=0,q=0; char ch=getchar(); while((ch<‘0‘||ch>‘9‘)&&ch!=‘-‘) ch=getchar(); if(ch==‘-‘) q=1,ch=getchar(); while(ch>=‘0‘&&ch<=‘9‘) a=(a<<3)+(a<<1)+ch-48,ch=getchar(); return q?-a:a; } const int N=200100; int n,m,r,mod,u,v,w[N],wn[N],head[N],dep[N],fa[N],top[N],size[N],son[N],id[N],sum[N<<2],tag[N<<2],cnt=0,tim=0,op,x,y,z; struct EDGE{int to,nxt;}e[N<<1]; void add(int u,int v) { e[++cnt]=(EDGE){v,head[u]}; head[u]=cnt; } void pushdown(int lnum,int rnum,int o) { if(tag[o]) { tag[o<<1]+=tag[o]; tag[o<<1|1]+=tag[o]; (sum[o<<1]+=tag[o]*lnum)%=mod; (sum[o<<1|1]+=tag[o]*rnum)%=mod; tag[o]=0; } } void build(int l,int r,int o) { if(l==r){sum[o]=wn[l]%mod;return;} int mid=(l+r)>>1; build(lson); build(rson); sum[o]=(sum[o<<1]+sum[o<<1|1])%mod; } void update(int L,int R,int val,int l,int r,int o) { if(L<=l&&r<=R){tag[o]+=val,(sum[o]+=(r-l+1)*val)%=mod;return;} int mid=(l+r)>>1; pushdown(mid-l+1,r-mid,o); if(L<=mid) update(L,R,val,lson); if(R> mid) update(L,R,val,rson); sum[o]=(sum[o<<1]+sum[o<<1|1])%mod; } int query(int L,int R,int l,int r,int o) { if(L<=l&&r<=R) {return sum[o]%mod;} int mid=(l+r)>>1,ans=0; pushdown(mid-l+1,r-mid,o); if(L<=mid) (ans+=query(L,R,lson))%=mod; if(R> mid) (ans+=query(L,R,rson))%=mod; return ans; } void dfs1(int u,int ff) { dep[u]=dep[ff]+1;fa[u]=ff;size[u]=1; int maxn=0; for(int i=head[u],v=e[i].to;i;i=e[i].nxt,v=e[i].to) if(v!=ff) { dfs1(v,u); size[u]+=size[v]; if(size[v]>maxn) maxn=size[v],son[u]=v; } } void dfs2(int u,int topf) { id[u]=++tim;wn[tim]=w[u]; top[u]=topf; if(!son[u]) return; dfs2(son[u],topf); for(int i=head[u],v=e[i].to;i;i=e[i].nxt,v=e[i].to) if(v!=fa[u]&&v!=son[u]) dfs2(v,v); } void update_tree(int u,int v,int val) { while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v); update(id[top[u]],id[u],val,1,n,1); u=fa[top[u]]; } if(dep[u]<dep[v]) swap(u,v); update(id[v],id[u],val,1,n,1); } int query_tree(int u,int v) { int ans=0; while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v); ans=(ans+query(id[top[u]],id[u],1,n,1))%mod; u=fa[top[u]]; } if(dep[u]<dep[v]) swap(u,v); return (ans+query(id[v],id[u],1,n,1))%mod; } int main() { n=read(),m=read(),r=read(),mod=read(); for(register int i=1;i<=n;i++) scanf("%d",&w[i]); for(register int i=1;i<n;i++) { u=read(),v=read(); add(u,v);add(v,u); } dfs1(r,r);dfs2(r,r); build(1,n,1); while(m--) { op=read(); if(op==1) { x=read(),y=read(),z=read(); update_tree(x,y,z); } else if(op==2) { x=read(),y=read(); printf("%d ",query_tree(x,y)%mod); } else if(op==3) { x=read(),y=read(); update(id[x],id[x]+size[x]-1,y,1,n,1); } else { x=read(); printf("%d ",query(id[x],id[x]+size[x]-1,1,n,1)); } } return 0; }
以上是关于树链剖分的主要内容,如果未能解决你的问题,请参考以下文章