Splay(伸展树分裂树):平衡二叉搜索树中功能最丰富的树

Posted aininot260

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Splay(伸展树分裂树):平衡二叉搜索树中功能最丰富的树相关的知识,希望对你有一定的参考价值。

这是我第一篇对高级数据结构的描述,如有不准确的地方还请指出,谢谢~

调这颗树的代码调的头皮发麻,和线段树根本不是一个难度的。

首先简单地介绍一下这棵平衡BST中的另类

这是一棵没有任何平衡因子的BST,它依靠均摊来达到O(logn)的插入查询和删除复杂度,常数比较大

而且,它的具有其他BST所不具备的,对于子树的任意分裂和合并的功能

下面我从定义讲起,剖析这棵树实现过程中的每一个细节

const int INF=1000000000;
const int maxn=1000005;
int n,m;
int a[maxn];
struct Tree
{
    int fa,ch[2];
    int size;
    int v,sum,mx,lx,rx;
    bool tag,rev;
}t[maxn];
int node[maxn];
int root;
int cnt=0;
queue<int> q;

定义部分,注意INF如果太大可能会有越界的危险,所以这里长了个教训

树中包含n个点,对于树中的每一个点,有如下定义:

fa里存父节点的ID,ch[]里存左右孩子的ID,size里存该节点所形成子树包含的总节点个数

v表示节点的值,sum表示该节点所形成子树的值的和

mx,lx,rx是需要维护的一个关系,其中:

mx是当前子树的最大子串和

lx是当前子树以左端点为起点,向右延伸的最大子串和,rx同理(这里的实现思路类似于SCOI2010的线段树的那道题)

tag是区间重置标记,rev是区间翻转标记,区间即子树

q是一个垃圾回收的缓冲数组,将在后面介绍

建树部分:

void buildtree(int l,int r,int f)
{
    if(l>r)
        return;
    int mid=(l+r)>>1;
    int o=node[mid];
    int last=node[f];
    if(l==r)
    {
        t[o].sum=a[l];
        t[o].size=1;
        t[o].tag=t[o].rev=0;
        if(a[l]>=0)
            t[o].lx=t[o].rx=t[o].mx=a[l];
        else
        {
            t[o].lx=t[o].rx=0;
            t[o].mx=a[l];
        }
    }
    else
    {
        buildtree(l,mid-1,mid);
        buildtree(mid+1,r,mid);
    }
    t[o].v=a[mid];
    t[o].fa=last;
    update(o);
    t[last].ch[mid>=f]=o;
}

依然是递归建树,由于是均摊的,根节点只要从中间随便找一个就好了

这里可以类比线段树的建树过程,比线段树稍微复杂一些

建树过程中用到了一个函数update:

void update(int o)
{
    int l=t[o].ch[0],r=t[o].ch[1];
    t[o].sum=t[l].sum+t[r].sum+t[o].v;
    t[o].size=t[l].size+t[r].size+1;
    t[o].mx=max(t[l].mx,t[r].mx);
    t[o].mx=max(t[o].mx,t[l].rx+t[o].v+t[r].lx);
    t[o].lx=max(t[l].lx,t[l].sum+t[o].v+t[r].lx);
    t[o].rx=max(t[r].rx,t[r].sum+t[o].v+t[l].rx);
}

它的作用是随着子树的建立随时去更新子树根节点所维护的那些值

接下来介绍插入操作,它可以在特定位置插入tot个元素:

void insert(int k,int tot)
{
    for(int i=1;i<=tot;i++)
        cin>>a[i];
    for(int i=1;i<=tot;i++)
    if(!q.empty())
    {
        node[i]=q.front();
        q.pop();
    }
    else
    {
        node[i]=++cnt;
    }

    buildtree(1,tot,0);
    int z=node[(1+tot)>>1];
    int x=find(root,k+1);
    int y=find(root,k+2);
    splay(x,root);
    splay(y,t[x].ch[1]);
    t[z].fa=y;
    t[y].ch[0]=z;
    update(y);
    update(x);
}

大致过程如下。用着tot个元素生成一棵新的子树,子树上随便找一个点作为根节点

