树上莫队和普通的序列莫队很像,我们把树进行dfs,然后存一个长度为2n的括号序列,就是一个点进去当作左括号,出来当作右括号,然后如果访问从u到v路径,我们可以转化成括号序列的区间,记录x进去的时候编号为f[x],出来时为g[x],然后分类讨论一下(f[u]<f[v]),如果u和v的lca不是u,那么就是从g[u]到f[v],否则就是lca的f到另一个点的f,(可以自己试一下,中间过程没有用的点正好就抵消掉了)这里要注意一下,从g[u]到f[v]的时候我们会少掉lca这个点,特殊处理一下即可,然后按照普通莫队排一下序,暴力就行了。 —— by VANE
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N=100005; int n,m,cnt1,cnt2,tot,clk,f[N],g[N]; vector<int> M[N]; int id[N<<1],blg[N<<1]; int bin[25],pos[N],fa[N][17],c[N],d[N]; int v[N],w[N],last[N],u[N]; bool vis[N]; struct node { int l,r,t,id; }a[N],b[N]; ll ans[N],sum; void dfs(int x) { f[x]=++clk;id[clk]=x; for(int i=1;bin[i]<=d[x];++i) fa[x][i]=fa[fa[x][i-1]][i-1]; for(int i=0;i<M[x].size();++i) { int y=M[x][i]; if(y!=fa[x][0]) { fa[y][0]=x; d[y]=d[x]+1; dfs(y); } } g[x]=++clk; id[clk]=x; } int lca(int x,int y) { if(d[x]<d[y]) swap(x,y); int tmp=d[x]-d[y]; for(int i=0;bin[i]<=tmp;++i) if(tmp&bin[i]) x=fa[x][i]; if(x==y) return x; for(int i=16;i>=0;--i) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0]; } bool cmp(node x,node y) { if(blg[x.l]<blg[y.l]) return 1; if(blg[x.l]==blg[y.l]&&blg[x.r]<blg[y.r]) return 1; if(blg[x.l]==blg[y.l]&&blg[x.r]==blg[y.r]) return x.t<y.t; return 0; } void modify(int x) { if(vis[x]) sum-=1ll*v[c[x]]*w[u[c[x]]--]; else sum+=1ll*v[c[x]]*w[++u[c[x]]]; vis[x]^=1; } void change(int x,int y) { if(vis[x]) {modify(x);c[x]=y;modify(x);} else c[x]=y; } int main() { int cas; scanf("%d%d%d",&n,&m,&cas); bin[0]=1;for(int i=1;i<=17;++i) bin[i]=bin[i-1]<<1; for(int i=1;i<=m;++i) scanf("%d",v+i); for(int i=1;i<=n;++i) scanf("%d",w+i); for(int i=1;i<n;++i) { int l,r;scanf("%d%d",&l,&r); M[l].push_back(r); M[r].push_back(l); } for(int i=1;i<=n;++i) scanf("%d",c+i),last[i]=c[i]; int sz=pow(n,2.0/3); dfs(1); for(int i=1;i<=clk;++i) blg[i]=(i-1)/sz; while(cas--) { int l,r,t; scanf("%d%d%d",&t,&l,&r); if(t) { if(f[l]>f[r]) swap(l,r); a[++cnt1].r=f[r];a[cnt1].t=cnt2; a[cnt1].id=cnt1; a[cnt1].l=(lca(l,r)==l)?f[l]:g[l]; } else { b[++cnt2].l=l;b[cnt2].t=last[l]; last[l]=b[cnt2].r=r; } } sort(a+1,a+1+cnt1,cmp); int l=1,r=0,t=1; for(int i=1;i<=cnt1;++i) { for(;t<=a[i].t;++t) change(b[t].l,b[t].r); for(;t>a[i].t;--t) change(b[t].l,b[t].t); while(l>a[i].l) modify(id[--l]); while(l<a[i].l) modify(id[l++]); while(r>a[i].r) modify(id[r--]); while(r<a[i].r) modify(id[++r]); int x=id[l],y=id[r],tmp=lca(x,y); if(x!=tmp&&y!=tmp) {modify(tmp);ans[a[i].id]=sum;modify(tmp);} else ans[a[i].id]=sum; } for(int i=1;i<=cnt1;++i) printf("%lld\n",ans[i]); }