算法模板-线段树

Posted 周先森爱吃素

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了算法模板-线段树相关的知识,希望对你有一定的参考价值。

本文为CSDN博主「maplezys」的原创文章,本文转载自https://blog.csdn.net/qq_41006629/article/details/124131368

简介

线段树(Segment Tree)是稍高级一点的数据结构,它一般用于维护区间信息。线段树是一棵平衡二叉树,其根结点代表着整个区间的信息,越往下的结点代表的区间越小,也就是说,线段树的每一个结点都对应着一条区间(线段)。

如果有一个数组是[1,2,3,4,5,6,7,8],那么它对应的线段树大致如下:

我们从下标1开始存储每个节点(较为方便),这样每个节点 x x x的左儿子节点为 2 x 2x 2x,右孩子为 2 x + 1 2 x + 1 2x+1。假设 x x x结点存储的是区间 [ l e f t , r i g h t ] [left, right] [left,right]的信息, m i d = ⌊ l e f t + r i g h t 2 ⌋ mid=\\left\\lfloor\\fracleft+right2\\right\\rfloor mid=2left+right,那么其左右儿子存储的分别是区间 [ l e f t , m i d ] [left, mid] [left,mid] [ m i d + 1 , r i g h t ] [mid+1, right] [mid+1,right]的信息。可以发现,由于mid的计算方式,因此左节点对应的区间长度要么和右节点相同,要么比之恰好多1。

详解

我们以一道简单的例题为例。

老师想知道从某某同学当中,分数最高的是多少,现在请你编程模拟老师的询问。当然,老师有时候需要更新某位同学的成绩。

输入描述:
输入包括多组测试数据。
每组输入第一行是两个正整数N和M(0 < N <= 30000,0 < M < 5000),分别代表学生的数目和操作的数目。
学生ID编号从1编到N。
第二行包含N个整数,代表这N个学生的初始成绩,其中第i个数代表ID为i的学生的成绩
接下来又M行,每一行有一个字符C(只取‘Q’或‘U’),和两个正整数A,B,当C为'Q'的时候, 表示这是一条询问操作,他询问ID从A到B(包括A,B)的学生当中,成绩最高的是多少
当C为‘U’的时候,表示这是一条更新操作,要求把ID为A的学生的成绩更改为B。

输出描述:
对于每一次询问操作,在一行里面输出最高成绩。

输入样例:
5 7
1 2 3 4 5
Q 1 5
U 3 6
Q 3 4
Q 4 5
U 4 5
U 2 9
Q 1 5

输出例子:
5
6
5
9

对于本题,我们可以分析出,线段树每个节点存储的就是当前区间的最大值,我们以线性表来存储这棵树,递归建树的代码如下。

    def build(pos, left, right):
        if left == right:
            self.value[pos] = self.data[left]
            return
        m = (left + right) >> 1
        l_child, r_child = pos << 1, pos << 1 | 1
        self.build(l_child, left, m)
        self.build(r_child, m + 1, right)
        self.value[pos] = max(self.value[l_child], self.value[r_child])

假如我要更新p结点的值,那么包含p结点的所有区间的结点均需被更新,也是采用递归的形式进行更新,代码如下。

    def update(idx, new_value, pos, left, right):
        if left == right and left == idx:
            self.value[pos] = new_value
            return
        m = (left +  right) >> 1
        if idx <= m: self.update(idx, new_value, pos << 1, left, m)
        if idx > m: self.update(idx, new_value, pos << 1 | 1, m + 1, right)
        self.value[pos] = max(self.value[pos<<1], self.value[pos<<1|1])

那此时如果我们需要查询指定区间的信息呢?依旧是递归查询,我们先贴出代码:

    def query(query_l, query_r, pos, cur_left, cur_right):
        if query_l <= cur_left and query_r >= cur_right: return self.value[pos]
        m = (cur_left + cur_right) >> 1
        l_ans, r_ans = -1, -1
        if query_l <= m: l_ans = self.query(query_l, query_r, pos << 1, cur_left, m)
        if query_r > m: r_ans = self.query(query_l, query_r, pos << 1 | 1, m + 1, cur_right)
        if l_ans == -1: return r_ans
        if r_ans == -1: return l_ans
        return max(l_ans, r_ans)

