题解:按照点的Dfs序走
用Splay维护Dfs序即可
插入时找前驱和后继,插在中间
#include<iostream> #include<cstdio> #include<cstring> using namespace std; const int maxn=200009; typedef long long Lint; int n,m; int nn; int htr[maxn]; Lint ans; int cntedge; int head[maxn]; int to[maxn],nex[maxn],dist[maxn]; void Addedge(int x,int y,int z){ nex[++cntedge]=head[x]; to[cntedge]=y; dist[cntedge]=z; head[x]=cntedge; } int dfsclock; int idt[maxn],father[maxn],depth[maxn]; Lint d[maxn]; void Dfs(int now,int fa){ father[now]=fa; depth[now]=depth[fa]+1; idt[now]=++dfsclock; for(int i=head[now];i;i=nex[i]){ if(to[i]==fa)continue; d[to[i]]=d[now]+dist[i]; Dfs(to[i],now); } } int f[maxn][20]; void LCAinit(){ for(int i=1;i<=n;++i)f[i][0]=father[i]; for(int j=1;j<=19;++j){ for(int i=1;i<=n;++i){ f[i][j]=f[f[i][j-1]][j-1]; } } } int Getlca(int u,int v){ if(depth[u]<depth[v])swap(u,v); for(int j=19;j>=0;--j){ if(depth[f[u][j]]>=depth[v])u=f[u][j]; } if(u==v)return u; for(int j=19;j>=0;--j){ if(f[u][j]!=f[v][j]){ u=f[u][j];v=f[v][j]; } } return f[u][0]; } Lint Getd(int u,int v){ int lca=Getlca(u,v); return d[u]+d[v]-2*d[lca]; } int root; int fa[maxn],ch[maxn][2]; inline int son(int x){ if(ch[fa[x]][1]==x)return 1; else return 0; } inline void Rotate(int x){ int y=fa[x]; int z=fa[y]; int b=son(x),c=son(y); int a=ch[x][b^1]; if(z)ch[z][c]=x; else root=x; fa[x]=z; if(a)fa[a]=y; ch[y][b]=a; fa[y]=x;ch[x][b^1]=y; } void Splay(int x,int i){ while(fa[x]!=i){ int y=fa[x]; int z=fa[y]; if(z==i){ Rotate(x); }else{ if(son(x)==son(y)){ Rotate(y);Rotate(x); }else{ Rotate(x);Rotate(x); } } } } void Ins(int p){ int x=root,y=0; while(x){ y=x; if(idt[p]>idt[x])x=ch[x][1]; else x=ch[x][0]; } x=p; fa[x]=y;ch[x][0]=ch[x][1]=0; if(y){ if(idt[x]>idt[y])ch[y][1]=x; else ch[y][0]=x; }else{ root=x; } Splay(x,0); } void Del(int x){ Splay(x,0); if(ch[x][0]==0&&ch[x][1]==0){ root=0; }else if(ch[x][0]==0){ root=ch[x][1];fa[ch[x][1]]=0; }else if(ch[x][1]==0){ root=ch[x][0];fa[ch[x][0]]=0; }else{ int p=ch[x][0]; while(ch[p][1])p=ch[p][1]; Splay(p,x); ch[p][1]=ch[x][1];fa[ch[x][1]]=p; root=p;fa[p]=0; } fa[x]=ch[x][0]=ch[x][1]=0; } int Getpre(){ int x=ch[root][0]; if(x==0)return 0; while(ch[x][1])x=ch[x][1]; return x; } int Getsuf(){ int x=ch[root][1]; if(x==0)return 0; while(ch[x][0])x=ch[x][0]; return x; } int Gettop(){ int x=root; if(x==0)return 0; while(ch[x][0])x=ch[x][0]; return x; } int Getbot(){ int x=root; if(x==0)return 0; while(ch[x][1]){ x=ch[x][1]; } return x; } int main(){ scanf("%d%d",&n,&m); for(int i=1;i<=n-1;++i){ int x,y,z; scanf("%d%d%d",&x,&y,&z); Addedge(x,y,z); Addedge(y,x,z); } Dfs(1,0); LCAinit(); // cout<<"eating shti"<<endl; while(m--){ int x;scanf("%d",&x); if(htr[x]){ htr[x]=0;--nn; Splay(x,0); int pre=Getpre(); int suf=Getsuf(); if(pre&&suf){ ans-=Getd(pre,x); ans-=Getd(suf,x); ans+=Getd(pre,suf); }else if(pre){ ans-=Getd(pre,x); int toppoint=Gettop(); ans-=Getd(x,toppoint); ans+=Getd(pre,toppoint); }else if(suf){ ans-=Getd(suf,x); int botpoint=Getbot(); ans-=Getd(x,botpoint); ans+=Getd(suf,botpoint); } Del(x); }else{ htr[x]=1;++nn; Ins(x); int pre=Getpre(); int suf=Getsuf(); if(pre&&suf){ ans-=Getd(pre,suf); ans+=Getd(pre,x); ans+=Getd(suf,x); }else if(pre){ ans+=Getd(pre,x); int toppoint=Gettop(); ans-=Getd(toppoint,pre); ans+=Getd(toppoint,x); }else if(suf){ ans+=Getd(x,suf); int botpoint=Getbot(); ans-=Getd(botpoint,suf); ans+=Getd(botpoint,x); } } printf("%lld\n",ans); } return 0; }