之后对原树伸展操作,首先找到k+1与k+2位置的节点ID

然后将k+1位置的节点伸展到根节点位置,将k+2位置的节点伸展到新根节点(就是刚才伸展过来的)的右子树位置

然后直接把新生产的子树的根节点接在k+2位置节点的左子树位置上就好了,形成之字形结构,有三层,LRL,分别是XYZ

有了这种插入的思路,我们可以把任意的东西接在任意的位置,这就是合并操作的原理

在介绍删除操作之前,我们先介绍一下split操作:

int split(int k,int tot)
{
    int x=find(root,k),y=find(root,k+tot+1);
    splay(x,root);
    splay(y,t[x].ch[1]);
    return t[y].ch[0];
}

其作用是把k位置之后的tot个节点形成的子树拎出来,放在之前所描述的之字形结构的Z位置,然后就能随意处理这个子树了,你可以把它删掉,也可以再把它接到其他位置

就比如刚才在描述插入操作的时候的那个Z位置,这样就实现了把区间的一段挪到区间的另一个位置去。分裂合并操作也就隐式实现了

既然现在这个split出来的子树已经和原来的树没有任何关系了,我就开一个新变量存一下它的根节点,然后把它和原树的联系切断,为了介绍切断,我们这里引出删除函数:

void erase(int k,int tot)
{
    int x=split(k,tot);
    int y=t[x].fa;
    rec(x);
    t[y].ch[0]=0;
    update(y);
    update(t[y].fa);
}

其删除的实现中有一个回收节点空间的函数rec,我们如果此时不做回收,而是像刚才那样,开一个新变量存这个拎出来的节点,就能既切断联系又拎出来了一棵子树。美滋滋

接下来我们稍微改动一下insert函数,把其中z替换成我们拎出来的这个子树

综上所述,区间的任意分裂合并都可以实现了

其实除了分裂和合并操作之外,就是BST的最基本的套路了

我们接着介绍,刚才引出了好几个陌生的函数,先说离咱们最近的那个,rec,回收子树节点函数:

void rec(int x)
{
    if(!x)
        return;
    int l=t[x].ch[0],r=t[x].ch[1];
    rec(l);rec(r);q.push(x);
    t[x].fa=t[x].ch[0]=t[x].ch[1]=0;
    t[x].tag=t[x].rev=0;
}

这个函数存在的意义就是重复利用Tree结构的空间,避免因数组连续性造成的空间浪费,这里学了一招,应该很实用

然后我们介绍find函数,这个函数就是用来找指定位置的节点ID的

int find(int o,int rk)
{
    pushdown(o);
    int l=t[o].ch[0],r=t[o].ch[1];
    if(t[l].size+1==rk)
        return o;
    if(t[l].size>=rk)
        return find(l,rk);
    return find(r,rk-t[l].size-1);
}

在o形成的子树中找第k个节点的ID

这里又引出了pushdown函数

void pushdown(int x)
{
    int l=t[x].ch[0];
    int r=t[x].ch[1];
    if(t[x].tag)
    {
        t[x].rev=t[x].tag=0;
        if(l)
        {
            t[l].tag=1;
            t[l].v=t[x].v;
            t[l].sum=t[x].v*t[l].size;
        }
        if(r)
        {
            t[r].tag=1;
            t[r].v=t[x].v;
            t[r].sum=t[x].v*t[r].size;
        }
        if(t[x].v>=0)
        {
            if(l)
                t[l].lx=t[l].rx=t[l].mx=t[l].sum;
            if(r)
                t[r].lx=t[r].rx=t[r].mx=t[r].sum;
        }
        else
        {
            if(l)
                t[l].lx=t[l].rx=0,t[l].mx=t[x].v;
            if(r)
                t[r].lx=t[r].rx=0,t[r].mx=t[x].v;
        }
    }
    if(t[x].rev)
    {
        t[x].rev^=1;t[l].rev^=1;t[r].rev^=1;
        swap(t[l].lx,t[l].rx),swap(t[r].lx,t[r].rx);
        swap(t[l].ch[0],t[l].ch[1]);
        swap(t[r].ch[0],t[r].ch[1]);
    }
}

