算法笔记--lca倍增算法
Posted Wisdom+.+
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了算法笔记--lca倍增算法相关的知识,希望对你有一定的参考价值。
算法笔记
模板:
vector<int>g[N]; vector<int>edge[N]; int anc[20][N]; int deep[N]; int h[N]; void dfs(int o,int u,int w) { if(u!=o)deep[u]=deep[o]+1,h[u]=h[o]+w; for(int j=0;j<g[u].size();j++) { if(g[u][j]!=o) { anc[0][g[u][j]]=u; for(int i=1;i<20;i++)anc[i][g[u][j]]=anc[i-1][anc[i-1][g[u][j]]]; dfs(u,g[u][j],edge[u][j]); } } } int lca(int u,int v) { if(deep[u]<deep[v])swap(u,v); for(int i=19;i>=0;i--)if(deep[anc[i][u]]>=deep[v])u=anc[i][u]; if(u==v)return u; for(int i=19;i>=0;i--)if(anc[i][u]!=anc[i][v])u=anc[i][u],v=anc[i][v]; return anc[0][u]; } int dis(int u,int v) { int l=lca(u,v); return h[u]+h[v]-2*h[l]; }
例题:
带权lca
代码:
#include<bits/stdc++.h> using namespace std; #define ll long long const int INF=0x3f3f3f3f; const int N=5e4+5; vector<int>g[N]; vector<int>edge[N]; int anc[20][N]; int deep[N]; int h[N]; void dfs(int o,int u,int w) { if(u!=o)deep[u]=deep[o]+1,h[u]=h[o]+w; for(int j=0;j<g[u].size();j++) { if(g[u][j]!=o) { anc[0][g[u][j]]=u; for(int i=1;i<20;i++)anc[i][g[u][j]]=anc[i-1][anc[i-1][g[u][j]]]; dfs(u,g[u][j],edge[u][j]); } } } int lca(int u,int v) { if(deep[u]<deep[v])swap(u,v); for(int i=19;i>=0;i--)if(deep[anc[i][u]]>=deep[v])u=anc[i][u]; if(u==v)return u; for(int i=19;i>=0;i--)if(anc[i][u]!=anc[i][v])u=anc[i][u],v=anc[i][v]; return anc[0][u]; } int dis(int u,int v) { int l=lca(u,v); return h[u]+h[v]-2*h[l]; } int main() { ios::sync_with_stdio(false); cin.tie(0); int n,u,v,c,m; cin>>n; for(int i=0;i<n-1;i++) { cin>>u>>v>>c; g[u].push_back(v); g[v].push_back(u); edge[u].push_back(c); edge[v].push_back(c); } cin>>m; for(int i=0;i<20;i++)anc[i][0]=0; dfs(0,0,0); while(m--) { cin>>u>>v; cout<<dis(u,v)<<endl; } return 0; }
普通lca
代码:
#include<bits/stdc++.h> using namespace std; #define ll long long #define ls rt<<1,l,m #define rs rt<<1|1,m+1,r const int INF=0x3f3f3f3f; const int N=1e5+5; vector<int>g[N]; int anc[20][N]; int deep[N]; void dfs(int o,int u) { if(o!=u)deep[u]=deep[o]+1; for(int j=0;j<g[u].size();j++) { if(o!=g[u][j]) { anc[0][g[u][j]]=u; for(int i=1;i<20;i++)anc[i][g[u][j]]=anc[i-1][anc[i-1][g[u][j]]]; dfs(u,g[u][j]); } } } int lca(int u,int v) { if(deep[u]<deep[v])swap(u,v); for(int i=19;i>=0;i--)if(deep[anc[i][u]]>=deep[v])u=anc[i][u]; if(u==v)return u; for(int i=19;i>=0;i--)if(anc[i][u]!=anc[i][v])u=anc[i][u],v=anc[i][v]; return anc[0][u]; } int dis(int u,int v) { return deep[u]+deep[v]-2*deep[lca(u,v)]; } void init() { for(int i=0;i<20;i++)anc[0][1]=1; dfs(0,1); } int main() { ios::sync_with_stdio(false); cin.tie(0); int n,m; cin>>n; for(int i=1;i<n;i++) { int a,b; cin>>a>>b; g[a].push_back(b); g[b].push_back(a); } init(); cin>>m; int a,b,ans=0; cin>>a; for(int i=1;i<m;i++) { cin>>b; ans+=dis(a,b); a=b; } cout<<ans<<endl; return 0; }
带权lca,当lca(a,b)==a时,韵韵才能参加。
代码:
#include<bits/stdc++.h> using namespace std; #define ll long long #define ls rt<<1,l,m #define rs rt<<1|1,m+1,r #define pb push_back const int INF=0x3f3f3f3f; const int N=1e4+5; vector<int>g[N]; vector<ll>edge[N]; int anc[20][N]; int deep[N]; ll h[N]; bool vis[N]={false}; int s; void dfs(int o,int u,ll w) { if(o!=u)deep[u]=deep[o]+1,h[u]=h[o]+w; for(int j=0;j<g[u].size();j++) { if(g[u][j]!=o) { anc[0][g[u][j]]=u; for(int i=1;i<20;i++)anc[i][g[u][j]]=anc[i-1][anc[i-1][g[u][j]]]; dfs(u,g[u][j],edge[u][j]); } } } int lca(int u,int v) { if(deep[u]<deep[v])swap(u,v); for(int i=19;i>=0;i--)if(deep[anc[i][u]]>=deep[v])u=anc[i][u]; if(u==v)return u; for(int i=19;i>=0;i--)if(anc[i][u]!=anc[i][v])u=anc[i][u],v=anc[i][v]; return anc[0][u]; } ll dis(int u,int v) { return h[u]+h[v]-2*h[lca(u,v)]; } void init() { for(int i=0;i<20;i++)anc[i][1]=1; dfs(0,1,1); } int main() { ios::sync_with_stdio(false); cin.tie(0); int n,m,a,b; ll t; cin>>n>>m; for(int i=1;i<n;i++) { cin>>a>>b>>t; g[a].pb(b); g[b].pb(a); edge[a].pb(t); edge[b].pb(t); } init(); int cnt=0; ll ans=0; for(int i=0;i<m;i++) { int a,b; cin>>a>>b; if(lca(a,b)==a) { cnt++; ans+=h[b]-h[a]; } } cout<<cnt<<endl; cout<<ans<<endl; return 0; }
带权lca
代码:
#include<bits/stdc++.h> using namespace std; #define ll long long #define pb push_back const int INF=0x3f3f3f3f; const int N=4e4+5; vector<int>g[N]; vector<int>edge[N]; int deep[N]; int h[N]; int anc[20][N]; bool vis[N]={false}; int s=1; void dfs(int o,int u,int w) { deep[u]=deep[o]+1; h[u]=h[o]+w; for(int j=0;j<g[u].size();j++) { if(g[u][j]!=o) { anc[0][g[u][j]]=u; for(int i=1;i<20;i++)anc[i][g[u][j]]=anc[i-1][anc[i-1][g[u][j]]]; dfs(u,g[u][j],edge[u][j]); } } } int lca(int u,int v) { if(deep[u]<deep[v])swap(u,v); for(int i=19;i>=0;i--)if(deep[anc[i][u]]>=deep[v])u=anc[i][u]; if(u==v)return u; for(int i=19;i>=0;i--)if(anc[i][u]!=anc[i][v])u=anc[i][u],v=anc[i][v]; return anc[0][u]; } ll dis(int u,int v) { return h[u]+h[v]-2*h[lca(u,v)]; } void init() { for(int i=0;i<20;i++)anc[i][1]=1; dfs(0,1,1); } int main() { ios::sync_with_stdio(false); cin.tie(0); int t; cin>>t; while(t--) { int n,m,a,b,k; cin>>n>>m; for(int i=1;i<n;i++) { cin>>a>>b>>k; g[a].pb(b); g[b].pb(a); edge[a].pb(k); edge[b].pb(k); vis[b]=true; } init(); for(int i=0;i<m;i++) { cin>>a>>b; cout<<dis(a,b)<<endl; } //cout<<endl; } return 0; }
带权lca,三点之间的最短路径公式h[a]+h[b]+h[c]-h[lca(a,b)]-h[lca(a,c)]-h[lca(b,c)]。
代码:
#include<bits/stdc++.h> using namespace std; #define ll long long #define pb push_back const int INF=0x3f3f3f3f; const int N=5e4+5; vector<int>g[N]; vector<int>edge[N]; int deep[N]; int h[N]; int anc[20][N]; void dfs(int o,int u,int w) { if(u!=o)deep[u]=deep[o]+1,h[u]=h[o]+w; for(int j=0;j<g[u].size();j++) { if(g[u][j]!=o) { anc[0][g[u][j]]=u; for(int i=1;i<20;i++)anc[i][g[u][j]]=anc[i-1][anc[i-1][g[u][j]]]; dfs(u,g[u][j],edge[u][j]); } } } int lca(int u,int v) { if(deep[u]<deep[v])swap(u,v); for(int i=19;i>=0;i--)if(deep[anc[i][u]]>=deep[v])u=anc[i][u]; if(u==v)return u; for(int i=19;i>=0;i--)if(anc[i][u]!=anc[i][v])u=anc[i][u],v=anc[i][v]; return anc[0][u]; } int dis(int u,int v) { return h[u]+h[v]-2*h[lca(u,v)]; } void init() { for(int i=0;i<20;i++)anc[i][0]=0; dfs(0,0,0); } int main() { int n,q,a,b,c; bool flag=true; while(~scanf("%d",&n)&&n) { if(flag) flag=false; else printf("\n"); for(int i=0;i<n;i++)g[i].clear(),edge[i].clear(); for(int i=1;i<n;i++) { scanf("%d%d%d",&a,&b,&c); g[a].pb(b); g[b].pb(a); edge[a].pb(c); edge[b].pb(c); } init(); scanf("%d",&q); while(q--) { scanf("%d%d%d",&a,&b,&c); printf("%d\n",h[a]+h[b]+h[c]-h[lca(a,b)]-h[lca(a,c)]-h[lca(b,c)]); } } return 0; }
以上是关于算法笔记--lca倍增算法的主要内容,如果未能解决你的问题,请参考以下文章