题意:给一棵树,有三个操作:①询问两点$(x,y)$之间的距离②把$x$和原来的父亲断开并连到它的$h$级祖先,作为新父亲最右的儿子③询问与根节点距离为$k$的点中最右的点是哪个点
用出栈入栈序$s_{1\cdots 2n}$来维护整棵树,入栈记$1$出栈记$-1$,那么一个节点$x$的深度就是$\sum\limits_{i=1}^{in_x}s_x$
每个平衡树节点记1.这个节点是出栈还是入栈2.子树和3.最大前缀和4.最小前缀和,那么我们就可以在平衡树上二分找到最右的深度为$d$的节点(注意如果找到的是出栈点应该返回父亲,因为有个$-1$)
对于操作①,把$(in_x,in_y)$提出来,那么这个区间内深度最小的节点就是$lca_{x,y}$
对于操作②,找到那个$h$级祖先,直接序列移动即可
对于操作③,直接找
为了使我的splay不残废就用splay写了一下
注意因为邻接表的性质,加边要倒着加
#include<stdio.h> int ch[200010][2],fa[200010],v[200010],s[200010],mx[200010],mn[200010],h[100010],nex[100010],to[100010],pa[100010],tmp[100010],p[200010],M,rt; void add(int a,int b){ M++; to[M]=b; nex[M]=h[a]; h[a]=M; } void dfs(int x){ M++; p[M]=(x<<1)-1; v[(x<<1)-1]=1; for(int i=h[x];i;i=nex[i]){ pa[to[i]]=x; dfs(to[i]); } M++; p[M]=x<<1; v[x<<1]=-1; } #define ls ch[x][0] #define rs ch[x][1] int max(int a,int b){return a>b?a:b;} int min(int a,int b){return a<b?a:b;} void pushup(int x){ s[x]=s[ls]+s[rs]+v[x]; mx[x]=max(mx[ls],s[ls]+v[x]+max(mx[rs],0)); mn[x]=min(mn[ls],s[ls]+v[x]+min(mn[rs],0)); } int build(int l,int r){ int mid=(l+r)>>1; int&x=p[mid]; if(l<mid){ ls=build(l,mid-1); fa[ls]=x; } if(mid<r){ rs=build(mid+1,r); fa[rs]=x; } pushup(x); return x; } void rot(int x){ int y,z,f,B; y=fa[x]; z=fa[y]; f=(ch[y][0]==x); B=ch[x][f]; fa[x]=z; fa[y]=x; if(B)fa[B]=y; ch[x][f]=y; ch[y][f^1]=B; if(ch[z][0]==y)ch[z][0]=x; if(ch[z][1]==y)ch[z][1]=x; pushup(y); pushup(x); } void splay(int x,int gl){ int y,z; while(fa[x]!=gl){ y=fa[x]; z=fa[y]; if(z!=gl)rot((ch[z][0]==y&&ch[y][0]==x)||(ch[z][1]==y&&ch[y][1]==x)?y:x); rot(x); } } int getdis(int x,int y){ x=(x<<1)-1; y=(y<<1)-1; int dx,dy,dl; splay(x,0); dx=s[ls]+v[x]; splay(y,0); dy=s[ch[y][0]]+v[y]; splay(x,0); splay(y,x); rt=x; dl=min(dx,dy); if(ls==y) dl=min(dl,s[ch[y][0]]+v[y]+mn[ch[y][1]]); else dl=min(dl,s[ls]+v[x]+mn[ch[y][0]]); return dx+dy-(dl<<1); } int find(int x,int d){ if(mx[rs]>=d-s[ls]-v[x]&&mn[rs]<=d-s[ls]-v[x])return find(rs,d-s[ls]-v[x]); if(s[ls]+v[x]==d)return(x&1)?(x+1)>>1:pa[x>>1]; return find(ls,d); } int pre(int x){ splay(x,0); for(x=ls;rs;x=rs); return x; } int nx(int x){ splay(x,0); for(x=rs;ls;x=ls); return x; } void change(int u,int h){ int x=(u<<1)-1,L,R,t; splay(x,0); pa[u]=find(ls,s[ls]+v[x]-h); L=pre(x); R=nx(u<<1); splay(L,0); splay(R,L); t=ch[R][0]; ch[R][0]=0; pushup(R); pushup(L); L=pre(pa[u]<<1); R=(pa[u]<<1); splay(L,0); splay(R,L); ch[R][0]=t; fa[t]=R; pushup(R); pushup(L); rt=L; } #define inf 1000000000 int main(){ mx[0]=-inf; mn[0]=inf; int n,m,i,x,y; scanf("%d%d",&n,&m); for(i=1;i<=n;i++){ scanf("%d",&y); for(x=1;x<=y;x++)scanf("%d",tmp+x); for(x=y;x>0;x--)add(i,tmp[x]); } M=0; dfs(1); rt=build(1,n<<1); while(m--){ scanf("%d%d",&i,&x); if(i!=3)scanf("%d",&y); if(i==1)printf("%d\n",getdis(x,y)); if(i==2)change(x,y); if(i==3)printf("%d\n",find(rt,x+1)); } }