它是用来处理节点上的翻转和重置标记的

说了半天都没有说splay函数,下面介绍

void splay(int x,int &k)
{
    while(x!=k)
    {
        int y=t[x].fa;
        int z=t[y].fa;
        if(y!=k)
        {
            if(t[y].ch[0]==x^t[z].ch[0]==y)
                rotate(x,k);
            else
                rotate(y,k);
        }
        rotate(x,k);
    }
}

这个函数可以把一个节点伸展到指定根节点的位置(可以是子树根),这个时候这个节点就是新的根节点了

这是整个伸展树最核心的函数,一定要理解

然后就是旋转操作,其实刚开始旋转我是不会的,后来我是去看了AVL的四种旋转方式才过来看的Splay

我们先回顾一下AVL的四种旋转方式,什么是AVL(一种特别特别正经的平衡树,不想用)

LL,RR,LR,RL,左旋就是往左转,右旋就是往右转,转的时候要搞清楚转哪条边,把谁转过去把谁转过来,转完之后有时候要拆接,怎么处理。搞清楚之后,理解Splay的一字形旋转和之字形旋转就容易多了

Splay双旋(只有双旋才叫Splay)的目的不是为了平衡,而是为了协助完成Splay操作,一次Splay就把沿途的所有的点都转了一遍

下面给出旋转的函数,写的比较硬核:

void rotate(int x,int &k)
{
    int y=t[x].fa;
    int z=t[y].fa;
    int l=(t[y].ch[1]==x);
    int r=l^1;
    if(y==k)
        k=x;
    else
        t[z].ch[t[z].ch[1]==y]=x;
    t[t[x].ch[r]].fa=y;
    t[y].fa=x;
    t[x].fa=z;
    t[y].ch[l]=t[x].ch[r];
    t[x].ch[r]=y;
    update(y);
    update(x);
}

后面的就比较水了,重点都已经介绍完了,再说说修改函数,把tot个树都修改为指定的值

void modify(int k,int tot,int val)
{
    int x=split(k,tot);
    int y=t[x].fa;
    t[x].v=val,t[x].tag=1,t[x].sum=t[x].size*val;
    if(val>=0)
        t[x].lx=t[x].rx=t[x].mx=t[x].sum;
    else
        t[x].lx=t[x].rx=0,t[x].mx=val;
    update(y);
    update(t[y].fa);
}

原理很简单,打标记,更新节点参数

然后是区间翻转:

void rever(int k,int tot)
{
    int x=split(k,tot);
    int y=t[x].fa;
    if(!t[x].tag)
    {
        t[x].rev^=1;
        swap(t[x].ch[0],t[x].ch[1]);
        swap(t[x].lx,t[x].rx);
        update(y);
        update(t[y].fa);
    }
}

也是打标记,反正有处理标记的函数,不怕,这里和线段树的lazy Tag的原理类似,只不过这里的树实在是很复杂

再说说查询:

void query(int k,int tot)
{
    int x=split(k,tot);
    cout<<t[x].sum<<endl;
}

只查个sum太水了,我们可以中序遍历是不是,BST家族都可以的

特别特别特别要注意的一点,在执行分裂合并操作的时候,如果你没有保证分裂合并能够满足BST性质,那么因为你的分裂合并,这棵Splay可能就不再是一棵BST了

你要特别小心这一点,仔细审题,如果只让你分裂合并,那好说,如果涉及到了查询什么的,一定要看清楚题意才行

