算法模板-线段树
Posted 周先森爱吃素
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了算法模板-线段树相关的知识,希望对你有一定的参考价值。
本文为CSDN博主「maplezys」的原创文章,本文转载自https://blog.csdn.net/qq_41006629/article/details/124131368。
简介
线段树(Segment Tree)是稍高级一点的数据结构,它一般用于维护区间信息。线段树是一棵平衡二叉树,其根结点代表着整个区间的信息,越往下的结点代表的区间越小,也就是说,线段树的每一个结点都对应着一条区间(线段)。
如果有一个数组是[1,2,3,4,5,6,7,8],那么它对应的线段树大致如下:
我们从下标1开始存储每个节点(较为方便),这样每个节点
x
x
x的左儿子节点为
2
x
2x
2x,右孩子为
2
x
+
1
2 x + 1
2x+1。假设
x
x
x结点存储的是区间
[
l
e
f
t
,
r
i
g
h
t
]
[left, right]
[left,right]的信息,
m
i
d
=
⌊
l
e
f
t
+
r
i
g
h
t
2
⌋
mid=\\left\\lfloor\\fracleft+right2\\right\\rfloor
mid=⌊2left+right⌋,那么其左右儿子存储的分别是区间
[
l
e
f
t
,
m
i
d
]
[left, mid]
[left,mid]和
[
m
i
d
+
1
,
r
i
g
h
t
]
[mid+1, right]
[mid+1,right]的信息。可以发现,由于mid的计算方式,因此左节点对应的区间长度要么和右节点相同,要么比之恰好多1。
详解
我们以一道简单的例题为例。
老师想知道从某某同学当中,分数最高的是多少,现在请你编程模拟老师的询问。当然,老师有时候需要更新某位同学的成绩。
输入描述:
输入包括多组测试数据。
每组输入第一行是两个正整数N和M(0 < N <= 30000,0 < M < 5000),分别代表学生的数目和操作的数目。
学生ID编号从1编到N。
第二行包含N个整数,代表这N个学生的初始成绩,其中第i个数代表ID为i的学生的成绩
接下来又M行,每一行有一个字符C(只取‘Q’或‘U’),和两个正整数A,B,当C为'Q'的时候, 表示这是一条询问操作,他询问ID从A到B(包括A,B)的学生当中,成绩最高的是多少
当C为‘U’的时候,表示这是一条更新操作,要求把ID为A的学生的成绩更改为B。
输出描述:
对于每一次询问操作,在一行里面输出最高成绩。
输入样例:
5 7
1 2 3 4 5
Q 1 5
U 3 6
Q 3 4
Q 4 5
U 4 5
U 2 9
Q 1 5
输出例子:
5
6
5
9
对于本题,我们可以分析出,线段树每个节点存储的就是当前区间的最大值,我们以线性表来存储这棵树,递归建树的代码如下。
def build(pos, left, right):
if left == right:
self.value[pos] = self.data[left]
return
m = (left + right) >> 1
l_child, r_child = pos << 1, pos << 1 | 1
self.build(l_child, left, m)
self.build(r_child, m + 1, right)
self.value[pos] = max(self.value[l_child], self.value[r_child])
假如我要更新p结点的值,那么包含p结点的所有区间的结点均需被更新,也是采用递归的形式进行更新,代码如下。
def update(idx, new_value, pos, left, right):
if left == right and left == idx:
self.value[pos] = new_value
return
m = (left + right) >> 1
if idx <= m: self.update(idx, new_value, pos << 1, left, m)
if idx > m: self.update(idx, new_value, pos << 1 | 1, m + 1, right)
self.value[pos] = max(self.value[pos<<1], self.value[pos<<1|1])
那此时如果我们需要查询指定区间的信息呢?依旧是递归查询,我们先贴出代码:
def query(query_l, query_r, pos, cur_left, cur_right):
if query_l <= cur_left and query_r >= cur_right: return self.value[pos]
m = (cur_left + cur_right) >> 1
l_ans, r_ans = -1, -1
if query_l <= m: l_ans = self.query(query_l, query_r, pos << 1, cur_left, m)
if query_r > m: r_ans = self.query(query_l, query_r, pos << 1 | 1, m + 1, cur_right)
if l_ans == -1: return r_ans
if r_ans == -1: return l_ans
return max(l_ans, r_ans)
查询区间的时候,有三种情况。假设需要查询的区间为 [ L , R ] [ L , R ] [L,R] ,若是目标区间覆盖了当前区间,那么当前区间的最大值是需要的,直接返回。若没有完全覆盖,且若 L L L 在 m i d mid mid的左边,那么需要去左孩子处( [ c u r _ l e f t , m i d ] [cur\\_left, mid ] [cur_left,mid])查询目标区间需要的信息,不然该值取-1。同理,若 R R R在 m i d mid mid的右边,则需要去右孩子处( [ m i d + 1 , c u r _ r i g h t ] [mid+1, cur\\_right] [mid+1,cur_right])查询目标区间需要的信息,不然该值取-1。上述两者中,若有一者为-1,则直接返回另一者。如果都不是-1,则需要返回这二者之间的较大值。
本题的完整代码如下。
class SegmentTree(object):
def __init__(self, n):
self.max_num = n
self.data = [0] * (self.max_num + 5)
self.value = [0] * (self.max_num * 4 + 5)
def build(self, pos, left, right):
if left == right:
self.value[pos] = self.data[left]
return
m = (left + right) >> 1
l_child, r_child = pos << 1, pos << 1 | 1
self.build(l_child, left, m)
self.build(r_child, m + 1, right)
self.value[pos] = max(self.value[l_child], self.value[r_child])
def update(self, idx, new_value, pos, left, right):
if left == right and left == idx:
self.value[pos] = new_value
return
m = (left + right) >> 1
if idx <= m: self.update(idx, new_value, pos << 1, left, m)
if idx > m: self.update(idx, new_value, pos << 1 | 1, m + 1, right)
self.value[pos] = max(self.value[pos<<1], self.value[pos<<1|1])
def query(self, query_l, query_r, pos, cur_left, cur_right):
if query_l <= cur_left and query_r >= cur_right: return self.value[pos]
m = (cur_left + cur_right) >> 1
l_ans, r_ans = -1, -1
if query_l <= m: l_ans = self.query(query_l, query_r, pos << 1, cur_left, m)
if query_r > m: r_ans = self.query(query_l, query_r, pos << 1 | 1, m + 1, cur_right)
if l_ans == -1: return r_ans
if r_ans == -1: return l_ans
return max(l_ans, r_ans)
class Solution():
def solve(self):
n, m = map(int, input().split())
grades = list(map(int, input().split()))
asks = [input().split() for _ in range(m)]
st = SegmentTree(n)
st.data = [0] + grades
st.build(1, 1, n)
for t, a, b in asks:
a, b = int(a), int(b)
if t == 'U':
st.update(a, b, 1, 1, n)
else:
print(st.query(a, b, 1, 1, n))
Solution().solve()
代码中线段树申明了4倍的空间,是因为防止越界。比如区间长度为n nn,那么线段树的最后一层便为n nn个结点,那么该线段树的高度便为 ⌈ log 2 n ⌉ \\left\\lceil\\log _2 n\\right\\rceil ⌈log2n⌉,不难看出, ⌈ log 2 n ⌉ ≤ log 2 n + 1 \\left\\lceil\\log _2 n\\right\\rceil \\leq \\log _2 n+1 ⌈log2n⌉≤log2n+1,通过等比数列求和,可得整棵树的节点数量为 1 ∗ ( 1 − 2 x ) 1 − 2 \\frac1 *\\left(1-2^x\\right)1-2 1−21∗(1−2x)其中 x x x为高度,整理可得, 2 log 2 n + 1 + 1 − 1 2^\\log _2 n+1+1-1 2log2n+1+1−1, − 1 −1 −1忽略,可得 4 n 4n 4n。
总结
线段树结构非常适用于大规模的区间查询问题。
以上是关于算法模板-线段树的主要内容,如果未能解决你的问题,请参考以下文章