Segment Tree
Reference: AtCoder library https://atcoder.github.io/ac-library/production/document_en/index.html
A segment tree is a data structure for monoids \( (S, \cdot : S \times S \rightarrow S, e \in S) \). A monoid is an algebraic structure which follows the following conditions:
- \(\cdot\) is associative. That is, \( (a \cdot b) \cdot c = a \cdot (b \cdot c) \) for all \( a, b, c \in S \).
- There is the identity element \(e\) such that \( a \cdot e = e \cdot a = a \) for all \( a \in S \).
Given an array \(A\) of length \(n\) consists of the monoid \(S\) as described above, a segment tree on it can process the following queries in \(O (\log{n})\) time:
- Update an element
- Calculate the product of the elements of an interval
assuming that calculating the product of two elements takes \(O(1)\) time.
Example
use segtree::*; fn main() { // Product segment tree with size of 10 and all elements being 1 let st = SegTree::new(10, std::iter::repeat(1), 1, |x, y| x * y % 1000000007); // Sum segment tree with initial values let st = SegTree::new(10, vec![1, 3, 2, 4, 3, 5, 4, 6, 5, 7], 0, |x, y| x + y); // Sum segment tree with size of 10 and all elements being 0 let mut st = SegTree::new(10, None, 0, |x, y| x + y); st.update(2, |_| 3); println!("{}", st.get(2)); // 3 let prev = st.get(4); st.update(4, |x| x + 2); let curr = st.get(4); println!("{prev} -> {curr}"); // 0 -> 2 println!("{}", st.prod(..4)); // 3 println!("{}", st.prod(4..)); // 2 println!("{}", st.prod(..)); // 5 println!("{}", st.prod(100..101)); // 0 println!("{}", st.prod(4..2)); // 0 for i in 1..=st.len() { print!("{} ", st.prod(1..i)); } println!(); // 0 0 3 3 5 5 5 5 5 5 println!("{}", st.partition_point(1, |x| x < 3)); // 3 println!("{}", st.partition_point(1, |x| x < 5)); // 5 } mod segtree { use std::ops::RangeBounds; /// A segment tree is a data structure for a monoid type `T`. /// /// Given all the constraints written in the comment of `SegTree::new`, a segment tree can process the following queries in O(TlogN) time assuming `op` all run in O(T). /// - Changing the value of an element /// - Calculating the product of elements in an interval, combined with `op` pub struct SegTree<T, O> { n: usize, data: Vec<T>, e: T, op: O, size: usize, log: u32, } impl<T: Copy, O: Fn(T, T) -> T> SegTree<T, O> { fn get_bounds(&self, range: impl RangeBounds<usize>) -> (usize, usize) { use std::ops::Bound::*; let n = self.len(); let l = match range.start_bound() { Included(&v) => v, Excluded(&v) => v + 1, Unbounded => 0, }; let r = match range.end_bound() { Included(&v) => (v + 1).min(n), Excluded(&v) => v.min(n), Unbounded => n, }; if l > r { return (l, l); } (l, r) } fn upd(&mut self, k: usize) { self.data[k] = (self.op)(self.data[k * 2], self.data[(k * 2) + 1]); } /// Returns a new segment tree of size `n` built from `iter`. /// /// The meanings of parameters and some generic types are as follows. /// - `T` is a type of values in the array the segment tree represents. /// - `n` is a number of elements in the array. /// - `iter` is an iterator returning initial values of the array. /// - If `iter.count() < n`, then the rest is filled with `e`. /// - If `iter.count() > n`, the array is truncated down to the length of `n`. /// - `op: impl Fn(T, T) -> T` is a binary operator for `T`. /// - `e` is an identity for `op`. /// /// The following notations will be used from now on. /// - `op(a, b)` is denoted as `a*b`. /// /// Constraints of parameters are as follows. /// - `op` and `e` must make `T` a monoid. That is, `op` and `e` should be given so that `T` can satisfy the following conditions. /// - `T` is associative under `op`. That is, `(a*b)*c == a*(b*c)` for all `[a, b, c]: [T; 3]`. /// - `T` has `e` as an identity element under `op`. That is, `a*e == e*a == a` for all `a: T`. /// /// For example, a generic range sum segment tree with every value initialized with `0` and of length `n` can be constucted as follows. /// ```no_run /// let mut st = SegTree::new(n, None, 0i64, |x, y| x + y); /// ``` pub fn new(n: usize, iter: impl IntoIterator<Item = T>, e: T, op: O) -> Self { let size = n.next_power_of_two(); let log = size.trailing_zeros(); let mut data = vec![e; size]; data.extend(iter.into_iter().take(n)); data.resize(2 * size, e); let mut st = Self { n, data, e, op, size, log }; for i in (1..size).rev() { st.upd(i); } st } /// Returns the length of the array. pub fn len(&self) -> usize { self.n } /// Returns the `i`-th value of the array. pub fn get(&self, i: usize) -> T { self.data[i + self.size] } /// Assign `upd_to(self.get(i))` to the `i`-th element. pub fn update(&mut self, i: usize, upd_to: impl Fn(T) -> T) { let i = i + self.size; self.data[i] = upd_to(self.data[i]); for j in 1..=self.log { self.upd(i >> j); } } /// Returns the product of elements in `range`. pub fn prod(&self, range: impl RangeBounds<usize>) -> T { let (mut l, mut r) = self.get_bounds(range); (l += self.size, r += self.size); if (l, r) == (0, self.n) { return self.data[1]; } else if l == r { return self.e; } let (mut sml, mut smr) = (self.e, self.e); while l < r { if l & 1 == 1 { sml = (self.op)(sml, self.data[l]); l += 1; } if r & 1 == 1 { r -= 1; smr = (self.op)(self.data[r], smr); } (l >>= 1, r >>= 1); } (self.op)(sml, smr) } /// For a function `pred` which has a nonnegative value `x`, such that `pred(self.prod(l..r))` is `false` if and only if `x <= r`, `self.partition_point(l, pred)` returns the value of such `x`. /// That is, this is the minimum value of `r` such that `pred(self.prod(l..r))` starts to be `false`. /// If `pred(self.e)` is `true`, then this function assumes that `pred(self.prod(l..r))` is always `true` for any `r` in range `l..=self.len()` and returns `l`. /// However, it's recommended to always set `pred(self.e)` to be `true` to avoid unnecessary case works. /// /// ## Constraints /// - `0 <= l <= self.len()` /// /// ## Examples /// `f(r) := pred(self.prod(l..r))` /// /// Given that `self.len() == 7`, calling `self.partition_point(0)` returns values written below. /// ```text /// r | 0 1 2 3 4 5 6 7 8 /// /// f(r) | true true true true false false false false N/A /// returns^ /// /// f(r) | false false false false false false false false N/A /// returns^ /// /// f(r) | true true true true true true true true N/A /// returns^ /// ``` pub fn partition_point(&self, l: usize, pred: impl Fn(T) -> bool) -> usize { if !pred(self.e) { // `pred(self.prod(l..l))` is `false` // Thus l is returned. // This case is not covered in the original implementation as it simply requires pred(self.e) to be `true` return l; } if l == self.n { // `pred(self.e)` has already been checked that it's `true`. // Thus the answer must be `self.n`. return self.n; } let mut l = l + self.size; let mut sm = self.e; loop { l >>= l.trailing_zeros(); if !pred((self.op)(sm, self.data[l])) { while l < self.size { l <<= 1; let tmp = (self.op)(sm, self.data[l]); if pred(tmp) { sm = tmp; l += 1; } } return l + 1 - self.size; } sm = (self.op)(sm, self.data[l]); l += 1; if l & ((!l) + 1) == l { break; } } self.n + 1 } /// For a function `pred` which has a value `x` less than or equal to `r`, such that `pred(self.prod(l..r))` is `true` if and only if `x <= l`, `self.left_partition_point(r, pred)` returns the value of such `x`. /// That is, this is the minimum value of `l` such that `pred(self.prod(l..r))` starts to be `true`. /// If `pred(self.e)` is `false`, then this function assumes that `pred(self.prod(l..r))` is always `false` for any `l` in range `0..=r` and returns `r+1`. /// However, it's recommended to always set `pred(self.e)` to be `true` to avoid unnecessary case works. /// /// ## Constraints /// - `0 <= r <= self.len()` /// /// ## Examples /// `f(l) := pred(self.prod(l..r))` /// /// Calling `self.left_partition_point(7)` returns values written below. /// ```text /// l | 0 1 2 3 4 5 6 7 8 /// /// f(l) | false false false false true true true true N/A /// returns^ /// /// f(l) | true true true true true true true true N/A /// returns^ /// /// f(l) | false false false false false false false false N/A /// returns^ /// ``` pub fn left_partition_point(&self, r: usize, pred: impl Fn(T) -> bool) -> usize { if !pred(self.e) { return r + 1; } if r == 0 { // `pred(self.e)` is always `true` at this point return 0; } let mut r = r + self.size; let mut sm = self.e; loop { r -= 1; while r > 1 && r & 1 == 1 { r >>= 1; } if !pred((self.op)(self.data[r], sm)) { while r < self.size { r = (r << 1) + 1; let tmp = (self.op)(self.data[r], sm); if pred(tmp) { sm = tmp; r -= 1; } } return r + 1 - self.size; } sm = (self.op)(self.data[r], sm); if r & ((!r) + 1) == r { break; } } 0 } } }
Code
mod segtree {
use std::ops::RangeBounds;
/// A segment tree is a data structure for a monoid type `T`.
///
/// Given all the constraints written in the comment of `SegTree::new`, a segment tree can process the following queries in O(TlogN) time assuming `op` all run in O(T).
/// - Changing the value of an element
/// - Calculating the product of elements in an interval, combined with `op`
pub struct SegTree<T, O> {
n: usize,
data: Vec<T>,
e: T,
op: O,
size: usize,
log: u32,
}
impl<T: Copy, O: Fn(T, T) -> T> SegTree<T, O> {
fn get_bounds(&self, range: impl RangeBounds<usize>) -> (usize, usize) {
use std::ops::Bound::*;
let n = self.len();
let l = match range.start_bound() {
Included(&v) => v,
Excluded(&v) => v + 1,
Unbounded => 0,
};
let r = match range.end_bound() {
Included(&v) => (v + 1).min(n),
Excluded(&v) => v.min(n),
Unbounded => n,
};
if l > r {
return (l, l);
}
(l, r)
}
fn upd(&mut self, k: usize) {
self.data[k] = (self.op)(self.data[k * 2], self.data[(k * 2) + 1]);
}
/// Returns a new segment tree of size `n` built from `iter`.
///
/// The meanings of parameters and some generic types are as follows.
/// - `T` is a type of values in the array the segment tree represents.
/// - `n` is a number of elements in the array.
/// - `iter` is an iterator returning initial values of the array.
/// - If `iter.count() < n`, then the rest is filled with `e`.
/// - If `iter.count() > n`, the array is truncated down to the length of `n`.
/// - `op: impl Fn(T, T) -> T` is a binary operator for `T`.
/// - `e` is an identity for `op`.
///
/// The following notations will be used from now on.
/// - `op(a, b)` is denoted as `a*b`.
///
/// Constraints of parameters are as follows.
/// - `op` and `e` must make `T` a monoid. That is, `op` and `e` should be given so that `T` can satisfy the following conditions.
/// - `T` is associative under `op`. That is, `(a*b)*c == a*(b*c)` for all `[a, b, c]: [T; 3]`.
/// - `T` has `e` as an identity element under `op`. That is, `a*e == e*a == a` for all `a: T`.
///
/// For example, a generic range sum segment tree with every value initialized with `0` and of length `n` can be constucted as follows.
/// ```no_run
/// let mut st = SegTree::new(n, None, 0i64, |x, y| x + y);
/// ```
pub fn new(n: usize, iter: impl IntoIterator<Item = T>, e: T, op: O) -> Self {
let size = n.next_power_of_two();
let log = size.trailing_zeros();
let mut data = vec![e; size];
data.extend(iter.into_iter().take(n));
data.resize(2 * size, e);
let mut st = Self { n, data, e, op, size, log };
for i in (1..size).rev() {
st.upd(i);
}
st
}
/// Returns the length of the array.
pub fn len(&self) -> usize {
self.n
}
/// Returns the `i`-th value of the array.
pub fn get(&self, i: usize) -> T {
self.data[i + self.size]
}
/// Assign `upd_to(self.get(i))` to the `i`-th element.
pub fn update(&mut self, i: usize, upd_to: impl Fn(T) -> T) {
let i = i + self.size;
self.data[i] = upd_to(self.data[i]);
for j in 1..=self.log {
self.upd(i >> j);
}
}
/// Returns the product of elements in `range`.
pub fn prod(&self, range: impl RangeBounds<usize>) -> T {
let (mut l, mut r) = self.get_bounds(range);
(l += self.size, r += self.size);
if (l, r) == (0, self.n) {
return self.data[1];
} else if l == r {
return self.e;
}
let (mut sml, mut smr) = (self.e, self.e);
while l < r {
if l & 1 == 1 {
sml = (self.op)(sml, self.data[l]);
l += 1;
}
if r & 1 == 1 {
r -= 1;
smr = (self.op)(self.data[r], smr);
}
(l >>= 1, r >>= 1);
}
(self.op)(sml, smr)
}
/// For a function `pred` which has a nonnegative value `x`, such that `pred(self.prod(l..r))` is `false` if and only if `x <= r`, `self.partition_point(l, pred)` returns the value of such `x`.
/// That is, this is the minimum value of `r` such that `pred(self.prod(l..r))` starts to be `false`.
/// If `pred(self.e)` is `true`, then this function assumes that `pred(self.prod(l..r))` is always `true` for any `r` in range `l..=self.len()` and returns `l`.
/// However, it's recommended to always set `pred(self.e)` to be `true` to avoid unnecessary case works.
///
/// ## Constraints
/// - `0 <= l <= self.len()`
///
/// ## Examples
/// `f(r) := pred(self.prod(l..r))`
///
/// Given that `self.len() == 7`, calling `self.partition_point(0)` returns values written below.
/// ```text
/// r | 0 1 2 3 4 5 6 7 8
///
/// f(r) | true true true true false false false false N/A
/// returns^
///
/// f(r) | false false false false false false false false N/A
/// returns^
///
/// f(r) | true true true true true true true true N/A
/// returns^
/// ```
pub fn partition_point(&self, l: usize, pred: impl Fn(T) -> bool) -> usize {
if !pred(self.e) {
// `pred(self.prod(l..l))` is `false`
// Thus l is returned.
// This case is not covered in the original implementation as it simply requires pred(self.e) to be `true`
return l;
}
if l == self.n {
// `pred(self.e)` has already been checked that it's `true`.
// Thus the answer must be `self.n`.
return self.n;
}
let mut l = l + self.size;
let mut sm = self.e;
loop {
l >>= l.trailing_zeros();
if !pred((self.op)(sm, self.data[l])) {
while l < self.size {
l <<= 1;
let tmp = (self.op)(sm, self.data[l]);
if pred(tmp) {
sm = tmp;
l += 1;
}
}
return l + 1 - self.size;
}
sm = (self.op)(sm, self.data[l]);
l += 1;
if l & ((!l) + 1) == l {
break;
}
}
self.n + 1
}
/// For a function `pred` which has a value `x` less than or equal to `r`, such that `pred(self.prod(l..r))` is `true` if and only if `x <= l`, `self.left_partition_point(r, pred)` returns the value of such `x`.
/// That is, this is the minimum value of `l` such that `pred(self.prod(l..r))` starts to be `true`.
/// If `pred(self.e)` is `false`, then this function assumes that `pred(self.prod(l..r))` is always `false` for any `l` in range `0..=r` and returns `r+1`.
/// However, it's recommended to always set `pred(self.e)` to be `true` to avoid unnecessary case works.
///
/// ## Constraints
/// - `0 <= r <= self.len()`
///
/// ## Examples
/// `f(l) := pred(self.prod(l..r))`
///
/// Calling `self.left_partition_point(7)` returns values written below.
/// ```text
/// l | 0 1 2 3 4 5 6 7 8
///
/// f(l) | false false false false true true true true N/A
/// returns^
///
/// f(l) | true true true true true true true true N/A
/// returns^
///
/// f(l) | false false false false false false false false N/A
/// returns^
/// ```
pub fn left_partition_point(&self, r: usize, pred: impl Fn(T) -> bool) -> usize {
if !pred(self.e) {
return r + 1;
}
if r == 0 {
// `pred(self.e)` is always `true` at this point
return 0;
}
let mut r = r + self.size;
let mut sm = self.e;
loop {
r -= 1;
while r > 1 && r & 1 == 1 {
r >>= 1;
}
if !pred((self.op)(self.data[r], sm)) {
while r < self.size {
r = (r << 1) + 1;
let tmp = (self.op)(self.data[r], sm);
if pred(tmp) {
sm = tmp;
r -= 1;
}
}
return r + 1 - self.size;
}
sm = (self.op)(self.data[r], sm);
if r & ((!r) + 1) == r {
break;
}
}
0
}
}
}
Last modified on 231007