sol
用set维护有宝物的点集。
可以证明行走路径\(a[1],a[2]...a[n]\)一定是按照点的dfs序排列。
因为\(dist(u,v)=dep[u]+dep[v]+2*dep[lca(u,v)]\),dfs序相邻可以最小化\(dep[lca(u,v)]\)。
插入要加上两个新贡献并减去一个旧贡献,删除也是减去两个旧贡献并加上一个新贡献。
code
#include<cstdio>
#include<algorithm>
#include<set>
using namespace std;
#define ll long long
int gi()
{
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
const int N = 1e5+5;
struct edge{int to,next,w;}a[N<<1];
int n,m,head[N],cnt,fa[N],dep[N],sz[N],son[N],top[N],dfn[N];ll dis[N],ans;
struct node{
int u;
bool operator < (const node &b) const
{return dfn[u]<dfn[b.u];}
};
set<node>S;
void link(int u,int v,int w){a[++cnt]=(edge){v,head[u],w};head[u]=cnt;}
void dfs1(int u,int f)
{
fa[u]=f;dep[u]=dep[f]+1;sz[u]=1;
for (int e=head[u];e;e=a[e].next)
{
int v=a[e].to;if (v==f) continue;
dis[v]=dis[u]+a[e].w;
dfs1(v,u);
sz[u]+=sz[v];if (sz[v]>sz[son[u]]) son[u]=v;
}
}
void dfs2(int u,int up)
{
top[u]=up;dfn[u]=++cnt;
if (son[u]) dfs2(son[u],up);
for (int e=head[u];e;e=a[e].next)
if (a[e].to!=fa[u]&&a[e].to!=son[u])
dfs2(a[e].to,a[e].to);
}
int lca(int u,int v)
{
while (top[u]^top[v])
{
if (dep[top[u]]<dep[top[v]]) swap(u,v);
u=fa[top[u]];
}
return dep[u]<dep[v]?u:v;
}
ll dist(int u,int v){return dis[u]+dis[v]-dis[lca(u,v)]*2;}
set<node>::iterator pre(set<node>::iterator t)
{
if (t==S.begin()) t=S.end();
t--;
return t;
}
set<node>::iterator nxt(set<node>::iterator t)
{
t++;
if (t==S.end()) t=S.begin();
return t;
}
set<node>::iterator t,tl,tr;
int main()
{
n=gi();m=gi();
for (int i=1,u,v,w;i<n;++i)
{
u=gi();v=gi();w=gi();
link(u,v,w);link(v,u,w);
}
dfs1(1,0);cnt=0;dfs2(1,1);
while (m--)
{
int u=gi();
if (S.find((node){u})==S.end())
{
S.insert((node){u});
t=S.find((node){u});
tl=pre(t);tr=nxt(t);
ans+=dist(u,(*tl).u)+dist(u,(*tr).u);
ans-=dist((*tl).u,(*tr).u);
}
else
{
if (S.size()==1) {S.erase((node){u});puts("0");continue;}
t=S.find((node){u});
tl=pre(t);tr=nxt(t);
ans-=dist(u,(*tl).u)+dist(u,(*tr).u);
ans+=dist((*tl).u,(*tr).u);
S.erase((node){u});
}
printf("%lld\n",ans);
}
return 0;
}