最后,我们给出debug了三个小时的Template,感谢黄学长~

  1 #include<queue>
  2 #include<iostream>
  3 #include<cstdlib>
  4 #include<algorithm>
  5 using namespace std;
  6 const int INF=1000000000;
  7 const int maxn=1000005;
  8 int n,m;
  9 int a[maxn];
 10 struct Tree
 11 {
 12     int fa,ch[2];
 13     int size;
 14     int v,sum,mx,lx,rx;
 15     bool tag,rev;
 16 }t[maxn];
 17 int node[maxn];
 18 int root;
 19 int cnt=0;
 20 queue<int> q;
 21 void update(int o)
 22 {
 23     int l=t[o].ch[0],r=t[o].ch[1];
 24     t[o].sum=t[l].sum+t[r].sum+t[o].v;
 25     t[o].size=t[l].size+t[r].size+1;
 26     t[o].mx=max(t[l].mx,t[r].mx);
 27     t[o].mx=max(t[o].mx,t[l].rx+t[o].v+t[r].lx);
 28     t[o].lx=max(t[l].lx,t[l].sum+t[o].v+t[r].lx);
 29     t[o].rx=max(t[r].rx,t[r].sum+t[o].v+t[l].rx);
 30 }
 31 void pushdown(int x)
 32 {
 33     int l=t[x].ch[0];
 34     int r=t[x].ch[1];
 35     if(t[x].tag)
 36     {
 37         t[x].rev=t[x].tag=0;
 38         if(l)
 39         {
 40             t[l].tag=1;
 41             t[l].v=t[x].v;
 42             t[l].sum=t[x].v*t[l].size;
 43         }
 44         if(r)
 45         {
 46             t[r].tag=1;
 47             t[r].v=t[x].v;
 48             t[r].sum=t[x].v*t[r].size;
 49         }
 50         if(t[x].v>=0)
 51         {
 52             if(l)
 53                 t[l].lx=t[l].rx=t[l].mx=t[l].sum;
 54             if(r)
 55                 t[r].lx=t[r].rx=t[r].mx=t[r].sum;
 56         }
 57         else
 58         {
 59             if(l)
 60                 t[l].lx=t[l].rx=0,t[l].mx=t[x].v;
 61             if(r)
 62                 t[r].lx=t[r].rx=0,t[r].mx=t[x].v;
 63         }
 64     }
 65     if(t[x].rev)
 66     {
 67         t[x].rev^=1;t[l].rev^=1;t[r].rev^=1;
 68         swap(t[l].lx,t[l].rx),swap(t[r].lx,t[r].rx);
 69         swap(t[l].ch[0],t[l].ch[1]);
 70         swap(t[r].ch[0],t[r].ch[1]);
 71     }
 72 }
 73 void rotate(int x,int &k)
 74 {
 75     int y=t[x].fa;
 76     int z=t[y].fa;
 77     int l=(t[y].ch[1]==x);
 78     int r=l^1;
 79     if(y==k)
 80         k=x;
 81     else
 82         t[z].ch[t[z].ch[1]==y]=x;
 83     t[t[x].ch[r]].fa=y;
 84     t[y].fa=x;
 85     t[x].fa=z;
 86     t[y].ch[l]=t[x].ch[r];
 87     t[x].ch[r]=y;
 88     update(y);
 89     update(x);
 90 }
 91 void splay(int x,int &k)
 92 {
 93     while(x!=k)
 94     {
 95         int y=t[x].fa;
 96         int z=t[y].fa;
 97         if(y!=k)
 98         {
 99             if(t[y].ch[0]==x^t[z].ch[0]==y)
100                 rotate(x,k);
101             else
102                 rotate(y,k);
103         }
104         rotate(x,k);
105     }
106 }
107 int find(int o,int rk)
108 {
109     pushdown(o);
110     int l=t[o].ch[0],r=t[o].ch[1];
111     if(t[l].size+1==rk)
112         return o;
113     if(t[l].size>=rk)
114         return find(l,rk);
115     return find(r,rk-t[l].size-1);
116 }
117 void rec(int x)
118 {
119     if(!x)
120         return;
121     int l=t[x].ch[0],r=t[x].ch[1];
122     rec(l);rec(r);q.push(x);
123     t[x].fa=t[x].ch[0]=t[x].ch[1]=0;
124     t[x].tag=t[x].rev=0;
125 }
126 int split(int k,int tot)
127 {
128     int x=find(root,k),y=find(root,k+tot+1);
129     splay(x,root);
130     splay(y,t[x].ch[1]);
131     return t[y].ch[0];
132 }
133 void query(int k,int tot)
134 {
135     int x=split(k,tot);
136     cout<<t[x].sum<<endl;
137 }
138 void modify(int k,int tot,int val)
139 {
140     int x=split(k,tot);
141     int y=t[x].fa;
142     t[x].v=val,t[x].tag=1,t[x].sum=t[x].size*val;
143     if(val>=0)
144         t[x].lx=t[x].rx=t[x].mx=t[x].sum;
145     else
146         t[x].lx=t[x].rx=0,t[x].mx=val;
147     update(y);
148     update(t[y].fa);
149 }
150 void rever(int k,int tot)
151 {
152     int x=split(k,tot);
153     int y=t[x].fa;
154     if(!t[x].tag)
155     {
156         t[x].rev^=1;
157         swap(t[x].ch[0],t[x].ch[1]);
158         swap(t[x].lx,t[x].rx);
159         update(y);
160         update(t[y].fa);
161     }
162 }
163 void erase(int k,int tot)
164 {
165     int x=split(k,tot);
166     int y=t[x].fa;
167     rec(x);
168     t[y].ch[0]=0;
169     update(y);
170     update(t[y].fa);
171 }
172 void buildtree(int l,int r,int f)
173 {
174     if(l>r)
175         return;
176     int mid=(l+r)>>1;
177     int o=node[mid];
178     int last=node[f];
179     if(l==r)
180     {
181         t[o].sum=a[l];
182         t[o].size=1;
183         t[o].tag=t[o].rev=0;
184         if(a[l]>=0)
185             t[o].lx=t[o].rx=t[o].mx=a[l];
186         else
187         {
188             t[o].lx=t[o].rx=0;
189             t[o].mx=a[l];
190         }
191     }
192     else
193     {
194         buildtree(l,mid-1,mid);
195         buildtree(mid+1,r,mid);
196     }
197     t[o].v=a[mid];
198     t[o].fa=last;
199     update(o);
200     t[last].ch[mid>=f]=o;
201 }
202 void insert(int k,int tot)
203 {
204     for(int i=1;i<=tot;i++)
205         cin>>a[i];
206     for(int i=1;i<=tot;i++)
207     if(!q.empty())
208     {
209         node[i]=q.front();
210         q.pop();
211     }
212     else
213     {
214         node[i]=++cnt;
215     }
216 
217     buildtree(1,tot,0);
218     int z=node[(1+tot)>>1];
219     int x=find(root,k+1);
220     int y=find(root,k+2);
221     splay(x,root);
222     splay(y,t[x].ch[1]);
223     t[z].fa=y;
224     t[y].ch[0]=z;
225     update(y);
226     update(x);
227 }
228 int main()
229 {
230     std::ios::sync_with_stdio(false);
231     cin>>n>>m;
232     t[0].mx=a[1]=a[n+2]=-INF;
233     for(int i=1;i<=n;i++)
234         cin>>a[i+1];
235     for(int i=1;i<=n+2;i++)
236         node[i]=i;
237     buildtree(1,n+2,0);
238     root=(n+3)>>1;
239     cnt=n+2;
240     int k,tot,val;
241     char ch[15];
242     while(m--)
243     {
244         cin>>ch;
245         if(ch[0]!=M||ch[2]!=X)
246             cin>>k>>tot;
247         if(ch[0]==I)
248             insert(k,tot);
249         if(ch[0]==D)
250             erase(k,tot);
251         if(ch[0]==M)
252         {
253             if(ch[2]==X)
254                 cout<<t[root].mx<<endl;
255             else
256                 cin>>val,modify(k,tot,val);    
257         }
258         if(ch[0]==R)
259             rever(k,tot);
260         if(ch[0]==G)
261             query(k,tot);
262     }
263     return 0;
264 }

最后吐槽一下,这道题是NOI2005年的数列操作,也是我有生以来做的第二道NOI题目(这时候立刻想到了第一道是什么)

能够在考场上完美实现一个这么个东西,平时要付出多少努力可想而知,不是谁都可以去清华的

以上是关于Splay(伸展树分裂树):平衡二叉搜索树中功能最丰富的树的主要内容,如果未能解决你的问题,请参考以下文章

Splay伸展树入门(单点操作,区间维护)

Splay Tree(伸展树)

Codeforces 675D Tree Construction Splay伸展树

算法学习:伸展树(splay)

伸展树(Splay Tree)

伸展树(splay tree)