查询区间的时候,有三种情况。假设需要查询的区间为 [ L , R ] [ L , R ] [L,R] ,若是目标区间覆盖了当前区间,那么当前区间的最大值是需要的,直接返回。若没有完全覆盖,且若 L L L m i d mid mid的左边,那么需要去左孩子处( [ c u r _ l e f t , m i d ] [cur\\_left, mid ] [cur_left,mid])查询目标区间需要的信息,不然该值取-1。同理,若 R R R m i d mid mid的右边,则需要去右孩子处( [ m i d + 1 , c u r _ r i g h t ] [mid+1, cur\\_right] [mid+1,cur_right])查询目标区间需要的信息,不然该值取-1。上述两者中,若有一者为-1,则直接返回另一者。如果都不是-1,则需要返回这二者之间的较大值。

本题的完整代码如下。

class SegmentTree(object):

    def __init__(self, n):
        self.max_num = n
        self.data = [0] * (self.max_num + 5)
        self.value = [0] * (self.max_num * 4 + 5)

    def build(self, pos, left, right):
        if left == right:
            self.value[pos] = self.data[left]
            return
        m = (left + right) >> 1
        l_child, r_child = pos << 1, pos << 1 | 1
        self.build(l_child, left, m)
        self.build(r_child, m + 1, right)
        self.value[pos] = max(self.value[l_child], self.value[r_child])

    def update(self, idx, new_value, pos, left, right):
        if left == right and left == idx:
            self.value[pos] = new_value
            return
        m = (left +  right) >> 1
        if idx <= m: self.update(idx, new_value, pos << 1, left, m)
        if idx > m: self.update(idx, new_value, pos << 1 | 1, m + 1, right)
        self.value[pos] = max(self.value[pos<<1], self.value[pos<<1|1])

    def query(self, query_l, query_r, pos, cur_left, cur_right):
        if query_l <= cur_left and query_r >= cur_right: return self.value[pos]
        m = (cur_left + cur_right) >> 1
        l_ans, r_ans = -1, -1
        if query_l <= m: l_ans = self.query(query_l, query_r, pos << 1, cur_left, m)
        if query_r > m: r_ans = self.query(query_l, query_r, pos << 1 | 1, m + 1, cur_right)
        if l_ans == -1: return r_ans
        if r_ans == -1: return l_ans
        return max(l_ans, r_ans)


class Solution():
    def solve(self):
        n, m = map(int, input().split())
        grades = list(map(int, input().split()))
        asks = [input().split() for _ in range(m)]
        st = SegmentTree(n)
        st.data = [0] + grades
        st.build(1, 1, n)
        for t, a, b in asks:
            a, b = int(a), int(b)
            if t == 'U':
                st.update(a, b, 1, 1, n)
            else:
                print(st.query(a, b, 1, 1, n))


Solution().solve()


代码中线段树申明了4倍的空间,是因为防止越界。比如区间长度为n nn,那么线段树的最后一层便为n nn个结点,那么该线段树的高度便为 ⌈ log ⁡ 2 n ⌉ \\left\\lceil\\log _2 n\\right\\rceil log2n,不难看出, ⌈ log ⁡ 2 n ⌉ ≤ log ⁡ 2 n + 1 \\left\\lceil\\log _2 n\\right\\rceil \\leq \\log _2 n+1 log2nlog2n+1,通过等比数列求和,可得整棵树的节点数量为 1 ∗ ( 1 − 2 x ) 1 − 2 \\frac1 *\\left(1-2^x\\right)1-2 121(12x)其中 x x x为高度,整理可得, 2 log ⁡ 2 n + 1 + 1 − 1 2^\\log _2 n+1+1-1 2log2n+1+11 − 1 −1 1忽略,可得 4 n 4n 4n

总结

线段树结构非常适用于大规模的区间查询问题。

以上是关于算法模板-线段树的主要内容,如果未能解决你的问题,请参考以下文章

模板可持久化线段树 1(主席树)

P3372 模板线段树 1

P3372 模板线段树 1

算法模板-线段树

算法模板——线段树

主席树(模板)