Introduction

This book is a collection of snippets for competitive programming and problem solving with Rust. All you have to do for using snippets is simply copying snippets you need and pasting them into your code.

Usage

Each snippet can be easily copied by clicking on the copy icon at the top right corner of a code block.

Clicking on the search icon in the menu bar, or pressing the S key on the keyboard will open an input box for entering search terms. Any keywords included in this book can be searched by typing it in the box.

Other Resources

General

Rust

C++

Python

Disclaimer

None of the codes from this document should be used in any other fields besides competitive programming! Every code here is strictly designed for CP, and none of these codes are for actual production codes.

License

Unless stated otherwise, or there is another license included in the code, every snippets here is under Unlicense.

This is free and unencumbered software released into the public domain.

Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means.

In jurisdictions that recognize copyright laws, the author or authors of this software dedicate any and all copyright interest in the software to the public domain. We make this dedication for the benefit of the public at large and to the detriment of our heirs and successors. We intend this dedication to be an overt act of relinquishment in perpetuity of all present and future rights to this software under copyright law.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

For more information, please refer to http://unlicense.org/

Sieve

Sieve algorithms find all prime numbers below a specified integer. Additionally, they efficiently compute values of multiplicative functions for all integers up to that specified limit.

Finding primes

sieve returns a vector containing all prime numbers that are less than or equal to max_val.

This function runs with a time complexity of \(O(N)\) where \(N\) is the value of max_val. This efficiency is achieved with linear sieve. This function is further optimized by excluding multiples of \(2\) and \(3\) in advance.

fn sieve(max_val: usize) -> Vec<usize> {
	let mut primes = vec![2, 3];
	let mut is_prime = vec![true; max_val / 3 + 1];

	for i in 0..is_prime.len() {
		let j = 6 * (i >> 1) + 5 + ((i & 1) << 1);
		if is_prime[i] {
			primes.push(j);
		}
		for &p in primes[2..].iter() {
			let v = j * p;
			if v > max_val {
				break;
			}
			is_prime[v / 3 - 1] = false;
			if j % p == 0 {
				break;
			}
		}
	}

	primes
}

With Euler Phi Function

fn phi_sieve(max_val: usize) -> (Vec<bool>, Vec<usize>, Vec<usize>) {
    let mut primes = vec![];
    let mut is_prime = vec![true; max_val + 1];
    is_prime[0] = false;
    is_prime[1] = false;
    let mut phi = vec![0; max_val + 1];

    for i in 2..=max_val {
        if is_prime[i] {
            primes.push(i);
            phi[i] = i - 1;
        }
        for &p in primes.iter() {
            let v = i * p;
            if v > max_val {
                break;
            }
            is_prime[v] = false;
            if i % p == 0 {
                phi[v] = phi[i] * p;
                break;
            } else {
                phi[v] = phi[i] * phi[p]
            }
        }
    }

    (is_prime, phi, primes)
}

With Möbius Function

fn mobius_sieve(max_val: usize) -> (Vec<i8>, Vec<usize>) {
    let mut primes = vec![];
    let mut mu = vec![2i8; max_val + 1];
    (mu[0], mu[1]) = (0, 1);

    for i in 2..=max_val {
        if mu[i] == 2 {
            primes.push(i);
            mu[i] = -1;
        }
        for &p in primes.iter() {
            let v = i * p;
            if v > max_val {
                break;
            }
            if i % p == 0 {
                mu[v] = 0;
                break;
            } else {
                mu[v] = -mu[i];
            }
        }
    }

    (mu, primes)
}

Last modified on 231203.

GCD, LCM

gcd(x, y) returns the greatest common divisor (GCD) of x and y.
lcm(x, y) returns the least common multiple (LCM) of x and y.

gcd is implemented using Euclidean algorithm, whose time complexity is \(O( \log _{\phi} x )\) where \(\phi\) is a golden ratio.

Example

fn main() {
let (x, y) = (10, 25);

let g = gcd(x, y);
println!("{}", g); // 5

let l = lcm(x, y);
println!("{}", l); // 50
}

fn gcd(x: u64, y: u64) -> u64 {
   if y == 0 {
       x
   } else {
       gcd(y, x % y)
   }
}

fn lcm(x: u64, y: u64) -> u64 {
   x / gcd(x, y) * y
}

Code

fn gcd<T>(x: T, y: T) -> T
where T: Copy + PartialEq + PartialOrd + Rem<Output = T> + From<u8> {
	if y == 0.into() {
		x
	} else {
		let v = x % y;
		gcd(y, v)
	}
}

fn lcm<T>(x: T, y: T) -> T
where T: Copy + PartialEq + PartialOrd + Rem<Output = T> + Div<Output = T> + Mul<Output = T> + From<u8> {
	x / gcd(x, y) * y
}

Last modified on 231203.

Extended Euclidean Algorithm

egcd(a, b) returns \(g, s, t\) such that \(g = \gcd(a, b)\) and \(as+bt=g\).

Example

fn main() {
for (a, b) in [(2, 5), (11, 17), (20, 35)] {
    let (g, s, t) = egcd(a, b);
    println!("gcd({a}, {b}) = {g}");
    println!("{a}*({s}) + {b}*({t}) = {g}");
}
}

/// Returns `(g, s, t)` such that `g == gcd(a, b)` and `a*s + t*b == g`.
fn egcd(mut a: i64, mut b: i64) -> (i64, i64, i64) {
    let (mut sa, mut ta, mut sb, mut tb) = (1, 0, 0, 1);
    while b != 0 {
        let (q, r) = (a / b, a % b);
        (sa, ta, sb, tb) = (sb, tb, sa - q * sb, ta - q * tb);
        (a, b) = (b, r);
    }
    (a, sa, ta)
}

Code

/// Returns `(g, s, t)` such that `g == gcd(a, b)` and `a*s + t*b == g`.
fn egcd(mut a: i64, mut b: i64) -> (i64, i64, i64) {
    let (mut sa, mut ta, mut sb, mut tb) = (1, 0, 0, 1);
    while b != 0 {
        let (q, r) = (a / b, a % b);
        (sa, ta, sb, tb) = (sb, tb, sa - q * sb, ta - q * tb);
        (a, b) = (b, r);
    }
    (a, sa, ta)
}

Last modified on 231008.

Chinese Remainder Theorem

crt(r, m) returns Some(x) such that \(x \equiv r_i \pmod {m_i}\) for all \(i\). If such \(x\) does not exist, then it returns None.

Example

fn main() {
let r = vec![1, 2, 3];
let m = vec![3, 5, 7];
let x = crt(&r, &m);
println!("{:?}", x); // Some(52)

let r = vec![2, 5];
let m = vec![10, 25];
let x = crt(&r, &m);
println!("{:?}", x); // None
}

fn gcd(x: i64, y: i64) -> i64 {
    if y == 0 {
        x
    } else {
        gcd(y, x % y)
    }
}

/// Returns `(g, s, t)` such that `g == gcd(a, b)` and `a*s + t*b == g`.
fn egcd(mut a: i64, mut b: i64) -> (i64, i64, i64) {
    let (mut sa, mut ta, mut sb, mut tb) = (1, 0, 0, 1);
    while b != 0 {
        let (q, r) = (a / b, a % b);
        (sa, ta, sb, tb) = (sb, tb, sa - q * sb, ta - q * tb);
        (a, b) = (b, r);
    }
    (a, sa, ta)
}

/// Returns x s.t. x=a_i (mod m_i) for all i.
/// Reference: PyRival https://github.com/cheran-senthil/PyRival/blob/master/pyrival/algebra/chinese_remainder.py
fn crt(a: &[i64], m: &[i64]) -> Option<i64> {
    use std::iter::zip;
    let (mut x, mut m_prod) = (0, 1);
    for (&ai, &mi) in zip(a, m) {
        let (g, s, _) = egcd(m_prod, mi);
        if (ai - x).rem_euclid(g) != 0 {
            return None;
        }
        x += m_prod * ((s * ((ai - x).rem_euclid(mi))).div_euclid(g));
        m_prod = (m_prod * mi).div_euclid(gcd(m_prod, mi));
    }
    Some(x.rem_euclid(m_prod))
}

Code

fn gcd(x: i64, y: i64) -> i64 {
    if y == 0 {
        x
    } else {
        gcd(y, x % y)
    }
}

/// Returns `(g, s, t)` such that `g == gcd(a, b)` and `a*s + t*b == g`.
fn egcd(mut a: i64, mut b: i64) -> (i64, i64, i64) {
    let (mut sa, mut ta, mut sb, mut tb) = (1, 0, 0, 1);
    while b != 0 {
        let (q, r) = (a / b, a % b);
        (sa, ta, sb, tb) = (sb, tb, sa - q * sb, ta - q * tb);
        (a, b) = (b, r);
    }
    (a, sa, ta)
}

/// Returns x s.t. x=a_i (mod m_i) for all i.
/// Reference: PyRival https://github.com/cheran-senthil/PyRival/blob/master/pyrival/algebra/chinese_remainder.py
fn crt(a: &[i64], m: &[i64]) -> Option<i64> {
    use std::iter::zip;
    let (mut x, mut m_prod) = (0, 1);
    for (&ai, &mi) in zip(a, m) {
        let (g, s, _) = egcd(m_prod, mi);
        if (ai - x).rem_euclid(g) != 0 {
            return None;
        }
        x += m_prod * ((s * ((ai - x).rem_euclid(mi))).div_euclid(g));
        m_prod = (m_prod * mi).div_euclid(gcd(m_prod, mi));
    }
    Some(x.rem_euclid(m_prod))
}

Last modified on 231008.

Deterministic Miller-Rabin Primality Test

Deterministic Miller-Rabin primality test determines whether a certain unsigned integer is a prime in a time complexity of \(O(\log{n})\). This test only works for integers under \(2^{64}\).

x.is_prime() chooses a roughly faster algorithm among naive primality test and Miller-Rabin test, and returns true if x is a prime, false if not.

Example

use millerrabin::*;

fn main() {
println!("{}", 407521u64.is_prime()); // true
println!("{}", 3284729387909u64.is_prime()); // true
println!("{}", 3284729387911u64.is_prime()); // false 53×61976026187
}

mod millerrabin {
    pub trait Primality {
        fn is_prime(self) -> bool;
    }

    macro_rules! impl_primality {
        ($t:ty, $u:ty, $thres:expr, $bcnt:expr, $($basis:expr),+) => {
            impl Primality for $t {
                fn is_prime(self) -> bool {
                    if self <= 1 {
                        return false;
                    } else if self & 1 == 0 {
                        return self == 2;
                    }

                    const THRES: $t = $thres;
                    const TEST: [$t; $bcnt] = [$($basis,)+];

                    if self <= THRES {
                        for p in (2..).take_while(|&p| p * p <= self) {
                            if self % p == 0 {
                                return false;
                            }
                        }
                        return true;
                    }

                    let pow = |base: $t, mut exp: $t| -> $t {
                        let mut base = base as $u;
                        let mut ret = 1 as $u;
                        while exp != 0 {
                            if exp & 1 != 0 {
                                ret = (ret * base) % self as $u;
                            }
                            exp >>= 1;
                            base = (base * base) % self as $u;
                        }
                        ret as $t
                    };

                    let s = (self - 1).trailing_zeros();
                    let d = (self - 1) >> s;

                    for &a in TEST.iter().take_while(|&&a| a < self - 1) {
                        let mut x = pow(a, d);
                        for _ in 0..s {
                            let y = ((x as $u).pow(2) % self as $u) as $t;
                            if y == 1 && x != 1 && x != self - 1 {
                                return false;
                            }
                            x = y;
                        }
                        if x != 1 {
                            return false;
                        }
                    }

                    true
                }
            }
        };
    }

    impl_primality!(u8, u16, 255, 1, 2);
    impl_primality!(u16, u32, 2000, 2, 2, 3);
    impl_primality!(u32, u64, 7000, 3, 2, 7, 61);
    impl_primality!(u64, u128, 300000, 7, 2, 325, 9375, 28178, 450775, 9780504, 1795265022);
}

Code

mod millerrabin {
    pub trait Primality {
        fn is_prime(self) -> bool;
    }

    macro_rules! impl_primality {
        ($t:ty, $u:ty, $thres:expr, $bcnt:expr, $($basis:expr),+) => {
            impl Primality for $t {
                fn is_prime(self) -> bool {
                    if self <= 1 {
                        return false;
                    } else if self & 1 == 0 {
                        return self == 2;
                    }

                    const THRES: $t = $thres;
                    const TEST: [$t; $bcnt] = [$($basis,)+];

                    if self <= THRES {
                        for p in (2..).take_while(|&p| p * p <= self) {
                            if self % p == 0 {
                                return false;
                            }
                        }
                        return true;
                    }

                    let pow = |base: $t, mut exp: $t| -> $t {
                        let mut base = base as $u;
                        let mut ret = 1 as $u;
                        while exp != 0 {
                            if exp & 1 != 0 {
                                ret = (ret * base) % self as $u;
                            }
                            exp >>= 1;
                            base = (base * base) % self as $u;
                        }
                        ret as $t
                    };

                    let s = (self - 1).trailing_zeros();
                    let d = (self - 1) >> s;

                    for &a in TEST.iter().take_while(|&&a| a < self - 1) {
                        let mut x = pow(a, d);
                        for _ in 0..s {
                            let y = ((x as $u).pow(2) % self as $u) as $t;
                            if y == 1 && x != 1 && x != self - 1 {
                                return false;
                            }
                            x = y;
                        }
                        if x != 1 {
                            return false;
                        }
                    }

                    true
                }
            }
        };
    }

    impl_primality!(u8, u16, 255, 1, 2);
    impl_primality!(u16, u32, 2000, 2, 2, 3);
    impl_primality!(u32, u64, 7000, 3, 2, 7, 61);
    impl_primality!(u64, u128, 300000, 7, 2, 325, 9375, 28178, 450775, 9780504, 1795265022);
}

Last modified on 231008.

Pollard's Rho Algorithm

Pollard rho algorithm is a randomized algorithm which factorizes a number in an average time complexity of \(O(n^{1/4})\).

x.factorize() factorizes x and returns a vector with the factors. The order of factors in the vector is undefined.

Miller-Rabin primality test should be with this snippet in the code.

Example

use factorization::*;

fn main() {
let mut rng = RNG::new(15163487);
let a: u64 = 484387724796727379;
let mut factors = a.factorize(&mut rng);
factors.sort_unstable();
println!("{:?}", factors); // [165551, 2925912406429]
println!("{}", factors.iter().product::<u64>()); // 484387724796727379
use millerrabin::Primality;
println!("{}", factors.iter().all(|&x| x.is_prime()));
}

mod factorization {
    use super::millerrabin::Primality;
    use std::ops::*;

    pub trait PollardRho: Primality + From<u8> + PartialOrd + ShrAssign + BitAnd<Output = Self> + Clone {
        fn rho(self, arr: &mut Vec<Self>, rng: &mut rng::RNG);
        fn factorize(mut self, rng: &mut rng::RNG) -> Vec<Self> {
            let mut arr: Vec<Self> = Vec::new();
            if self <= 1.into() {
                return arr;
            }
            while self.clone() & 1.into() == 0.into() {
                self >>= 1.into();
                arr.push(2.into());
            }
            self.rho(&mut arr, rng);
            arr
        }
    }

    macro_rules! impl_pollardrho {
        ($t:ty, $u:ty, $reset:expr) => {
            impl PollardRho for $t {
                fn rho(self, arr: &mut Vec<Self>, rng: &mut rng::RNG) {
                    if self <= 1 {
                        return;
                    } else if self.is_prime() {
                        arr.push(self);
                        return;
                    }

                    let mut i: u64 = 0;
                    let mut x: $t = (rng.next_u64() % self as u64) as $t;
                    let mut y: $t = x;
                    let mut k: u64 = 2;
                    let mut d: $t;
                    let mut reset_limit: u64 = $reset;

                    loop {
                        i += 1;
                        x = (((x as $u * x as $u % self as $u) + (self - 1) as $u) % self as $u) as $t;
                        d = gcd(y.abs_diff(x), self);
                        if d == self || i >= reset_limit {
                            // Reset
                            reset_limit = reset_limit * 3 / 2;
                            i = 0;
                            x = (rng.next_u64() % self as u64) as $t;
                            y = x;
                        }
                        if d != 1 {
                            break;
                        }
                        if i == k {
                            y = x;
                            k <<= 1;
                        }
                    }

                    if d != self {
                        d.rho(arr, rng);
                        (self / d).rho(arr, rng);
                        return;
                    }

                    let mut i = 3;
                    while i * i <= self {
                        if self % i == 0 {
                            i.rho(arr, rng);
                            (d / i).rho(arr, rng);
                            return;
                        }
                        i += 2;
                    }
                }
            }
        };
    }

    impl_pollardrho!(u8, u16, 100000);
    impl_pollardrho!(u16, u32, 100000);
    impl_pollardrho!(u32, u64, 100000);
    impl_pollardrho!(u64, u128, 100000);

    pub fn gcd<T>(x: T, y: T) -> T
    where
        T: Copy + PartialEq + PartialOrd + core::ops::Rem<Output = T> + From<u8>,
    {
        if y == 0.into() {
            x
        } else {
            let v = x % y;
            gcd(y, v)
        }
    }

    pub mod rng {
        pub struct RNG {
            val: u64,
        }
        impl RNG {
            pub fn new(seed: u64) -> Self {
                Self { val: seed }
            }
            pub fn next_u64(&mut self) -> u64 {
                let mut x = self.val;
                x ^= x << 13;
                x ^= x >> 7;
                x ^= x << 17;
                self.val = x;
                x
            }
        }
    }

    pub use rng::*;
}

mod millerrabin {
    pub trait Primality {
        fn is_prime(self) -> bool;
    }

    macro_rules! impl_primality {
        ($t:ty, $u:ty, $thres:expr, $bcnt:expr, $($basis:expr),+) => {
            impl Primality for $t {
                fn is_prime(self) -> bool {
                    if self <= 1 {
                        return false;
                    } else if self & 1 == 0 {
                        return self == 2;
                    }

                    const THRES: $t = $thres;
                    const TEST: [$t; $bcnt] = [$($basis,)+];

                    if self <= THRES {
                        for p in (2..).take_while(|&p| p * p <= self) {
                            if self % p == 0 {
                                return false;
                            }
                        }
                        return true;
                    }

                    let pow = |base: $t, mut exp: $t| -> $t {
                        let mut base = base as $u;
                        let mut ret = 1 as $u;
                        while exp != 0 {
                            if exp & 1 != 0 {
                                ret = (ret * base) % self as $u;
                            }
                            exp >>= 1;
                            base = (base * base) % self as $u;
                        }
                        ret as $t
                    };

                    let s = (self - 1).trailing_zeros();
                    let d = (self - 1) >> s;

                    for &a in TEST.iter().take_while(|&&a| a < self - 1) {
                        let mut x = pow(a, d);
                        for _ in 0..s {
                            let y = ((x as $u).pow(2) % self as $u) as $t;
                            if y == 1 && x != 1 && x != self - 1 {
                                return false;
                            }
                            x = y;
                        }
                        if x != 1 {
                            return false;
                        }
                    }

                    true
                }
            }
        };
    }

    impl_primality!(u8, u16, 255, 1, 2);
    impl_primality!(u16, u32, 2000, 2, 2, 3);
    impl_primality!(u32, u64, 7000, 3, 2, 7, 61);
    impl_primality!(u64, u128, 300000, 7, 2, 325, 9375, 28178, 450775, 9780504, 1795265022);
}

Code

mod factorization {
    use super::millerrabin::Primality;
    use std::ops::*;

    pub trait PollardRho: Primality + From<u8> + PartialOrd + ShrAssign + BitAnd<Output = Self> + Clone {
        fn rho(self, arr: &mut Vec<Self>, rng: &mut rng::RNG);
        fn factorize(mut self, rng: &mut rng::RNG) -> Vec<Self> {
            let mut arr: Vec<Self> = Vec::new();
            if self <= 1.into() {
                return arr;
            }
            while self.clone() & 1.into() == 0.into() {
                self >>= 1.into();
                arr.push(2.into());
            }
            self.rho(&mut arr, rng);
            arr
        }
    }

    macro_rules! impl_pollardrho {
        ($t:ty, $u:ty, $reset:expr) => {
            impl PollardRho for $t {
                fn rho(self, arr: &mut Vec<Self>, rng: &mut rng::RNG) {
                    if self <= 1 {
                        return;
                    } else if self.is_prime() {
                        arr.push(self);
                        return;
                    }

                    let mut i: u64 = 0;
                    let mut x: $t = (rng.next_u64() % self as u64) as $t;
                    let mut y: $t = x;
                    let mut k: u64 = 2;
                    let mut d: $t;
                    let mut reset_limit: u64 = $reset;

                    loop {
                        i += 1;
                        x = (((x as $u * x as $u % self as $u) + (self - 1) as $u) % self as $u) as $t;
                        d = gcd(y.abs_diff(x), self);
                        if d == self || i >= reset_limit {
                            // Reset
                            reset_limit = reset_limit * 3 / 2;
                            i = 0;
                            x = (rng.next_u64() % self as u64) as $t;
                            y = x;
                        }
                        if d != 1 {
                            break;
                        }
                        if i == k {
                            y = x;
                            k <<= 1;
                        }
                    }

                    if d != self {
                        d.rho(arr, rng);
                        (self / d).rho(arr, rng);
                        return;
                    }

                    let mut i = 3;
                    while i * i <= self {
                        if self % i == 0 {
                            i.rho(arr, rng);
                            (d / i).rho(arr, rng);
                            return;
                        }
                        i += 2;
                    }
                }
            }
        };
    }

    impl_pollardrho!(u8, u16, 100000);
    impl_pollardrho!(u16, u32, 100000);
    impl_pollardrho!(u32, u64, 100000);
    impl_pollardrho!(u64, u128, 100000);

    pub fn gcd<T>(x: T, y: T) -> T
    where
        T: Copy + PartialEq + PartialOrd + core::ops::Rem<Output = T> + From<u8>,
    {
        if y == 0.into() {
            x
        } else {
            let v = x % y;
            gcd(y, v)
        }
    }

    pub mod rng {
        pub struct RNG {
            val: u64,
        }
        impl RNG {
            pub fn new(seed: u64) -> Self {
                Self { val: seed }
            }
            pub fn next_u64(&mut self) -> u64 {
                let mut x = self.val;
                x ^= x << 13;
                x ^= x >> 7;
                x ^= x << 17;
                self.val = x;
                x
            }
        }
    }

    pub use self::rng::*;
}

Last modified on 231008.

Number Theoretic Transform (NTT)

Number Theoretic Transform (NTT) is an alternative of FFT where the domain of numbers is \(\mathbb{Z}_p\) where \(p\) is a prime with a form of \(p = a \times 2^b + 1\).

ntt::convolute(&[u64], &[u64]) -> Vec<u64> convolutes two slices using two NTTs on two different primes and CRT, within a range that none of the numbers in the result exceed the range of u64.

Example

fn main() {
let a: Vec<u64> = vec![1, 2];
let b: Vec<u64> = vec![3, 4, 5];
println!("{:?}", ntt::convolute(&a, &b)); // [3, 10, 13, 10, 0, 0, 0, 0]
}

mod ntt {
    // FFT_constname convention following https://algoshitpo.github.io/2020/05/20/fft-ntt/
    // p: prime for modulo
    // w: primitive root of p
    // p = a * 2^b + 1

    //             p  ntt_a ntt_b   ntt_w
    //   998,244,353    119    23       3
    // 2,281,701,377    17     27       3
    // 2,483,027,969    37     26       3
    // 2,113,929,217    63     25       5
    //   104,857,601    25     22       3
    // 1,092,616,193    521    21       3

    fn ceil_pow2(n: usize) -> u32 {
        n.next_power_of_two().trailing_zeros()
    }

    /// Reverses k trailing bits of n. Assumes that the rest of usize::BITS-k bits are all zero.
    const fn reverse_trailing_bits(n: usize, k: u32) -> usize {
        n.reverse_bits() >> (usize::BITS - k)
    }

    #[derive(Clone, Debug)]
    pub struct Ntt<const P: u64> {
        pub arr: Vec<u64>,
    }

    impl<const P: u64> Ntt<P> {
        pub const fn ntt_a() -> u64 {
            let mut p = P - 1;
            while p & 1 == 0 {
                p >>= 1;
            }
            p
        }

        pub const fn ntt_b() -> u32 {
            let mut p = P - 1;
            let mut ret = 0;
            while p & 1 == 0 {
                p >>= 1;
                ret += 1;
            }
            ret
        }

        pub const fn ntt_w() -> u64 {
            match P {
                998244353 | 2281701377 | 2483027969 | 104857601 | 1092616193 => 3,
                2113929217 => 5,
                _ => todo!(),
            }
        }

        const fn pow(base: u64, exp: u64) -> u64 {
            let mut base = base;
            let mut exp = exp;
            let mut ret = 1;
            while exp != 0 {
                if exp & 1 != 0 {
                    ret = ret * base % P;
                }
                base = base * base % P;
                exp >>= 1;
            }
            ret
        }

        // unity(n, 1) ^ (1<<n) == 1
        const fn unity(n: u32, k: u64) -> u64 {
            Self::pow(Self::pow(Self::ntt_w(), Self::ntt_a()), k << (Self::ntt_b() - n))
        }

        const fn recip(x: u64) -> u64 {
            Self::pow(x, P - 2)
        }

        pub fn new(arr: Vec<u64>) -> Self {
            Self { arr }
        }

        pub fn ntt(&mut self) {
            let n: usize = self.arr.len();
            let k = n.trailing_zeros();
            debug_assert_eq!(n, 1 << k);

            for i in 0..n {
                let j = reverse_trailing_bits(i, k);
                if i < j {
                    self.arr.swap(i, j);
                }
            }

            for x in 0..k {
                let base: u64 = Self::unity(x + 1, 1);
                let s = 1 << x;
                for i in (0..n).step_by(s << 1) {
                    let mut mult: u64 = 1;
                    for j in 0..s {
                        let tmp = (self.arr[i + j + s] * mult) % P;
                        self.arr[i + j + s] = (self.arr[i + j] + P - tmp) % P;
                        self.arr[i + j] = (self.arr[i + j] + tmp) % P;
                        mult = mult * base % P;
                    }
                }
            }
        }

        pub fn intt(&mut self) {
            let n: usize = self.arr.len();
            let k = n.trailing_zeros();
            debug_assert_eq!(n, 1 << k);

            for i in 0..n {
                let j = reverse_trailing_bits(i, k);
                if i < j {
                    self.arr.swap(i, j);
                }
            }

            for x in 0..k {
                let base: u64 = Self::recip(Self::unity(x + 1, 1));
                let s = 1 << x;
                for i in (0..n).step_by(s << 1) {
                    let mut mult: u64 = 1;
                    for j in 0..s {
                        let tmp = (self.arr[i + j + s] * mult) % P;
                        self.arr[i + j + s] = (self.arr[i + j] + P - tmp) % P;
                        self.arr[i + j] = (self.arr[i + j] + tmp) % P;
                        mult = mult * base % P;
                    }
                }
            }

            let r = Self::recip(n as u64);
            for f in self.arr.iter_mut() {
                *f *= r;
                *f %= P;
            }
        }

        pub fn convolute(a: &[u64], b: &[u64]) -> Self {
            let nlen = 1 << ceil_pow2(a.len() + b.len());
            let pad = |a: &[u64]| a.iter().copied().chain(std::iter::repeat(0)).take(nlen).collect();
            let arr = pad(a);
            let brr = pad(b);

            let mut arr = Self::new(arr);
            let mut brr = Self::new(brr);
            arr.ntt();
            brr.ntt();

            let crr: Vec<_> = arr.arr.iter().zip(brr.arr.iter()).map(|(&a, &b)| a * b % P).collect();
            let mut crr = Self::new(crr);
            crr.intt();
            crr
        }
    }

    fn merge<const P: u64, const Q: u64>(one: &[u64], two: &[u64]) -> Vec<u64> {
        let p: u64 = Ntt::<Q>::recip(P);
        let q: u64 = Ntt::<P>::recip(Q);
        let r: u64 = P * Q;
        one.iter()
            .zip(two.iter())
            .map(|(&a1, &a2)| ((a1 as u128 * q as u128 * Q as u128 + a2 as u128 * p as u128 * P as u128) % r as u128) as u64)
            .collect()
    }

    pub fn convolute(a: &[u64], b: &[u64]) -> Vec<u64> {
        const P: u64 = 2281701377;
        const Q: u64 = 998244353;
        let a: Vec<u64> = a.iter().copied().collect();
        let b: Vec<u64> = b.iter().copied().collect();

        let arr = Ntt::<P>::convolute(&a, &b);
        let brr = Ntt::<Q>::convolute(&a, &b);
        merge::<P, Q>(&arr.arr, &brr.arr)
    }
}

Code

mod ntt {
	// FFT_constname convention following https://algoshitpo.github.io/2020/05/20/fft-ntt/
	// p: prime for modulo
	// w: primitive root of p
	// p = a * 2^b + 1

	//             p  ntt_a ntt_b   ntt_w
	//   998,244,353    119    23       3
	// 2,281,701,377    17     27       3
	// 2,483,027,969    37     26       3
	// 2,113,929,217    63     25       5
	//   104,857,601    25     22       3
	// 1,092,616,193    521    21       3

	fn ceil_pow2(n: usize) -> u32 { n.next_power_of_two().trailing_zeros() }

	/// Reverses k trailing bits of n. Assumes that the rest of usize::BITS-k bits are all zero.
	const fn reverse_trailing_bits(n: usize, k: u32) -> usize { n.reverse_bits() >> (usize::BITS - k) }

	#[derive(Clone, Debug)]
	pub struct Ntt<const P: u64> {
		pub arr: Vec<u64>,
	}

	impl<const P: u64> Ntt<P> {
		pub const fn ntt_a() -> u64 {
			let p = P - 1;
			p >> p.trailing_zeros()
		}

		pub const fn ntt_b() -> u32 { (P - 1).trailing_zeros() }

		/// Primitive root of P
		pub const fn ntt_w() -> u64 {
			match P {
				998244353 | 2281701377 | 2483027969 | 104857601 | 1092616193 => 3,
				2113929217 => 5,
				_ => todo!(),
			}
		}

		const fn pow(mut base: u64, mut exp: u64) -> u64 {
			let mut ret = 1;
			while exp != 0 {
				if exp & 1 != 0 {
					ret = ret * base % P;
				}
				base = base * base % P;
				exp >>= 1;
			}
			ret
		}

		/// Returns an integer x where x^(2^n) == 1 mod P.
		/// That is, it returns (2^n)-th root of unity.
		const fn unity(n: u32, k: u64) -> u64 { Self::pow(Self::pow(Self::ntt_w(), Self::ntt_a()), k << (Self::ntt_b() - n)) }

		const fn recip(x: u64) -> u64 { Self::pow(x, P - 2) }

		pub fn new(arr: Vec<u64>) -> Self { Self { arr } }

		pub fn ntt(&mut self) {
			let n: usize = self.arr.len();
			let k = n.trailing_zeros();
			debug_assert_eq!(n, 1 << k);

			for i in 0..n {
				let j = reverse_trailing_bits(i, k);
				if i < j {
					self.arr.swap(i, j);
				}
			}

			let mut basis = vec![Self::unity(k, 1)];
			for i in 1..k as usize {
				basis.push(basis[i - 1] * basis[i - 1] % P);
			}
			for (x, &base) in basis.iter().rev().enumerate() {
				let s = 1 << x;
				for i in (0..n).step_by(s << 1) {
					let mut mult: u64 = 1;
					for j in i..i + s {
						let tmp = (self.arr[j + s] * mult) % P;
						self.arr[j + s] = (self.arr[j] + P - tmp) % P;
						self.arr[j] = (self.arr[j] + tmp) % P;
						mult = mult * base % P;
					}
				}
			}
		}

		pub fn intt(&mut self) {
			let n: usize = self.arr.len();
			let k = n.trailing_zeros();
			debug_assert_eq!(n, 1 << k);

			for i in 0..n {
				let j = reverse_trailing_bits(i, k);
				if i < j {
					self.arr.swap(i, j);
				}
			}

			let mut basis = vec![Self::recip(Self::unity(k, 1))];
			for i in 1..k as usize {
				basis.push(basis[i - 1] * basis[i - 1] % P);
			}
			for (x, &base) in basis.iter().rev().enumerate() {
				let s = 1 << x;
				for i in (0..n).step_by(s << 1) {
					let mut mult: u64 = 1;
					for j in i..i + s {
						let tmp = (self.arr[j + s] * mult) % P;
						self.arr[j + s] = (self.arr[j] + P - tmp) % P;
						self.arr[j] = (self.arr[j] + tmp) % P;
						mult = mult * base % P;
					}
				}
			}

			let r = Self::recip(n as u64);
			for f in self.arr.iter_mut() {
				*f *= r;
				*f %= P;
			}
		}

		pub fn convolute(a: &[u64], b: &[u64]) -> Self {
			let nlen = 1 << ceil_pow2(a.len() + b.len());
			let pad = |a: &[u64]| a.iter().copied().chain(std::iter::repeat(0)).take(nlen).collect();
			let arr = pad(a);
			let brr = pad(b);

			let mut arr = Self::new(arr);
			let mut brr = Self::new(brr);
			arr.ntt();
			brr.ntt();

			let crr: Vec<_> = arr.arr.iter().zip(brr.arr.iter()).map(|(&a, &b)| a * b % P).collect();
			let mut crr = Self::new(crr);
			crr.intt();
			crr
		}
	}

	fn merge<const P: u64, const Q: u64>(one: &[u64], two: &[u64]) -> Vec<u64> {
		let p = Ntt::<Q>::recip(P) as u128;
		let q = Ntt::<P>::recip(Q) as u128;
		let [pp, qq] = [P, Q].map(|x| x as u128);
		let r = (P * Q) as u128;

		one.iter()
			.zip(two.iter())
			.map(|(&a1, &a2)| {
				let [a, b] = [a1, a2].map(|x| x as u128);
				(a * q * qq + b * p * pp) % r
			})
			.map(|x| x as u64)
			.collect()
	}

	pub fn convolute(a: &[u64], b: &[u64]) -> Vec<u64> {
		const P: u64 = 2281701377;
		const Q: u64 = 998244353;

		let arr = Ntt::<P>::convolute(a, b);
		let brr = Ntt::<Q>::convolute(a, b);
		merge::<P, Q>(&arr.arr, &brr.arr)
	}
}

Last modified on 231203.

Linear Recurrence

Algorithms related to linear recurrences

Berlekamp-Massey

berlekamp_massey(A, m) returns a vector C of length \(n\) which satisfies \[ \begin{aligned} A_x &= \sum_{i=0}^{i=k-1} {C_i A_{x-k+i}} \\ &= C_0 A_{x-k} + C_1 A_{x-k+1} + \cdots + C_{k-1} A_{x-1} \end{aligned} \] with minimum \(n\) under prime modulo \(m\). It is safe to have the length of vals as at least \(3n\).

Example

fn main() {
// vals[x] = vals[x-3] + 2*vals[x-2] + 3*vals[x-1]
let m: u64 = 1000000007;
let mut vals: Vec<u64> = vec![1, 2, 3];
for x in 3..20 {
    vals.push((vals[x - 3] + 2 * vals[x - 2] + 3 * vals[x - 1]) % m);
}

let rec = berlekamp_massey(&vals, m);
println!("{:?}", rec); // [1, 2, 3]
}

// Berlekamp-Massey
// References
// https://blog.naver.com/jinhan814/222140081932
// https://koosaga.com/231

fn rem_pow(mut base: i64, mut exp: i64, m: i64) -> i64 {
    let mut result = 1;
    while exp != 0 {
        if exp & 1 != 0 {
            result = (result * base) % m;
        }
        exp >>= 1;
        base = (base * base) % m;
    }
    result
}

/// Finds rec[n] which satisfies
/// vals[d] = rec[0]vals[0] + rec[1]vals[1] + ... + rec[d-1]vals[d-1]
/// with minimum n.
fn berlekamp_massey(vals: &[u64], m: u64) -> Vec<u64> {
    let m = m as i64;
    let mut cur: Vec<i64> = Vec::new();
    let (mut lf, mut ld) = (0, 0);
    let mut ls: Vec<i64> = Vec::new();
    for i in 0..vals.len() {
        let mut t = 0;
        for (j, v) in cur.iter().enumerate() {
            t = (t + vals[i - j - 1] as i64 * v) % m;
        }

        if (t - vals[i] as i64) % m == 0 {
            continue;
        }

        if cur.len() == 0 {
            cur = vec![0; i + 1];
            lf = i;
            ld = (t - vals[i] as i64) % m;
            continue;
        }

        let k = -(vals[i] as i64 - t) * rem_pow(ld, m - 2, m) % m;
        let mut c: Vec<i64> = vec![0; i - lf + ls.len()];
        c[i - lf - 1] = k as i64;
        for (p, j) in ls.iter().enumerate() {
            c[i - lf + p] = -j * k % m;
        }

        if c.len() < cur.len() {
            c.extend((0..(cur.len() - c.len())).map(|_| 0));
        }

        for j in 0..cur.len() {
            c[j] = (c[j] + cur[j]) % m;
        }

        if i - lf + ls.len() >= cur.len() {
            ls = cur;
            lf = i;
            ld = (t - vals[i] as i64) % m;
        }

        cur = c;
    }

    for i in 0..cur.len() {
        cur[i] = (cur[i] % m + m) % m;
    }

    cur.into_iter().rev().map(|x| x as u64).collect()
}

Code

// Berlekamp-Massey
// References
// https://blog.naver.com/jinhan814/222140081932
// https://koosaga.com/231

fn rem_pow(mut base: i64, mut exp: i64, m: i64) -> i64 {
    let mut result = 1;
    while exp != 0 {
        if exp & 1 != 0 {
            result = (result * base) % m;
        }
        exp >>= 1;
        base = (base * base) % m;
    }
    result
}

/// Finds rec[n] which satisfies
/// vals[d] = rec[0]vals[0] + rec[1]vals[1] + ... + rec[d-1]vals[d-1]
/// with minimum n.
fn berlekamp_massey(vals: &[u64], m: u64) -> Vec<u64> {
    let m = m as i64;
    let mut cur: Vec<i64> = Vec::new();
    let (mut lf, mut ld) = (0, 0);
    let mut ls: Vec<i64> = Vec::new();
    for i in 0..vals.len() {
        let mut t = 0;
        for (j, v) in cur.iter().enumerate() {
            t = (t + vals[i - j - 1] as i64 * v) % m;
        }

        if (t - vals[i] as i64) % m == 0 {
            continue;
        }

        if cur.len() == 0 {
            cur = vec![0; i + 1];
            lf = i;
            ld = (t - vals[i] as i64) % m;
            continue;
        }

        let k = -(vals[i] as i64 - t) * rem_pow(ld, m - 2, m) % m;
        let mut c: Vec<i64> = vec![0; i - lf + ls.len()];
        c[i - lf - 1] = k as i64;
        for (p, j) in ls.iter().enumerate() {
            c[i - lf + p] = -j * k % m;
        }

        if c.len() < cur.len() {
            c.extend((0..(cur.len() - c.len())).map(|_| 0));
        }

        for j in 0..cur.len() {
            c[j] = (c[j] + cur[j]) % m;
        }

        if i - lf + ls.len() >= cur.len() {
            ls = cur;
            lf = i;
            ld = (t - vals[i] as i64) % m;
        }

        cur = c;
    }

    for i in 0..cur.len() {
        cur[i] = (cur[i] % m + m) % m;
    }

    cur.into_iter().rev().map(|x| x as u64).collect()
}

Kitamasa

kitamasa(C, A, n, m) returns \(A_n\) where \[ \begin{aligned} A_x &= \sum_{i=0}^{i=k-1} {C_i A_{x-k+i}} \\ &= C_0 A_{x-k} + C_1 A_{x-k+1} + \cdots + C_{k-1} A_{x-1} \end{aligned} \] in a time complexity of \( O(T(k) \log{n}) \), where \(O(T(k))\) is a time complexity taken for multiplying two polynomials of order \(k\), and \(k\) is a length of \(C\).

Example

fn main() {
// vals[x] = vals[x-3] + 2*vals[x-2] + 3*vals[x-1]
// 1, 2, 3, 14, 50, 181, 657, 2383, 8644, 31355, 113736, 412562, 1496513, 5428399, 19690785, 71425666, ...
let m: u64 = 1000000007;
let vals: Vec<u64> = vec![1, 2, 3];
let rec: Vec<u64> = vec![1, 2, 3];

let v = kitamasa(&rec, &vals, 15, m);
println!("{}", v); // 71425666
}

// Kitamasas
// Reference: https://justicehui.github.io/hard-algorithm/2021/03/13/kitamasa/

fn poly_mul(v: &[u64], w: &[u64], rec: &[u64], m: u64) -> Vec<u64> {
    let mut t = vec![0; 2 * v.len()];

    for j in 0..v.len() {
        for k in 0..w.len() {
            t[j + k] += v[j] * w[k] % m;
            if t[j + k] >= m {
                t[j + k] -= m;
            }
        }
    }

    for j in (v.len()..2 * v.len()).rev() {
        for k in 1..=v.len() {
            t[j - k] += t[j] * rec[k - 1] % m;
            if t[j - k] >= m {
                t[j - k] -= m;
            }
        }
    }

    t[..v.len()].iter().map(|x| *x).collect()
}

/// Finds arr[n] where
/// arr[n+d] = rec[0]arr[n] + rec[1]arr[n+1] + rec[2]arr[n+2] + rec[3]arr[n+3] + ... + rec[d-1]arr[n+d-1]
/// under modulo m where d=rec.len()=arr.len()
fn kitamasa(rec: &[u64], vals: &[u64], mut n: u64, m: u64) -> u64 {
    let recurr: Vec<_> = rec.iter().rev().copied().collect();
    let (mut s, mut t) = (vec![0u64; recurr.len()], vec![0u64; recurr.len()]);
    s[0] = 1;
    if recurr.len() != 1 {
        t[1] = 1;
    } else {
        t[0] = recurr[0];
    }

    while n != 0 {
        if n & 1 != 0 {
            s = poly_mul(&s, &t, &recurr, m);
        }
        t = poly_mul(&t, &t, &recurr, m);
        n >>= 1;
    }

    let mut ret = 0u64;
    for i in 0..recurr.len() {
        ret += s[i] * vals[i] % m;
        if ret >= m {
            ret -= m;
        }
    }
    ret
}

\(O(k^2 \log{n})\) Implementation

The implementation below uses naive polynomial multiplication.

// Kitamasa
// Reference: JusticeHui's Blog: <https://justicehui.github.io/hard-algorithm/2021/03/13/kitamasa/>

fn poly_mul(v: &[u64], w: &[u64], rec: &[u64], m: u64) -> Vec<u64> {
    let mut t = vec![0; 2 * v.len()];

    for j in 0..v.len() {
        for k in 0..w.len() {
            t[j + k] += v[j] * w[k] % m;
            if t[j + k] >= m {
                t[j + k] -= m;
            }
        }
    }

    for j in (v.len()..2 * v.len()).rev() {
        for k in 1..=v.len() {
            t[j - k] += t[j] * rec[k - 1] % m;
            if t[j - k] >= m {
                t[j - k] -= m;
            }
        }
    }

    t[..v.len()].iter().map(|x| *x).collect()
}

/// Finds arr[n] where
/// arr[n+d] = rec[0]arr[n] + rec[1]arr[n+1] + rec[2]arr[n+2] + rec[3]arr[n+3] + ... + rec[d-1]arr[n+d-1]
/// under modulo m where d=rec.len()=arr.len()
fn kitamasa(rec: &[u64], vals: &[u64], mut n: u64, m: u64) -> u64 {
    let recurr: Vec<_> = rec.iter().rev().copied().collect();
    let (mut s, mut t) = (vec![0u64; recurr.len()], vec![0u64; recurr.len()]);
    s[0] = 1;
    if recurr.len() != 1 {
        t[1] = 1;
    } else {
        t[0] = recurr[0];
    }

    while n != 0 {
        if n & 1 != 0 {
            s = poly_mul(&s, &t, &recurr, m);
        }
        t = poly_mul(&t, &t, &recurr, m);
        n >>= 1;
    }

    let mut ret = 0u64;
    for i in 0..recurr.len() {
        ret += s[i] * vals[i] % m;
        if ret >= m {
            ret -= m;
        }
    }
    ret
}

Bostan-Mori

Reference - Alin Bostan, Ryuhei Mori. A Simple and Fast Algorithm for Computing the N-th Term of a Linearly Recurrent Sequence. SOSA’21 (SIAM Symposium on Simplicity in Algorithms), Jan 2021, Alexandria, United States. ffhal-02917827v2f

Given \(k\) initial values \(c_0,\ c_1,\ \cdots,\ c_{k-1}\) and a linear recurrence of length \(k\): \[ f_{n+k} = c_0f_n + c_1f_{n+1} + \cdots + c_{k-1}f_{n+k-1} \] bostan_mori::<P>(c, f, n) calculates \(f_n \bmod P\). Here, \(P\) can be any integer larger than \(1\).

The time complexity of this algorithm is \(O(\mathrm{M}(k) \log n)\) where \(\mathrm{M}(k)\) represents the time complexity for polynomial multiplication. This depends on the implementation of poly_mul. In this snippet, naive polynomial multiplication is used, resulting in a time complexity of \(O(k^2 \log n)\). By replacing it with an implementation using FFT, the time complexity can be improved to \(O(k \log k \log n)\).

Example

This example calculates terms of \({f_n}\) for \(0 \le n \le 10\) where \(f_0 = 0\), \(f_1 = 1\), and \(f_{n+2} = 2f_n + f_{n+1}\).

fn main() {
for i in 0..=10 {
    println!("{}", bostan_mori::<12345>(&[2, 1], &[0, 1], i));
}
}

/// Finds arr[n] where
/// arr[n+d] = rec[0]arr[n] + rec[1]arr[n+1] + rec[2]arr[n+2] + rec[3]arr[n+3] + ... + rec[d-1]arr[n+d-1]
/// under mod P where d=rec.len()=arr.len()
fn bostan_mori<const P: u64>(rec: &[u64], vals: &[u64], mut n: u64) -> u64 {
    if vals.len() as u64 > n {
        return vals[n as usize];
    }

    let x = (0..rec.len()).find(|&i| rec[i] != 0).unwrap_or(rec.len());
    if x == rec.len() {
        return 0;
    }

    let vals: Vec<u64> = vals.iter().map(|&v| v % P).collect();
    let rec: Vec<u64> = rec.iter().map(|&v| v % P).collect();
    let mut q = vec![1];
    q.extend(rec.iter().rev().map(|&v| (P - v) % P));
    let mut p = poly_mul::<P>(&vals, &q);
    p.truncate(rec.len());

    while n >= 1 {
        let mq: Vec<_> = q
            .iter()
            .enumerate()
            .map(|(i, &v)| if i % 2 == 0 { v } else { (P - v) % P })
            .collect();
        let u = poly_mul::<P>(&p, &mq);
        p = u
            .iter()
            .copied()
            .skip((n % 2) as usize)
            .step_by(2)
            .collect();
        let a = poly_mul::<P>(&q, &mq);
        q = a.iter().copied().step_by(2).collect();
        n /= 2;
    }

    p[0] % P
}

fn poly_mul<const P: u64>(a: &[u64], b: &[u64]) -> Vec<u64> {
    let mut ret = vec![0; a.len() + b.len() - 1];
    for (i, &av) in a.iter().enumerate() {
        for (j, &bv) in b.iter().enumerate() {
            ret[i + j] = (ret[i + j] + av * bv) % P;
        }
    }
    ret
}

Snippet

/// Finds arr[n] where
/// arr[n+d] = rec[0]arr[n] + rec[1]arr[n+1] + rec[2]arr[n+2] + rec[3]arr[n+3] + ... + rec[d-1]arr[n+d-1]
/// under mod P where d=rec.len()=arr.len()
fn bostan_mori<const P: u64>(rec: &[u64], vals: &[u64], mut n: u64) -> u64 {
	if vals.len() as u64 > n {
		return vals[n as usize];
	}

	let x = (0..rec.len()).find(|&i| rec[i] != 0).unwrap_or(rec.len());
	if x == rec.len() {
		return 0;
	}

	let vals: Vec<u64> = vals.iter().map(|&v| v % P).collect();
	let rec: Vec<u64> = rec.iter().map(|&v| v % P).collect();
	let mut q = vec![1];
	q.extend(rec.iter().rev().map(|&v| (P - v) % P));
	let mut p = poly_mul::<P>(&vals, &q);
	p.truncate(rec.len());

	while n >= 1 {
		let mq: Vec<_> = q.iter().enumerate().map(|(i, &v)| if i % 2 == 0 { v } else { (P - v) % P }).collect();
		let u = poly_mul::<P>(&p, &mq);
		p = u.iter().copied().skip((n % 2) as usize).step_by(2).collect();
		let a = poly_mul::<P>(&q, &mq);
		q = a.iter().copied().step_by(2).collect();
		n /= 2;
	}

	p[0] % P
}

/// Naive O(k^2) multiplication
fn poly_mul<const P: u64>(a: &[u64], b: &[u64]) -> Vec<u64> {
	let mut ret = vec![0; a.len() + b.len() - 1];
	for (i, &av) in a.iter().enumerate() {
		for (j, &bv) in b.iter().enumerate() {
			ret[i + j] = (ret[i + j] + av * bv) % P;
		}
	}
	ret
}

Last modified on 231203.

Integer Square Root

isqrt(s) returns \( \left\lfloor \sqrt{s} \right\rfloor \). It runs much faster than the typical binary search method, but slower than casting the result from std::f64::sqrt. If the value can be perfectly represented with f64 and the memory limit isn't too short, it's better to use the f64 square root function from std.

Example

fn main() {
let x: u64 = 10002;
let sq = isqrt(x);
println!("{}", sq); // 100
}

fn isqrt(s: u64) -> u64 {
    let mut x0 = s >> 1;
    if x0 != 0 {
        let mut x1 = (x0 + s / x0) >> 1;
        while x1 < x0 {
            x0 = x1;
            x1 = (x0 + s / x0) >> 1
        }
        x0
    } else {
        s
    }
}

Code

fn isqrt<T>(s: T) -> T
where T: Copy + Shr<Output = T> + Add<Output = T> + Div<Output = T> + PartialOrd + From<u8> {
	let mut x0 = s >> 1.into();
	if x0 != 0.into() {
		let mut x1 = (x0 + s / x0) >> 1.into();
		while x1 < x0 {
			x0 = x1;
			x1 = (x0 + s / x0) >> 1.into();
		}
		x0
	} else {
		s
	}
}

Last modified on 231203.

2D Matrix

So far, only the very basic functionalities are implemented.

Code

pub struct Mat<T> {
	row: usize,
	col: usize,
	data: Vec<T>,
}

impl<T> Mat<T> {
	pub fn new(row: usize, col: usize) -> Self
	where T: Default {
		Self { row, col, data: (0..row * col).map(|_| Default::default()).collect() }
	}

	pub fn resize_iter(row: usize, col: usize, data: impl IntoIterator<Item = T>) -> Self {
		let data: Vec<T> = data.into_iter().take(row * col).collect();
		debug_assert_eq!(data.len(), row * col);
		Self { row, col, data }
	}
}

impl<T> Index<usize> for Mat<T> {
	type Output = [T];
	fn index(&self, idx: usize) -> &Self::Output {
		let l = idx * self.col;
		&self.data[l..l + self.col]
	}
}

impl<T> IndexMut<usize> for Mat<T> {
	fn index_mut(&mut self, idx: usize) -> &mut Self::Output {
		let l = idx * self.col;
		&mut self.data[l..l + self.col]
	}
}

Disjoint Set Union

Disjoint set union data structure.

Example

fn main() {
let mut uf = UnionFind::new(10);
println!("{}", uf.find_root(2) == uf.find_root(6)); // false
uf.union(2, 6);
println!("{}", uf.find_root(2) == uf.find_root(6)); // true
}

struct UnionFind {
    parent: Vec<usize>,
    size: Vec<usize>,
    num: usize,
}

impl UnionFind {
    fn new(n: usize) -> Self {
        Self {
            parent: (0..n).collect(),
            size: vec![1; n],
            num: n,
        }
    }

    fn find_root(&mut self, mut x: usize) -> usize {
        while self.parent[x] != x {
            self.parent[x] = self.parent[self.parent[x]];
            x = self.parent[x];
        }
        x
    }

    fn union(&mut self, u: usize, v: usize) {
        let u = self.find_root(u);
        let v = self.find_root(v);
        if u != v {
            self.num -= 1;
            if self.size[u] < self.size[v] {
                self.parent[u] = v;
                self.size[v] += self.size[u];
            } else {
                self.parent[v] = u;
                self.size[u] += self.size[v];
            }
        }
    }

    fn get_size(&mut self, x: usize) -> usize {
        let r = self.find_root(x);
        self.size[r]
    }
}

Code

struct UnionFind {
    parent: Vec<usize>,
    size: Vec<usize>,
    num: usize,
}

impl UnionFind {
    fn new(n: usize) -> Self {
        Self {
            parent: (0..n).collect(),
            size: vec![1; n],
            num: n,
        }
    }

    fn find_root(&mut self, mut x: usize) -> usize {
        while self.parent[x] != x {
            self.parent[x] = self.parent[self.parent[x]];
            x = self.parent[x];
        }
        x
    }

    fn union(&mut self, u: usize, v: usize) {
        let u = self.find_root(u);
        let v = self.find_root(v);
        if u != v {
            self.num -= 1;
            if self.size[u] < self.size[v] {
                self.parent[u] = v;
                self.size[v] += self.size[u];
            } else {
                self.parent[v] = u;
                self.size[u] += self.size[v];
            }
        }
    }

    fn get_size(&mut self, x: usize) -> usize {
        let r = self.find_root(x);
        self.size[r]
    }
}

Weighted DSU

Disjoint set union where vertices have their own "potential value". The potential value of a vertex \(i\) is denoted as \(w(i)\), and the differences of potential values between vertices in a same group are always well defined.

fn union(&mut self, u: usize, v: usize, dw: i64) addes a "rule" stating that \(u\) and \(v\) are in a same group, and \(w(u) - w(v) = dw\). Based on posed rules ahead, fn get_pot_diff(&mut self, u: usize, v: usize) -> Option<i64> returns \(w(u) - w(v)\) if \(u\) and \(v\) are in a same group.

potdiff[i] in the code is defined as \(w(i) - w(r_i)\), where \(r_i\) is a root of a group \(i\) is in.

Code

struct WeightDSU {
	parent: Vec<usize>,
	size: Vec<usize>,
	num: usize,
	potdiff: Vec<i64>, // potdiff[x] == w(x) - w(p)
}

impl WeightDSU {
	fn new(n: usize) -> Self {
		Self {
			parent: (0..n).collect(),
			size: vec![1; n],
			num: n,
			potdiff: vec![0; n],
		}
	}

	// Returns the root of x.
	fn find_root(&mut self, mut x: usize) -> usize {
		while self.parent[x] != x {
			let p = self.parent[x];
			self.potdiff[x] += self.potdiff[p];
			self.parent[x] = self.parent[p];
			x = self.parent[x];
		}
		x
	}

	// Returns the root of x, namely xr, with w(x) - w(xr).
	fn find_root_with_pdiff(&mut self, mut x: usize) -> (usize, i64) {
		let mut pd = 0;
		while self.parent[x] != x {
			let p = self.parent[x];
			self.potdiff[x] += self.potdiff[p];
			self.parent[x] = self.parent[p];
			pd += self.potdiff[x];
			x = self.parent[x];
		}
		(x, pd)
	}

	// Unions groups of u and v, with w(u) - w(v) = dw. If u and v are already in a same group,
	// if w(u) - w(v) == dv then it returns Some(true), otherwise Some(false). If they weren't
	// in a same group, then it returns None and unions two groups following the given dw.
	fn union(&mut self, u: usize, v: usize, dw: i64) -> Option<bool> {
		let (ur, pu) = self.find_root_with_pdiff(u);
		let (vr, pv) = self.find_root_with_pdiff(v);
		let nw = dw - pu + pv;
		if ur != vr {
			self.num -= 1;
			if self.size[ur] < self.size[vr] {
				self.parent[ur] = vr;
				self.size[vr] += self.size[ur];
				self.potdiff[ur] = nw;
			} else {
				self.parent[vr] = ur;
				self.size[ur] += self.size[vr];
				self.potdiff[vr] = -nw;
			}
			None
		} else {
			Some(nw == 0)
		}
	}

	// Returns the size of a group x is in.
	fn get_size(&mut self, x: usize) -> usize {
		let r = self.find_root(x);
		self.size[r]
	}

	// Returns Some(w(u) - w(v)) if u and v are in the same group, otherwise None.
	fn get_pot_diff(&mut self, u: usize, v: usize) -> Option<i64> {
		let (ur, pu) = self.find_root_with_pdiff(u);
		let (vr, pv) = self.find_root_with_pdiff(v);
		if ur == vr {
			Some(pu - pv)
		} else {
			None
		}
	}
}

Segment Trees

Segment trees are a category of data structures which can handle range queries efficiently.

Segment Tree

Reference: AtCoder library https://atcoder.github.io/ac-library/production/document_en/index.html

A segment tree is a data structure for monoids \( (S, \cdot : S \times S \rightarrow S, e \in S) \). A monoid is an algebraic structure which follows the following conditions:

  • \(\cdot\) is associative. That is, \( (a \cdot b) \cdot c = a \cdot (b \cdot c) \) for all \( a, b, c \in S \).
  • There is the identity element \(e\) such that \( a \cdot e = e \cdot a = a \) for all \( a \in S \).

Given an array \(A\) of length \(n\) consists of the monoid \(S\) as described above, a segment tree on it can process the following queries in \(O (\log{n})\) time:

  • Update an element
  • Calculate the product of the elements of an interval

assuming that calculating the product of two elements takes \(O(1)\) time.

Example

use segtree::*;

fn main() {
// Product segment tree with size of 10 and all elements being 1
let st = SegTree::new(10, std::iter::repeat(1), 1, |x, y| x * y % 1000000007);
// Sum segment tree with initial values
let st = SegTree::new(10, vec![1, 3, 2, 4, 3, 5, 4, 6, 5, 7], 0, |x, y| x + y);

// Sum segment tree with size of 10 and all elements being 0
let mut st = SegTree::new(10, None, 0, |x, y| x + y);

st.update(2, |_| 3);
println!("{}", st.get(2)); // 3

let prev = st.get(4);
st.update(4, |x| x + 2);
let curr = st.get(4);
println!("{prev} -> {curr}"); // 0 -> 2

println!("{}", st.prod(..4)); // 3
println!("{}", st.prod(4..)); // 2
println!("{}", st.prod(..)); // 5
println!("{}", st.prod(100..101)); // 0
println!("{}", st.prod(4..2)); // 0

for i in 1..=st.len() {
    print!("{} ", st.prod(1..i));
}
println!(); // 0 0 3 3 5 5 5 5 5 5
println!("{}", st.partition_point(1, |x| x < 3)); // 3
println!("{}", st.partition_point(1, |x| x < 5)); // 5
}

mod segtree {
    use std::ops::RangeBounds;

    /// A segment tree is a data structure for a monoid type `T`.
    ///
    /// Given all the constraints written in the comment of `SegTree::new`, a segment tree can process the following queries in O(TlogN) time assuming `op` all run in O(T).
    /// - Changing the value of an element
    /// - Calculating the product of elements in an interval, combined with `op`
    pub struct SegTree<T, O> {
        n: usize,
        data: Vec<T>,
        e: T,
        op: O,
        size: usize,
        log: u32,
    }

    impl<T: Copy, O: Fn(T, T) -> T> SegTree<T, O> {
        fn get_bounds(&self, range: impl RangeBounds<usize>) -> (usize, usize) {
            use std::ops::Bound::*;
            let n = self.len();
            let l = match range.start_bound() {
                Included(&v) => v,
                Excluded(&v) => v + 1,
                Unbounded => 0,
            };
            let r = match range.end_bound() {
                Included(&v) => (v + 1).min(n),
                Excluded(&v) => v.min(n),
                Unbounded => n,
            };
            if l > r {
                return (l, l);
            }
            (l, r)
        }

        fn upd(&mut self, k: usize) {
            self.data[k] = (self.op)(self.data[k * 2], self.data[(k * 2) + 1]);
        }

        /// Returns a new segment tree of size `n` built from `iter`.
        ///
        /// The meanings of parameters and some generic types are as follows.
        /// - `T` is a type of values in the array the segment tree represents.
        /// - `n` is a number of elements in the array.
        /// - `iter` is an iterator returning initial values of the array.
        ///   - If `iter.count() < n`, then the rest is filled with `e`.
        ///   - If `iter.count() > n`, the array is truncated down to the length of `n`.
        /// - `op: impl Fn(T, T) -> T` is a binary operator for `T`.
        /// - `e` is an identity for `op`.
        ///
        /// The following notations will be used from now on.
        /// - `op(a, b)` is denoted as `a*b`.
        ///
        /// Constraints of parameters are as follows.
        /// - `op` and `e` must make `T` a monoid. That is, `op` and `e` should be given so that `T` can satisfy the following conditions.
        ///   - `T` is associative under `op`. That is, `(a*b)*c == a*(b*c)` for all `[a, b, c]: [T; 3]`.
        ///   - `T` has `e` as an identity element under `op`. That is, `a*e == e*a == a` for all `a: T`.
        ///
        /// For example, a generic range sum segment tree with every value initialized with `0` and of length `n` can be constucted as follows.
        /// ```no_run
        /// let mut st = SegTree::new(n, None, 0i64, |x, y| x + y);
        /// ```
        pub fn new(n: usize, iter: impl IntoIterator<Item = T>, e: T, op: O) -> Self {
            let size = n.next_power_of_two();
            let log = size.trailing_zeros();

            let mut data = vec![e; size];
            data.extend(iter.into_iter().take(n));
            data.resize(2 * size, e);

            let mut st = Self { n, data, e, op, size, log };
            for i in (1..size).rev() {
                st.upd(i);
            }
            st
        }

        /// Returns the length of the array.
        pub fn len(&self) -> usize {
            self.n
        }

        /// Returns the `i`-th value of the array.
        pub fn get(&self, i: usize) -> T {
            self.data[i + self.size]
        }

        /// Assign `upd_to(self.get(i))` to the `i`-th element.
        pub fn update(&mut self, i: usize, upd_to: impl Fn(T) -> T) {
            let i = i + self.size;
            self.data[i] = upd_to(self.data[i]);
            for j in 1..=self.log {
                self.upd(i >> j);
            }
        }

        /// Returns the product of elements in `range`.
        pub fn prod(&self, range: impl RangeBounds<usize>) -> T {
            let (mut l, mut r) = self.get_bounds(range);
            (l += self.size, r += self.size);

            if (l, r) == (0, self.n) {
                return self.data[1];
            } else if l == r {
                return self.e;
            }

            let (mut sml, mut smr) = (self.e, self.e);
            while l < r {
                if l & 1 == 1 {
                    sml = (self.op)(sml, self.data[l]);
                    l += 1;
                }
                if r & 1 == 1 {
                    r -= 1;
                    smr = (self.op)(self.data[r], smr);
                }
                (l >>= 1, r >>= 1);
            }

            (self.op)(sml, smr)
        }

        /// For a function `pred` which has a nonnegative value `x`, such that `pred(self.prod(l..r))` is `false` if and only if `x <= r`, `self.partition_point(l, pred)` returns the value of such `x`.
        /// That is, this is the minimum value of `r` such that `pred(self.prod(l..r))` starts to be `false`.
        /// If `pred(self.e)` is `true`, then this function assumes that `pred(self.prod(l..r))` is always `true` for any `r` in range `l..=self.len()` and returns `l`.
        /// However, it's recommended to always set `pred(self.e)` to be `true` to avoid unnecessary case works.
        ///
        /// ## Constraints
        /// - `0 <= l <= self.len()`
        ///
        /// ## Examples
        /// `f(r) := pred(self.prod(l..r))`
        ///
        /// Given that `self.len() == 7`, calling `self.partition_point(0)` returns values written below.
        /// ```text
        ///    r |     0     1     2     3     4     5     6     7     8
        ///
        /// f(r) |  true  true  true  true false false false false   N/A
        ///                             returns^
        ///
        /// f(r) | false false false false false false false false   N/A
        ///     returns^
        ///
        /// f(r) |  true  true  true  true  true  true  true  true   N/A
        ///                                                     returns^
        /// ```
        pub fn partition_point(&self, l: usize, pred: impl Fn(T) -> bool) -> usize {
            if !pred(self.e) {
                // `pred(self.prod(l..l))` is `false`
                // Thus l is returned.
                // This case is not covered in the original implementation as it simply requires pred(self.e) to be `true`
                return l;
            }

            if l == self.n {
                // `pred(self.e)` has already been checked that it's `true`.
                // Thus the answer must be `self.n`.
                return self.n;
            }

            let mut l = l + self.size;
            let mut sm = self.e;

            loop {
                l >>= l.trailing_zeros();
                if !pred((self.op)(sm, self.data[l])) {
                    while l < self.size {
                        l <<= 1;
                        let tmp = (self.op)(sm, self.data[l]);
                        if pred(tmp) {
                            sm = tmp;
                            l += 1;
                        }
                    }
                    return l + 1 - self.size;
                }
                sm = (self.op)(sm, self.data[l]);
                l += 1;
                if l & ((!l) + 1) == l {
                    break;
                }
            }
            self.n + 1
        }

        /// For a function `pred` which has a value `x` less than or equal to `r`, such that `pred(self.prod(l..r))` is `true` if and only if `x <= l`, `self.left_partition_point(r, pred)` returns the value of such `x`.
        /// That is, this is the minimum value of `l` such that `pred(self.prod(l..r))` starts to be `true`.
        /// If `pred(self.e)` is `false`, then this function assumes that `pred(self.prod(l..r))` is always `false` for any `l` in range `0..=r` and returns `r+1`.
        /// However, it's recommended to always set `pred(self.e)` to be `true` to avoid unnecessary case works.
        ///
        /// ## Constraints
        /// - `0 <= r <= self.len()`
        ///
        /// ## Examples
        /// `f(l) := pred(self.prod(l..r))`
        ///
        /// Calling `self.left_partition_point(7)` returns values written below.
        /// ```text
        ///    l |     0     1     2     3     4     5     6     7     8
        ///
        /// f(l) | false false false false  true  true  true  true   N/A
        ///                             returns^
        ///
        /// f(l) |  true  true  true  true  true  true  true  true   N/A
        ///     returns^
        ///
        /// f(l) | false false false false false false false false   N/A
        ///                                                     returns^
        /// ```
        pub fn left_partition_point(&self, r: usize, pred: impl Fn(T) -> bool) -> usize {
            if !pred(self.e) {
                return r + 1;
            }

            if r == 0 {
                // `pred(self.e)` is always `true` at this point
                return 0;
            }

            let mut r = r + self.size;
            let mut sm = self.e;

            loop {
                r -= 1;
                while r > 1 && r & 1 == 1 {
                    r >>= 1;
                }
                if !pred((self.op)(self.data[r], sm)) {
                    while r < self.size {
                        r = (r << 1) + 1;
                        let tmp = (self.op)(self.data[r], sm);
                        if pred(tmp) {
                            sm = tmp;
                            r -= 1;
                        }
                    }
                    return r + 1 - self.size;
                }
                sm = (self.op)(self.data[r], sm);
                if r & ((!r) + 1) == r {
                    break;
                }
            }
            0
        }
    }
}

Code

mod segtree {
    use std::ops::RangeBounds;

    /// A segment tree is a data structure for a monoid type `T`.
    ///
    /// Given all the constraints written in the comment of `SegTree::new`, a segment tree can process the following queries in O(TlogN) time assuming `op` all run in O(T).
    /// - Changing the value of an element
    /// - Calculating the product of elements in an interval, combined with `op`
    pub struct SegTree<T, O> {
        n: usize,
        data: Vec<T>,
        e: T,
        op: O,
        size: usize,
        log: u32,
    }

    impl<T: Copy, O: Fn(T, T) -> T> SegTree<T, O> {
        fn get_bounds(&self, range: impl RangeBounds<usize>) -> (usize, usize) {
            use std::ops::Bound::*;
            let n = self.len();
            let l = match range.start_bound() {
                Included(&v) => v,
                Excluded(&v) => v + 1,
                Unbounded => 0,
            };
            let r = match range.end_bound() {
                Included(&v) => (v + 1).min(n),
                Excluded(&v) => v.min(n),
                Unbounded => n,
            };
            if l > r {
                return (l, l);
            }
            (l, r)
        }

        fn upd(&mut self, k: usize) {
            self.data[k] = (self.op)(self.data[k * 2], self.data[(k * 2) + 1]);
        }

        /// Returns a new segment tree of size `n` built from `iter`.
        ///
        /// The meanings of parameters and some generic types are as follows.
        /// - `T` is a type of values in the array the segment tree represents.
        /// - `n` is a number of elements in the array.
        /// - `iter` is an iterator returning initial values of the array.
        ///   - If `iter.count() < n`, then the rest is filled with `e`.
        ///   - If `iter.count() > n`, the array is truncated down to the length of `n`.
        /// - `op: impl Fn(T, T) -> T` is a binary operator for `T`.
        /// - `e` is an identity for `op`.
        ///
        /// The following notations will be used from now on.
        /// - `op(a, b)` is denoted as `a*b`.
        ///
        /// Constraints of parameters are as follows.
        /// - `op` and `e` must make `T` a monoid. That is, `op` and `e` should be given so that `T` can satisfy the following conditions.
        ///   - `T` is associative under `op`. That is, `(a*b)*c == a*(b*c)` for all `[a, b, c]: [T; 3]`.
        ///   - `T` has `e` as an identity element under `op`. That is, `a*e == e*a == a` for all `a: T`.
        ///
        /// For example, a generic range sum segment tree with every value initialized with `0` and of length `n` can be constucted as follows.
        /// ```no_run
        /// let mut st = SegTree::new(n, None, 0i64, |x, y| x + y);
        /// ```
        pub fn new(n: usize, iter: impl IntoIterator<Item = T>, e: T, op: O) -> Self {
            let size = n.next_power_of_two();
            let log = size.trailing_zeros();

            let mut data = vec![e; size];
            data.extend(iter.into_iter().take(n));
            data.resize(2 * size, e);

            let mut st = Self { n, data, e, op, size, log };
            for i in (1..size).rev() {
                st.upd(i);
            }
            st
        }

        /// Returns the length of the array.
        pub fn len(&self) -> usize {
            self.n
        }

        /// Returns the `i`-th value of the array.
        pub fn get(&self, i: usize) -> T {
            self.data[i + self.size]
        }

        /// Assign `upd_to(self.get(i))` to the `i`-th element.
        pub fn update(&mut self, i: usize, upd_to: impl Fn(T) -> T) {
            let i = i + self.size;
            self.data[i] = upd_to(self.data[i]);
            for j in 1..=self.log {
                self.upd(i >> j);
            }
        }

        /// Returns the product of elements in `range`.
        pub fn prod(&self, range: impl RangeBounds<usize>) -> T {
            let (mut l, mut r) = self.get_bounds(range);
            (l += self.size, r += self.size);

            if (l, r) == (0, self.n) {
                return self.data[1];
            } else if l == r {
                return self.e;
            }

            let (mut sml, mut smr) = (self.e, self.e);
            while l < r {
                if l & 1 == 1 {
                    sml = (self.op)(sml, self.data[l]);
                    l += 1;
                }
                if r & 1 == 1 {
                    r -= 1;
                    smr = (self.op)(self.data[r], smr);
                }
                (l >>= 1, r >>= 1);
            }

            (self.op)(sml, smr)
        }

        /// For a function `pred` which has a nonnegative value `x`, such that `pred(self.prod(l..r))` is `false` if and only if `x <= r`, `self.partition_point(l, pred)` returns the value of such `x`.
        /// That is, this is the minimum value of `r` such that `pred(self.prod(l..r))` starts to be `false`.
        /// If `pred(self.e)` is `true`, then this function assumes that `pred(self.prod(l..r))` is always `true` for any `r` in range `l..=self.len()` and returns `l`.
        /// However, it's recommended to always set `pred(self.e)` to be `true` to avoid unnecessary case works.
        ///
        /// ## Constraints
        /// - `0 <= l <= self.len()`
        ///
        /// ## Examples
        /// `f(r) := pred(self.prod(l..r))`
        ///
        /// Given that `self.len() == 7`, calling `self.partition_point(0)` returns values written below.
        /// ```text
        ///    r |     0     1     2     3     4     5     6     7     8
        ///
        /// f(r) |  true  true  true  true false false false false   N/A
        ///                             returns^
        ///
        /// f(r) | false false false false false false false false   N/A
        ///     returns^
        ///
        /// f(r) |  true  true  true  true  true  true  true  true   N/A
        ///                                                     returns^
        /// ```
        pub fn partition_point(&self, l: usize, pred: impl Fn(T) -> bool) -> usize {
            if !pred(self.e) {
                // `pred(self.prod(l..l))` is `false`
                // Thus l is returned.
                // This case is not covered in the original implementation as it simply requires pred(self.e) to be `true`
                return l;
            }

            if l == self.n {
                // `pred(self.e)` has already been checked that it's `true`.
                // Thus the answer must be `self.n`.
                return self.n;
            }

            let mut l = l + self.size;
            let mut sm = self.e;

            loop {
                l >>= l.trailing_zeros();
                if !pred((self.op)(sm, self.data[l])) {
                    while l < self.size {
                        l <<= 1;
                        let tmp = (self.op)(sm, self.data[l]);
                        if pred(tmp) {
                            sm = tmp;
                            l += 1;
                        }
                    }
                    return l + 1 - self.size;
                }
                sm = (self.op)(sm, self.data[l]);
                l += 1;
                if l & ((!l) + 1) == l {
                    break;
                }
            }
            self.n + 1
        }

        /// For a function `pred` which has a value `x` less than or equal to `r`, such that `pred(self.prod(l..r))` is `true` if and only if `x <= l`, `self.left_partition_point(r, pred)` returns the value of such `x`.
        /// That is, this is the minimum value of `l` such that `pred(self.prod(l..r))` starts to be `true`.
        /// If `pred(self.e)` is `false`, then this function assumes that `pred(self.prod(l..r))` is always `false` for any `l` in range `0..=r` and returns `r+1`.
        /// However, it's recommended to always set `pred(self.e)` to be `true` to avoid unnecessary case works.
        ///
        /// ## Constraints
        /// - `0 <= r <= self.len()`
        ///
        /// ## Examples
        /// `f(l) := pred(self.prod(l..r))`
        ///
        /// Calling `self.left_partition_point(7)` returns values written below.
        /// ```text
        ///    l |     0     1     2     3     4     5     6     7     8
        ///
        /// f(l) | false false false false  true  true  true  true   N/A
        ///                             returns^
        ///
        /// f(l) |  true  true  true  true  true  true  true  true   N/A
        ///     returns^
        ///
        /// f(l) | false false false false false false false false   N/A
        ///                                                     returns^
        /// ```
        pub fn left_partition_point(&self, r: usize, pred: impl Fn(T) -> bool) -> usize {
            if !pred(self.e) {
                return r + 1;
            }

            if r == 0 {
                // `pred(self.e)` is always `true` at this point
                return 0;
            }

            let mut r = r + self.size;
            let mut sm = self.e;

            loop {
                r -= 1;
                while r > 1 && r & 1 == 1 {
                    r >>= 1;
                }
                if !pred((self.op)(self.data[r], sm)) {
                    while r < self.size {
                        r = (r << 1) + 1;
                        let tmp = (self.op)(self.data[r], sm);
                        if pred(tmp) {
                            sm = tmp;
                            r -= 1;
                        }
                    }
                    return r + 1 - self.size;
                }
                sm = (self.op)(self.data[r], sm);
                if r & ((!r) + 1) == r {
                    break;
                }
            }
            0
        }
    }
}

Last modified on 231007

Fenwick Tree

Given an integer array \(A\) of length \(n\), a Fenwick tree processes the following queries in \(O(\log{n})\) time:

  • Add a certain amount to an element
  • Calculate the sum of the elements of an interval

A Fenwick tree uses half the memory of a segment tree, but the performance in terms of time is just about the same.

A type of elements of \(A\) must be a primitive signed integer type, such as i32 and i64, and floats such as f64. Unsigned integer types like u64 do not work. Specifically, the type should implement From<i8>.

Example

fn main() {
let mut fw: Fenwick<i32> = Fenwick::new(10);
for i in 0..10 {
    print!("{} ", fw.get(i));
}
println!(); // 0 0 0 0 0 0 0 0 0 0
fw.add(2, 10);
fw.add(5, 100);
fw.add(3, -1);
for i in 0..10 {
    print!("{} ", fw.get(i));
}
println!(); // 0 0 10 -1 0 100 0 0 0 0
println!("{}", fw.sum(3..8)); // 99
}

struct Fenwick<T> {
    n: usize,
    data: Vec<T>,
}

impl<T: Copy + From<i8> + std::ops::AddAssign + std::ops::Sub<Output = T>> Fenwick<T> {
    fn new(n: usize) -> Self {
        Self {
            n,
            data: vec![0.into(); n],
        }
    }

    fn add(&mut self, idx: usize, val: T) {
        let mut idx = idx + 1;
        while idx <= self.n {
            self.data[idx - 1] += val;
            idx += idx & (!idx + 1);
        }
    }

    fn get(&self, idx: usize) -> T {
        self.sum(idx..=idx)
    }

    fn sum(&self, range: impl std::ops::RangeBounds<usize>) -> T {
        use std::ops::Bound::*;
        let l = match range.start_bound() {
            Included(&v) => v,
            Excluded(&v) => v + 1,
            Unbounded => 0,
        };
        let r = match range.end_bound() {
            Included(&v) => v + 1,
            Excluded(&v) => v,
            Unbounded => self.n,
        };
        self.inner_sum(r) - self.inner_sum(l)
    }

    fn inner_sum(&self, mut r: usize) -> T {
        let mut s: T = 0.into();
        while r > 0 {
            s += self.data[r - 1];
            r -= r & (!r + 1);
        }
        s
    }
}

Code

struct Fenwick<T> {
	n: usize,
	data: Vec<T>,
}

impl<T: Copy + From<i8> + std::ops::AddAssign + std::ops::Sub<Output = T>> Fenwick<T> {
	fn new(n: usize) -> Self {
		Self { n, data: vec![0.into(); n] }
	}

	fn add(&mut self, idx: usize, val: T) {
		let mut idx = idx + 1;
		while idx <= self.n {
			self.data[idx - 1] += val;
			idx += idx & (!idx + 1);
		}
	}

	fn get(&self, idx: usize) -> T {
		self.sum(idx..=idx)
	}

	fn sum(&self, range: impl std::ops::RangeBounds<usize>) -> T {
		use std::ops::Bound::*;
		let l = match range.start_bound() {
			Included(&v) => v,
			Excluded(&v) => v + 1,
			Unbounded => 0,
		};
		let r = match range.end_bound() {
			Included(&v) => v + 1,
			Excluded(&v) => v,
			Unbounded => self.n,
		};
		self.inner_sum(r) - self.inner_sum(l)
	}

	fn inner_sum(&self, mut r: usize) -> T {
		let mut s: T = 0.into();
		while r > 0 {
			s += self.data[r - 1];
			r -= r & (!r + 1);
		}
		s
	}
}

Lazy Segment Tree

Reference: AtCoder library https://atcoder.github.io/ac-library/production/document_en/index.html

A lazy segment tree is a data struture for a pair of a monoid \( (T, \cdot : T \times T \rightarrow T, e \in T) \) and a set \(F\) of \(T \rightarrow T\) mappings that satisfies the following properties:

  • \(F\) contains the identity mapping \(Id\) such that \( Id(x) = x \) for all \(x\in T\).
  • \(F\) is closed under composition. That is, \( f \circ g \in F \) for all \( f, g \in F \).
  • \( f (x \cdot y) = f(x) \cdot f(y) \) hold for all \(f \in F \) and \( x, y \in T \).

Given an array \(A\) of length \(n\) consists of the monoid \(T\) as described above, a segment tree on it can process the following queries in \(O (\log{n})\) time:

  • Apply the mapping \( f \in F \) on all the elements of an interval
  • Calculate the product of the elements of an interval

assuming that calculating the product of two elements takes \(O(1)\) time.

Example

use lazyseg::*;

fn main() {
// Generic range addition, range sum lazy segment tree
let mut ls = LazySeg::new(
    10,
    (0..10).map(|i| (i, 1)),
    |(x, l), (y, m)| (x + y, l + m),
    (0i64, 0i64),
    |a, (x, l)| (x + a * l, l),
    |a, b| a + b,
    0i64,
);

println!("{}", ls.prod(2..8).0); // 27
ls.apply_range(3..6, 3);
println!("{}", ls.prod(2..8).0); // 36

for r in 3..=10 {
    print!("{} ", ls.prod(3..r).0);
}
println!(); // 0 6 13 21 27 34 42 51
println!("{}", ls.partition_point(3, |(x, _)| x < 40)); // 9

for l in 0..=7 {
    print!("{} ", ls.prod(l..7).0);
}
println!(); // 30 30 29 27 21 14 6 0
println!("{}", ls.left_partition_point(7, |(x, _)| x < 25)); // 4
}

mod lazyseg {
    use std::ops::RangeBounds;

    /// A lazy segment tree is a data structure for a monoid type `T` and a mapping `F` that maps a `T` value to another `T` value.
    ///
    /// Given all the constraints written in the comments of `LazySeg::new`, a lazy segment tree can process the following queries in O(TlogN) time assuming `op`, `map`, `compos` all run in O(T).
    /// - Applying the map `f: U` onto all the elements in an interval
    /// - Calculating the product of elements in an interval, combined with `op`
    pub struct LazySeg<T, O, F, M, C> {
        n: usize,
        data: Vec<T>,
        lazy: Vec<F>,
        e: T,
        op: O,
        id: F,
        map: M,
        compos: C,
        size: usize,
        log: u32,
    }

    impl<T, O, F, M, C> LazySeg<T, O, F, M, C>
    where
        T: Copy,
        O: Fn(T, T) -> T,
        F: Copy,
        M: Fn(F, T) -> T,
        C: Fn(F, F) -> F,
    {
        fn get_bounds(&self, range: impl RangeBounds<usize>) -> (usize, usize) {
            use std::ops::Bound::*;
            let n = self.len();
            let l = match range.start_bound() {
                Included(&v) => v,
                Excluded(&v) => v + 1,
                Unbounded => 0,
            };
            let r = match range.end_bound() {
                Included(&v) => (v + 1).min(n),
                Excluded(&v) => v.min(n),
                Unbounded => n,
            };
            if l > r {
                return (l, l);
            }
            (l, r)
        }

        fn upd(&mut self, k: usize) {
            self.data[k] = (self.op)(self.data[k * 2], self.data[k * 2 + 1]);
        }

        fn all_apply(&mut self, k: usize, f: F) {
            self.data[k] = (self.map)(f, self.data[k]);
            if k < self.size {
                self.lazy[k] = (self.compos)(f, self.lazy[k]);
            }
        }

        fn push(&mut self, k: usize) {
            self.all_apply(k * 2, self.lazy[k]);
            self.all_apply(k * 2 + 1, self.lazy[k]);
            self.lazy[k] = self.id;
        }

        /// Returns a new lazy segment tree of size `n` built from `iter`.
        ///
        /// The meanings of parameters and some generic types are as follows.
        /// - `T` is a type of values in the array the lazy segment tree represents.
        /// - `F` is a type of mappings for the array.
        /// - `n` is a number of elements in the array.
        /// - `iter` is an iterator returning initial values of the array.
        ///   - If `iter.count() < n`, then the rest is filled with `e`.
        ///   - If `iter.count() > n`, the array is truncated down to the length of `n`.
        /// - `op: impl Fn(T, T) -> T` is a binary operator for `T`.
        /// - `e` is an identity for `op`.
        /// - `map: impl Fn(F, T) -> T` defines how to map `T` to another `T` based on the `F` value.
        /// - `compos: impl Fn(F, F) -> F` defines how to compose two `F`'s.
        /// - `id` defines an identity for `compos`.
        ///
        /// The following notations will be used from now on.
        /// - `op(a, b)` is denoted as `a*b`.
        /// - `map(f, a)` is denoted as `f.a`.
        /// - `map(g, map(f, a))` is denoted as `g.f.a`.
        ///
        /// Constraints of parameters are as follows.
        /// - `op` and `e` must make `T` a monoid. That is, `op` and `e` should be given so that `T` can satisfy the following conditions.
        ///   - `T` is associative under `op`. That is, `(a*b)*c == a*(b*c)` for all `[a, b, c]: [T; 3]`.
        ///   - `T` has `e` as an identity element under `op`. That is, `a*e == e*a == a` for all `a: T`.
        /// - `map`, `compos`, and `id` must satisfy the following conditions.
        ///   - `compos` should be properly defined. That is, if `compos(g, f) == h`, then `g.f.a == h.a` must hold.
        ///   - `id` must be a proper identity for `F` under `compos`. That is, `f.id.a == id.f.a == f.a` for all `a: T` and `f: F`.
        ///   - IMPORTANT: `map` must satisfy `f.(x*y) == f.x * f.y`.
        ///
        /// For example, a generic range addition range sum lazy segment tree with every value initialized with `0` and of length `n` can be constructed as follows.
        /// ```no_run
        /// let mut ls = LazySeg::new(
        ///     n,
        ///     (0..n).map(|_| (0, 1)),
        ///     |(x, l), (y, m)| (x + y, l + m),
        ///     (0i64, 0i64),
        ///     |a, (x, l)| (x + a * l, l),
        ///     |a, b| a + b,
        ///     0i64,
        /// );
        /// ```
        /// A so-called "ax+b" lazy segment tree starting with an array of `vec![0; n]` can be constructed as follows.
        /// ```no_run
        /// let mut ls = LazySeg::new(
        ///     n,
        ///     (0..n).map(|_| (0, 1)),
        ///     |(x, l), (y, m)| (x + y, l + m),
        ///     (0i64, 0i64),
        ///     |(a, b), (x, l)| (a * x, b * l),
        ///     |(a, b), (c, d)| (a * c, a * d + b),
        ///     (1i64, 0i64),
        /// );
        /// ```
        pub fn new(n: usize, iter: impl IntoIterator<Item = T>, op: O, e: T, map: M, compos: C, id: F) -> Self {
            let size = n.next_power_of_two();
            let log = size.trailing_zeros();

            let mut data = vec![e; size];
            data.extend(iter.into_iter().take(n));
            data.resize(2 * size, e);

            let mut ls = Self {
                n,
                data,
                lazy: vec![id; size],
                e,
                op,
                id,
                map,
                compos,
                size,
                log,
            };
            for i in (1..size).rev() {
                ls.upd(i);
            }
            ls
        }

        /// Returns the length of the array.
        pub fn len(&self) -> usize {
            self.n
        }

        /// Returns the `i`-th value of the array.
        pub fn get(&mut self, i: usize) -> T {
            let i = i + self.size;
            for j in (1..=self.log).rev() {
                self.push(i >> j);
            }
            self.data[i]
        }

        /// Assign `upd_to(self.get(i))` to the `i`-th element.
        pub fn update(&mut self, i: usize, upd_to: impl Fn(T) -> T) {
            let i = i + self.size;
            for j in (1..=self.log).rev() {
                self.push(i >> j);
            }
            self.data[i] = upd_to(self.data[i]);
            for j in 1..=self.log {
                self.upd(i >> j);
            }
        }

        /// Returns the product of elements in `range`.
        pub fn prod(&mut self, range: impl RangeBounds<usize>) -> T {
            let (l, r) = self.get_bounds(range);

            if l == 0 && r == self.size {
                let ret = (self.op)(self.e, self.data[1]);
                return ret;
            } else if l == r {
                return self.e;
            }

            let (mut l, mut r) = (l + self.size, r + self.size);

            for i in (1..=self.log).rev() {
                if ((l >> i) << i) != l {
                    self.push(l >> i);
                }
                if ((r >> i) << i) != r {
                    self.push((r - 1) >> i);
                }
            }

            let (mut sml, mut smr) = (self.e, self.e);
            while l < r {
                if l & 1 == 1 {
                    sml = (self.op)(sml, self.data[l]);
                    l += 1;
                }
                if r & 1 == 1 {
                    r -= 1;
                    smr = (self.op)(self.data[r], smr);
                }
                (l >>= 1, r >>= 1);
            }
            (self.op)(sml, smr)
        }

        /// Changes every element `x` in `range` of the array to `map.x`.
        pub fn apply_range(&mut self, range: impl RangeBounds<usize>, map: F) {
            let (l, r) = self.get_bounds(range);
            if l >= r {
                return;
            }

            let (mut l, mut r) = (l + self.size, r + self.size);

            for i in (1..=self.log).rev() {
                if ((l >> i) << i) != l {
                    self.push(l >> i);
                }
                if ((r >> i) << i) != r {
                    self.push((r - 1) >> i);
                }
            }

            let (l2, r2) = (l, r);
            while l < r {
                if l & 1 == 1 {
                    self.all_apply(l, map);
                    l += 1;
                }
                if r & 1 == 1 {
                    r -= 1;
                    self.all_apply(r, map);
                }
                l >>= 1;
                r >>= 1;
            }
            l = l2;
            r = r2;

            for i in 1..=self.log {
                if ((l >> i) << i) != l {
                    self.upd(l >> i);
                }
                if ((r >> i) << i) != r {
                    self.upd((r - 1) >> i);
                }
            }
        }

        /// For a function `pred` which has a nonnegative value `x`, such that `pred(self.prod(l..r))` is `false` if and only if `x <= r`, `self.partition_point(l, pred)` returns the value of such `x`.
        /// That is, this is the minimum value of `r` such that `pred(self.prod(l..r))` starts to be `false`.
        /// If `pred(self.e)` is `true`, then this function assumes that `pred(self.prod(l..r))` is always `true` for any `r` in range `l..=self.len()` and returns `l`.
        /// However, it's recommended to always set `pred(self.e)` to be `true` to avoid unnecessary case works.
        ///
        /// ## Constraints
        /// - `0 <= l <= self.len()`
        ///
        /// ## Examples
        /// `f(r) := pred(self.prod(l..r))`
        ///
        /// Given that `self.len() == 7`, calling `self.partition_point(0)` returns values written below.
        /// ```text
        ///    r |     0     1     2     3     4     5     6     7     8
        ///
        /// f(r) |  true  true  true  true false false false false   N/A
        ///                             returns^
        ///
        /// f(r) | false false false false false false false false   N/A
        ///     returns^
        ///
        /// f(r) |  true  true  true  true  true  true  true  true   N/A
        ///                                                     returns^
        /// ```
        pub fn partition_point(&mut self, l: usize, pred: impl Fn(T) -> bool) -> usize {
            if !pred(self.e) {
                return l;
            }

            if l == self.n {
                // `pred(self.e)` has already been checked that it's `true`.
                return self.n;
            }

            let mut l = l + self.size;
            for i in (1..=self.log).rev() {
                self.push(l >> i);
            }
            let mut sm = self.e;

            loop {
                l >>= l.trailing_zeros();
                if !pred((self.op)(sm, self.data[l])) {
                    while l < self.size {
                        self.push(l);
                        l <<= 1;
                        let tmp = (self.op)(sm, self.data[l]);
                        if pred(tmp) {
                            sm = tmp;
                            l += 1;
                        }
                    }
                    return l + 1 - self.size;
                }
                sm = (self.op)(sm, self.data[l]);
                l += 1;
                if l & ((!l) + 1) == l {
                    break;
                }
            }
            self.n + 1
        }

        /// For a function `pred` which has a value `x` less than or equal to `r`, such that `pred(self.prod(l..r))` is `true` if and only if `x <= l`, `self.left_partition_point(r, pred)` returns the value of such `x`.
        /// That is, this is the minimum value of `l` such that `pred(self.prod(l..r))` starts to be `true`.
        /// If `pred(self.e)` is `false`, then this function assumes that `pred(self.prod(l..r))` is always `false` for any `l` in range `0..=r` and returns `r+1`.
        /// However, it's recommended to always set `pred(self.e)` to be `true` to avoid unnecessary case works.
        ///
        /// ## Constraints
        /// - `0 <= r <= self.len()`
        ///
        /// ## Examples
        /// `f(l) := pred(self.prod(l..r))`
        ///
        /// Calling `self.left_partition_point(7)` returns values written below.
        /// ```text
        ///    l |     0     1     2     3     4     5     6     7     8
        ///
        /// f(l) | false false false false  true  true  true  true   N/A
        ///                             returns^
        ///
        /// f(l) |  true  true  true  true  true  true  true  true   N/A
        ///     returns^
        ///
        /// f(l) | false false false false false false false false   N/A
        ///                                                     returns^
        /// ```
        pub fn left_partition_point(&mut self, r: usize, pred: impl Fn(T) -> bool) -> usize {
            if !pred(self.e) {
                return r + 1;
            }

            if r == 0 {
                // `pred(self.e)` has already been checked that it's `true`.
                return 0;
            }

            let mut r = r + self.size;
            for i in (1..=self.log).rev() {
                self.push((r - 1) >> i);
            }

            let mut sm = self.e;
            loop {
                r -= 1;
                while r > 1 && r & 1 == 1 {
                    r >>= 1;
                }

                if !pred((self.op)(self.data[r], sm)) {
                    while r < self.size {
                        self.push(r);
                        r = (r << 1) + 1;
                        let tmp = (self.op)(self.data[r], sm);
                        if pred(tmp) {
                            sm = tmp;
                            r -= 1;
                        }
                    }
                    return r + 1 - self.size;
                }
                sm = (self.op)(self.data[r], sm);
                if r & ((!r) + 1) == r {
                    break;
                }
            }
            0
        }
    }
}

Code

mod lazyseg {
    use std::ops::RangeBounds;

    /// A lazy segment tree is a data structure for a monoid type `T` and a mapping `F` that maps a `T` value to another `T` value.
    ///
    /// Given all the constraints written in the comments of `LazySeg::new`, a lazy segment tree can process the following queries in O(TlogN) time assuming `op`, `map`, `compos` all run in O(T).
    /// - Applying the map `f: U` onto all the elements in an interval
    /// - Calculating the product of elements in an interval, combined with `op`
    pub struct LazySeg<T, O, F, M, C> {
        n: usize,
        data: Vec<T>,
        lazy: Vec<F>,
        e: T,
        op: O,
        id: F,
        map: M,
        compos: C,
        size: usize,
        log: u32,
    }

    impl<T, O, F, M, C> LazySeg<T, O, F, M, C>
    where
        T: Copy,
        O: Fn(T, T) -> T,
        F: Copy,
        M: Fn(F, T) -> T,
        C: Fn(F, F) -> F,
    {
        fn get_bounds(&self, range: impl RangeBounds<usize>) -> (usize, usize) {
            use std::ops::Bound::*;
            let n = self.len();
            let l = match range.start_bound() {
                Included(&v) => v,
                Excluded(&v) => v + 1,
                Unbounded => 0,
            };
            let r = match range.end_bound() {
                Included(&v) => (v + 1).min(n),
                Excluded(&v) => v.min(n),
                Unbounded => n,
            };
            if l > r {
                return (l, l);
            }
            (l, r)
        }

        fn upd(&mut self, k: usize) {
            self.data[k] = (self.op)(self.data[k * 2], self.data[k * 2 + 1]);
        }

        fn all_apply(&mut self, k: usize, f: F) {
            self.data[k] = (self.map)(f, self.data[k]);
            if k < self.size {
                self.lazy[k] = (self.compos)(f, self.lazy[k]);
            }
        }

        fn push(&mut self, k: usize) {
            self.all_apply(k * 2, self.lazy[k]);
            self.all_apply(k * 2 + 1, self.lazy[k]);
            self.lazy[k] = self.id;
        }

        /// Returns a new lazy segment tree of size `n` built from `iter`.
        ///
        /// The meanings of parameters and some generic types are as follows.
        /// - `T` is a type of values in the array the lazy segment tree represents.
        /// - `F` is a type of mappings for the array.
        /// - `n` is a number of elements in the array.
        /// - `iter` is an iterator returning initial values of the array.
        ///   - If `iter.count() < n`, then the rest is filled with `e`.
        ///   - If `iter.count() > n`, the array is truncated down to the length of `n`.
        /// - `op: impl Fn(T, T) -> T` is a binary operator for `T`.
        /// - `e` is an identity for `op`.
        /// - `map: impl Fn(F, T) -> T` defines how to map `T` to another `T` based on the `F` value.
        /// - `compos: impl Fn(F, F) -> F` defines how to compose two `F`'s.
        /// - `id` defines an identity for `compos`.
        ///
        /// The following notations will be used from now on.
        /// - `op(a, b)` is denoted as `a*b`.
        /// - `map(f, a)` is denoted as `f.a`.
        /// - `map(g, map(f, a))` is denoted as `g.f.a`.
        ///
        /// Constraints of parameters are as follows.
        /// - `op` and `e` must make `T` a monoid. That is, `op` and `e` should be given so that `T` can satisfy the following conditions.
        ///   - `T` is associative under `op`. That is, `(a*b)*c == a*(b*c)` for all `[a, b, c]: [T; 3]`.
        ///   - `T` has `e` as an identity element under `op`. That is, `a*e == e*a == a` for all `a: T`.
        /// - `map`, `compos`, and `id` must satisfy the following conditions.
        ///   - `compos` should be properly defined. That is, if `compos(g, f) == h`, then `g.f.a == h.a` must hold.
        ///   - `id` must be a proper identity for `F` under `compos`. That is, `f.id.a == id.f.a == f.a` for all `a: T` and `f: F`.
        ///   - IMPORTANT: `map` must satisfy `f.(x*y) == f.x * f.y`.
        ///
        /// For example, a generic range addition range sum lazy segment tree with every value initialized with `0` and of length `n` can be constructed as follows.
        /// ```no_run
        /// let mut ls = LazySeg::new(
        ///     n,
        ///     (0..n).map(|_| (0, 1)),
        ///     |(x, l), (y, m)| (x + y, l + m),
        ///     (0i64, 0i64),
        ///     |a, (x, l)| (x + a * l, l),
        ///     |a, b| a + b,
        ///     0i64,
        /// );
        /// ```
        /// A so-called "ax+b" lazy segment tree starting with an array of `vec![0; n]` can be constructed as follows.
        /// ```no_run
        /// let mut ls = LazySeg::new(
        ///     n,
        ///     (0..n).map(|_| (0, 1)),
        ///     |(x, l), (y, m)| (x + y, l + m),
        ///     (0i64, 0i64),
        ///     |(a, b), (x, l)| (a * x, b * l),
        ///     |(a, b), (c, d)| (a * c, a * d + b),
        ///     (1i64, 0i64),
        /// );
        /// ```
        pub fn new(n: usize, iter: impl IntoIterator<Item = T>, op: O, e: T, map: M, compos: C, id: F) -> Self {
            let size = n.next_power_of_two();
            let log = size.trailing_zeros();

            let mut data = vec![e; size];
            data.extend(iter.into_iter().take(n));
            data.resize(2 * size, e);

            let mut ls = Self {
                n,
                data,
                lazy: vec![id; size],
                e,
                op,
                id,
                map,
                compos,
                size,
                log,
            };
            for i in (1..size).rev() {
                ls.upd(i);
            }
            ls
        }

        /// Returns the length of the array.
        pub fn len(&self) -> usize {
            self.n
        }

        /// Returns the `i`-th value of the array.
        pub fn get(&mut self, i: usize) -> T {
            let i = i + self.size;
            for j in (1..=self.log).rev() {
                self.push(i >> j);
            }
            self.data[i]
        }

        /// Assign `upd_to(self.get(i))` to the `i`-th element.
        pub fn update(&mut self, i: usize, upd_to: impl Fn(T) -> T) {
            let i = i + self.size;
            for j in (1..=self.log).rev() {
                self.push(i >> j);
            }
            self.data[i] = upd_to(self.data[i]);
            for j in 1..=self.log {
                self.upd(i >> j);
            }
        }

        /// Returns the product of elements in `range`.
        pub fn prod(&mut self, range: impl RangeBounds<usize>) -> T {
            let (l, r) = self.get_bounds(range);

            if l == 0 && r == self.size {
                let ret = (self.op)(self.e, self.data[1]);
                return ret;
            } else if l == r {
                return self.e;
            }

            let (mut l, mut r) = (l + self.size, r + self.size);

            for i in (1..=self.log).rev() {
                if ((l >> i) << i) != l {
                    self.push(l >> i);
                }
                if ((r >> i) << i) != r {
                    self.push((r - 1) >> i);
                }
            }

            let (mut sml, mut smr) = (self.e, self.e);
            while l < r {
                if l & 1 == 1 {
                    sml = (self.op)(sml, self.data[l]);
                    l += 1;
                }
                if r & 1 == 1 {
                    r -= 1;
                    smr = (self.op)(self.data[r], smr);
                }
                (l >>= 1, r >>= 1);
            }
            (self.op)(sml, smr)
        }

        /// Changes every element `x` in `range` of the array to `map.x`.
        pub fn apply_range(&mut self, range: impl RangeBounds<usize>, map: F) {
            let (l, r) = self.get_bounds(range);
            if l >= r {
                return;
            }

            let (mut l, mut r) = (l + self.size, r + self.size);

            for i in (1..=self.log).rev() {
                if ((l >> i) << i) != l {
                    self.push(l >> i);
                }
                if ((r >> i) << i) != r {
                    self.push((r - 1) >> i);
                }
            }

            let (l2, r2) = (l, r);
            while l < r {
                if l & 1 == 1 {
                    self.all_apply(l, map);
                    l += 1;
                }
                if r & 1 == 1 {
                    r -= 1;
                    self.all_apply(r, map);
                }
                l >>= 1;
                r >>= 1;
            }
            l = l2;
            r = r2;

            for i in 1..=self.log {
                if ((l >> i) << i) != l {
                    self.upd(l >> i);
                }
                if ((r >> i) << i) != r {
                    self.upd((r - 1) >> i);
                }
            }
        }

        /// For a function `pred` which has a nonnegative value `x`, such that `pred(self.prod(l..r))` is `false` if and only if `x <= r`, `self.partition_point(l, pred)` returns the value of such `x`.
        /// That is, this is the minimum value of `r` such that `pred(self.prod(l..r))` starts to be `false`.
        /// If `pred(self.e)` is `true`, then this function assumes that `pred(self.prod(l..r))` is always `true` for any `r` in range `l..=self.len()` and returns `l`.
        /// However, it's recommended to always set `pred(self.e)` to be `true` to avoid unnecessary case works.
        ///
        /// ## Constraints
        /// - `0 <= l <= self.len()`
        ///
        /// ## Examples
        /// `f(r) := pred(self.prod(l..r))`
        ///
        /// Given that `self.len() == 7`, calling `self.partition_point(0)` returns values written below.
        /// ```text
        ///    r |     0     1     2     3     4     5     6     7     8
        ///
        /// f(r) |  true  true  true  true false false false false   N/A
        ///                             returns^
        ///
        /// f(r) | false false false false false false false false   N/A
        ///     returns^
        ///
        /// f(r) |  true  true  true  true  true  true  true  true   N/A
        ///                                                     returns^
        /// ```
        pub fn partition_point(&mut self, l: usize, pred: impl Fn(T) -> bool) -> usize {
            if !pred(self.e) {
                return l;
            }

            if l == self.n {
                // `pred(self.e)` has already been checked that it's `true`.
                return self.n;
            }

            let mut l = l + self.size;
            for i in (1..=self.log).rev() {
                self.push(l >> i);
            }
            let mut sm = self.e;

            loop {
                l >>= l.trailing_zeros();
                if !pred((self.op)(sm, self.data[l])) {
                    while l < self.size {
                        self.push(l);
                        l <<= 1;
                        let tmp = (self.op)(sm, self.data[l]);
                        if pred(tmp) {
                            sm = tmp;
                            l += 1;
                        }
                    }
                    return l + 1 - self.size;
                }
                sm = (self.op)(sm, self.data[l]);
                l += 1;
                if l & ((!l) + 1) == l {
                    break;
                }
            }
            self.n + 1
        }

        /// For a function `pred` which has a value `x` less than or equal to `r`, such that `pred(self.prod(l..r))` is `true` if and only if `x <= l`, `self.left_partition_point(r, pred)` returns the value of such `x`.
        /// That is, this is the minimum value of `l` such that `pred(self.prod(l..r))` starts to be `true`.
        /// If `pred(self.e)` is `false`, then this function assumes that `pred(self.prod(l..r))` is always `false` for any `l` in range `0..=r` and returns `r+1`.
        /// However, it's recommended to always set `pred(self.e)` to be `true` to avoid unnecessary case works.
        ///
        /// ## Constraints
        /// - `0 <= r <= self.len()`
        ///
        /// ## Examples
        /// `f(l) := pred(self.prod(l..r))`
        ///
        /// Calling `self.left_partition_point(7)` returns values written below.
        /// ```text
        ///    l |     0     1     2     3     4     5     6     7     8
        ///
        /// f(l) | false false false false  true  true  true  true   N/A
        ///                             returns^
        ///
        /// f(l) |  true  true  true  true  true  true  true  true   N/A
        ///     returns^
        ///
        /// f(l) | false false false false false false false false   N/A
        ///                                                     returns^
        /// ```
        pub fn left_partition_point(&mut self, r: usize, pred: impl Fn(T) -> bool) -> usize {
            if !pred(self.e) {
                return r + 1;
            }

            if r == 0 {
                // `pred(self.e)` has already been checked that it's `true`.
                return 0;
            }

            let mut r = r + self.size;
            for i in (1..=self.log).rev() {
                self.push((r - 1) >> i);
            }

            let mut sm = self.e;
            loop {
                r -= 1;
                while r > 1 && r & 1 == 1 {
                    r >>= 1;
                }

                if !pred((self.op)(self.data[r], sm)) {
                    while r < self.size {
                        self.push(r);
                        r = (r << 1) + 1;
                        let tmp = (self.op)(self.data[r], sm);
                        if pred(tmp) {
                            sm = tmp;
                            r -= 1;
                        }
                    }
                    return r + 1 - self.size;
                }
                sm = (self.op)(self.data[r], sm);
                if r & ((!r) + 1) == r {
                    break;
                }
            }
            0
        }
    }
}

Last modified on 231007.

Rope

Rope acts as if it is a list, but inserting a value at an arbitrary position takes time complexity of amortized \( O(\log{N}) \). However, accessing values also takes amortized \( O(\log{N}) \) time. Building a rope from an iterator takes \( O(N) \).

When accessing to elements, if you use immutable borrow; that is, borrowing through rope.get(idx) or immutably indexing a value like let v = rope[3];, then splaying doesn't happen and in the worst case the accessing could take \(O(N)\). Make sure to use get_mut() for performance.

Example

fn main() {
use rope::Rope;

let mut arr: Rope<i32> = (0..10).collect();
println!("{:?}", arr);

let out = arr.take_range(1..5).unwrap();
arr.merge_right(out);
println!("{:?}", arr);

for i in 11..100000 {
    let n = arr.len() / 2;
    arr.insert(n, i);
}
println!("{}", arr[50000]);

for _ in 0..arr.len() - 10 {
    let n = arr.len() / 2;
    arr.remove(n + 1);
}
println!("{:?}", arr);
}
mod rope {
    use std::{
        cmp::Ordering,
        fmt::{Debug, Display},
        ops::{Bound::*, Index, IndexMut, RangeBounds},
        ptr::{self, NonNull},
        iter::FromIterator,
    };
    pub struct Node<T> {
        data: T,
        subt: usize,
        l: Link<T>,
        r: Link<T>,
        p: Link<T>,
    }
    type Link<T> = Option<NonNull<Node<T>>>;
    impl<T> Node<T> {
        fn new(data: T) -> Self {
            Node {
                data,
                subt: 1,
                l: None,
                r: None,
                p: None,
            }
        }
        fn left_size(&self) -> usize {
            unsafe { self.l.map_or(0, |l| (*l.as_ptr()).subt) }
        }
        fn right_size(&self) -> usize {
            unsafe { self.r.map_or(0, |r| (*r.as_ptr()).subt) }
        }
        fn upd_subtree(&mut self) {
            self.subt = 1 + self.left_size() + self.right_size();
        }
        // Option<(is_left, parent)>
        unsafe fn is_left_child(x: NonNull<Self>) -> Option<(bool, NonNull<Self>)> {
            if let Some(p) = (*x.as_ptr()).p {
                if (*p.as_ptr())
                    .l
                    .map_or(false, |pl| ptr::eq(x.as_ptr(), pl.as_ptr()))
                {
                    Some((true, p))
                } else {
                    Some((false, p))
                }
            } else {
                None
            }
        }
    }
    pub struct Rope<T> {
        root: Link<T>,
        size: usize,
    }
    impl<T> Default for Rope<T> {
        fn default() -> Self {
            Self {
                root: None,
                size: 0,
            }
        }
    }
    impl<T> Rope<T> {
        pub fn new() -> Self {
            Self::default()
        }
        pub fn len(&self) -> usize {
            self.size
        }
        pub fn insert(&mut self, idx: usize, data: T) {
            debug_assert!(idx <= self.size);
            unsafe {
                let new_node = NonNull::new_unchecked(Box::into_raw(Box::new(Node::new(data))));
                if let Some(r) = self.root {
                    let idx = self.kth_ptr(idx);
                    if let Some(idx) = idx {
                        // idx_node is the node which new_node should replace
                        // "Replace" means the new_node should be placed right before the idx_node
                        if let Some(l) = (*idx.as_ptr()).l {
                            // Attach at the right of rightmost node from l
                            let mut p = l;
                            while let Some(r) = (*p.as_ptr()).r {
                                p = r;
                            }
                            // Attach new_node to the right of p
                            (*new_node.as_ptr()).p = Some(p);
                            (*p.as_ptr()).r = Some(new_node);
                        } else {
                            // Attach it right away
                            let p = idx;
                            (*new_node.as_ptr()).p = Some(p);
                            (*p.as_ptr()).l = Some(new_node);
                        }
                    } else {
                        // idx == self.size
                        // new_node goes to the rightmost of the tree
                        let mut p = r;
                        while let Some(r) = (*p.as_ptr()).r {
                            p = r;
                        }
                        // Attach new_node to the right of p
                        (*new_node.as_ptr()).p = Some(p);
                        (*p.as_ptr()).r = Some(new_node);
                    }
                    let mut c = new_node;
                    while let Some(p) = (*c.as_ptr()).p {
                        c = p;
                        (*c.as_ptr()).upd_subtree();
                    }
                } else {
                    self.root = Some(new_node);
                }
                self.splay(new_node);
                self.size += 1;
            }
        }
        pub fn remove(&mut self, idx: usize) -> Option<T> {
            if idx >= self.size {
                return None;
            }
            let data: T = unsafe {
                if let Some(mut rt) = self.kth_ptr(idx) {
                    rt = self.remove_helper(rt);
                    if let Some(rp) = (*rt.as_ptr()).p {
                        self.splay(rp);
                    }
                    let retr = Box::from_raw(rt.as_ptr());
                    retr.data
                } else {
                    unreachable!()
                }
            };
            self.size -= 1;
            Some(data)
        }
        pub fn push_front(&mut self, data: T) {
            self.insert(0, data);
        }
        pub fn push_back(&mut self, data: T) {
            self.insert(self.size, data);
        }
        pub fn pop_front(&mut self) -> Option<T> {
            self.remove(0)
        }
        pub fn pop_back(&mut self) -> Option<T> {
            self.remove(self.size - 1)
        }
        /// Splits out the rope, leaving self[..at] and returning self[at..].
        /// If the index is invalid, it returns None.
        pub fn take_right(&mut self, right_start: usize) -> Option<Self> {
            let rhs = unsafe {
                if right_start == 0 {
                    let rhs = Self {
                        root: self.root,
                        size: self.size,
                    };
                    self.root = None;
                    self.size = 0;
                    rhs
                } else {
                    let root = self.kth_ptr(right_start - 1)?;
                    self.splay(root);
                    if let Some(r) = (*root.as_ptr()).r {
                        (*root.as_ptr()).r = None;
                        (*r.as_ptr()).p = None;
                        (*root.as_ptr()).upd_subtree();
                        self.size = (*root.as_ptr()).subt;
                        Self {
                            root: Some(r),
                            size: (*r.as_ptr()).subt,
                        }
                    } else {
                        Self {
                            root: None,
                            size: 0,
                        }
                    }
                }
            };
            Some(rhs)
        }
        /// Splits out the rope and returns self[..at] and self[at..].
        /// If the index is invalid, it returns None.
        pub fn split_at(mut self, at: usize) -> Option<(Self, Self)> {
            let rhs = self.take_right(at)?;
            Some((self, rhs))
        }
        /// Takes out the range from the rope.
        /// Returns None if the index is invalid.
        pub fn take_range(&mut self, range: impl RangeBounds<usize>) -> Option<Self> {
            let l = match range.start_bound() {
                Included(&l) => l,
                Excluded(&l) => l + 1,
                Unbounded => 0,
            };
            let r = match range.end_bound() {
                Included(&r) => r + 1,
                Excluded(&r) => r,
                Unbounded => self.size,
            };
            if l > r || l > self.size || r > self.size {
                return None;
            }
            // Now the operations below never ends early
            let c = self.take_right(r)?;
            let b = self.take_right(l)?;
            self.merge_right(c);
            Some(b)
        }
        pub fn merge_right(&mut self, mut rhs: Self) {
            if self.len() == 0 {
                self.root = rhs.root;
                self.size = rhs.size;
            } else {
                unsafe {
                    let rmost = self.kth_ptr(self.size - 1).unwrap();
                    self.splay(rmost);
                    (*rmost.as_ptr()).r = rhs.root;
                    if let Some(rhs_root) = rhs.root {
                        (*rhs_root.as_ptr()).p = Some(rmost);
                    }
                    (*rmost.as_ptr()).upd_subtree();
                    self.size = (*rmost.as_ptr()).subt;
                }
            }
            rhs.root = None;
            rhs.size = 0;
        }
        pub fn merge_left(&mut self, mut lhs: Self) {
            if self.len() == 0 {
                self.root = lhs.root;
                self.size = lhs.size;
            } else {
                unsafe {
                    let lmost = self.kth_ptr(0).unwrap();
                    self.splay(lmost);
                    (*lmost.as_ptr()).l = lhs.root;
                    if let Some(lhs_root) = lhs.root {
                        (*lhs_root.as_ptr()).p = Some(lmost);
                    }
                    (*lmost.as_ptr()).upd_subtree();
                    self.size = (*lmost.as_ptr()).subt;
                }
            }
            lhs.root = None;
            lhs.size = 0;
        }
    }
    impl<T: Debug> Debug for Rope<T> {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            write!(f, "[")?;
            let mut cnt: usize = 0;
            unsafe {
                let mut stack: Vec<*mut Node<T>> = Vec::new();
                let mut curr = self.root;
                loop {
                    while let Some(x) = curr {
                        stack.push(x.as_ptr());
                        curr = (*x.as_ptr()).l;
                    }
                    if let Some(x) = stack.pop() {
                        if cnt == 0 {
                            write!(f, "{:?}", (*x).data)?;
                        } else {
                            write!(f, ", {:?}", (*x).data)?;
                        }
                        cnt += 1;
                        curr = (*x).r;
                    } else {
                        break;
                    }
                }
            }
            write!(f, "]")
        }
    }
    impl<T: Display> Display for Rope<T> {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            unsafe {
                let mut stack: Vec<*mut Node<T>> = Vec::new();
                let mut curr = self.root;
                loop {
                    while let Some(x) = curr {
                        stack.push(x.as_ptr());
                        curr = (*x.as_ptr()).l;
                    }
                    if let Some(x) = stack.pop() {
                        write!(f, "{}", (*x).data)?;
                        curr = (*x).r;
                    } else {
                        break;
                    }
                }
            }
            Ok(())
        }
    }
    impl<T> Drop for Rope<T> {
        fn drop(&mut self) {
            if let Some(root) = self.root {
                unsafe {
                    let mut st: Vec<*mut Node<T>> = Vec::new();
                    st.push(root.as_ptr());
                    while let Some(t) = st.pop() {
                        let v = Box::from_raw(t);
                        if let Some(l) = v.l {
                            st.push(l.as_ptr());
                        }
                        if let Some(r) = v.r {
                            st.push(r.as_ptr());
                        }
                        // retrieve.drop()
                    }
                }
            }
        }
    }
    impl<T> Index<usize> for Rope<T> {
        type Output = T;
        fn index(&self, idx: usize) -> &Self::Output {
            unsafe {
                let p = self.kth_ptr(idx);
                &(*p.unwrap().as_ptr()).data
            }
        }
    }
    impl<T> IndexMut<usize> for Rope<T> {
        fn index_mut(&mut self, idx: usize) -> &mut Self::Output {
            unsafe {
                let p = self.kth_ptr(idx);
                &mut (*p.unwrap().as_ptr()).data
            }
        }
    }
    impl<T> FromIterator<T> for Rope<T> {
        fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
            let mut arr = Self::new();
            for v in iter {
                unsafe { arr.push_ontop_root(v) };
            }
            arr
        }
    }
    //------------------------
    // Helper implementations
    //------------------------
    impl<T> Rope<T> {
        /// Adds data as a new root of a rope, and putting the original root
        /// as a left child of the root.
        unsafe fn push_ontop_root(&mut self, data: T) {
            let new_node = NonNull::new_unchecked(Box::into_raw(Box::new(Node::new(data))));
            if let Some(root) = self.root {
                (*root.as_ptr()).p = Some(new_node);
                (*new_node.as_ptr()).l = Some(root);
            }
            self.root = Some(new_node);
            (*new_node.as_ptr()).upd_subtree();
            self.size += 1;
        }
        /// Returns false if x has no parent, and do nothing
        /// Returns true if x has a parent, after performing rotation
        unsafe fn rotate(&mut self, x: NonNull<Node<T>>) -> bool {
            if let Some((is_x_left, p)) = Node::is_left_child(x) {
                // Check if p is root
                if let Some(root) = self.root {
                    if ptr::eq(root.as_ptr(), p.as_ptr()) {
                        self.root = Some(x);
                    }
                }
                // Connect x to xpp. If pp is None, do nothing.
                (*x.as_ptr()).p = (*p.as_ptr()).p;
                if let Some((is_p_left, pp)) = Node::is_left_child(p) {
                    if is_p_left {
                        (*pp.as_ptr()).l = Some(x);
                    } else {
                        (*pp.as_ptr()).r = Some(x);
                    }
                }
                if is_x_left {
                    let b = (*x.as_ptr()).r;
                    (*x.as_ptr()).r = Some(p);
                    (*p.as_ptr()).p = Some(x);
                    (*p.as_ptr()).l = b;
                    if let Some(b) = b {
                        (*b.as_ptr()).p = Some(p);
                    }
                } else {
                    let b = (*x.as_ptr()).l;
                    (*x.as_ptr()).l = Some(p);
                    (*p.as_ptr()).p = Some(x);
                    (*p.as_ptr()).r = b;
                    if let Some(b) = b {
                        (*b.as_ptr()).p = Some(p);
                    }
                }
                (*p.as_ptr()).upd_subtree();
                (*x.as_ptr()).upd_subtree();
                true
            } else {
                false
            }
        }
        fn splay(&mut self, x: NonNull<Node<T>>) {
            unsafe {
                while let Some(root) = self.root {
                    if ptr::eq(x.as_ptr(), root.as_ptr()) {
                        break;
                    }
                    if let Some((is_x_left, p)) = Node::is_left_child(x) {
                        if ptr::eq(root.as_ptr(), p.as_ptr()) {
                            // If p is root, rotate x once
                            self.rotate(x);
                        } else {
                            // Panics if pp doesn't exist, which happens only when p is root
                            let (is_p_left, _pp) = Node::is_left_child(p).unwrap();
                            if is_x_left == is_p_left {
                                self.rotate(p);
                                self.rotate(x);
                            } else {
                                self.rotate(x);
                                self.rotate(x);
                            }
                        }
                    } else {
                        // x has no parent, which should logically never happen
                        unreachable!()
                    }
                }
            }
        }
        unsafe fn kth_ptr(&self, idx: usize) -> Link<T> {
            if self.size <= idx {
                return None;
            }
            if let Some(r) = self.root {
                let mut rem = idx;
                let mut p = r;
                loop {
                    let lsize = (*p.as_ptr()).left_size();
                    match rem.cmp(&lsize) {
                        Ordering::Less => {
                            p = (*p.as_ptr()).l?;
                        }
                        Ordering::Equal => {
                            break;
                        }
                        Ordering::Greater => {
                            rem -= lsize + 1;
                            p = (*p.as_ptr()).r?;
                        }
                    }
                }
                Some(p)
            } else {
                None
            }
        }
        unsafe fn remove_helper(&mut self, x: NonNull<Node<T>>) -> NonNull<Node<T>> {
            // Set remove_target to the actual node to delete
            match ((*x.as_ptr()).l, ((*x.as_ptr()).r)) {
                (None, None) => {
                    // Reset root if the node itself is root
                    if let Some(root) = self.root {
                        if ptr::eq(root.as_ptr(), x.as_ptr()) {
                            self.root = None;
                        }
                    }
                    // Detatch itself from parent
                    if let Some((is_x_left, p)) = Node::is_left_child(x) {
                        if is_x_left {
                            (*p.as_ptr()).l = None;
                        } else {
                            (*p.as_ptr()).r = None;
                        }
                        // Update subtree size
                        let mut p = p;
                        (*p.as_ptr()).upd_subtree();
                        while let Some(pp) = (*p.as_ptr()).p {
                            p = pp;
                            (*p.as_ptr()).upd_subtree();
                        }
                    }
                    x
                }
                (Some(l), None) => {
                    // Reset root if the node itself is a root
                    if let Some(root) = self.root {
                        if ptr::eq(root.as_ptr(), x.as_ptr()) {
                            self.root = Some(l);
                        }
                    }
                    (*l.as_ptr()).p = (*x.as_ptr()).p;
                    if let Some((is_rt_left, p)) = Node::is_left_child(x) {
                        if is_rt_left {
                            (*p.as_ptr()).l = Some(l);
                        } else {
                            (*p.as_ptr()).r = Some(l);
                        }
                    }
                    let mut p = l;
                    while let Some(pp) = (*p.as_ptr()).p {
                        p = pp;
                        (*p.as_ptr()).upd_subtree();
                    }
                    x
                }
                (None, Some(r)) => {
                    // Reset root if the node itself is a root
                    if let Some(root) = self.root {
                        if ptr::eq(root.as_ptr(), x.as_ptr()) {
                            self.root = Some(r);
                        }
                    }
                    (*r.as_ptr()).p = (*x.as_ptr()).p;
                    if let Some((is_rt_left, p)) = Node::is_left_child(x) {
                        if is_rt_left {
                            (*p.as_ptr()).l = Some(r);
                        } else {
                            (*p.as_ptr()).r = Some(r);
                        }
                    }
                    let mut p = r;
                    while let Some(pp) = (*p.as_ptr()).p {
                        p = pp;
                        (*p.as_ptr()).upd_subtree();
                    }
                    x
                }
                (Some(l), Some(_)) => {
                    let mut sw = l;
                    while let Some(sr) = (*sw.as_ptr()).r {
                        sw = sr;
                    }
                    std::mem::swap(&mut (*x.as_ptr()).data, &mut (*sw.as_ptr()).data);
                    sw = self.remove_helper(sw);
                    sw
                }
            }
        }
    }
    //-----------
    // Iterators
    //-----------
    impl<T> Rope<T> {
        pub fn iter(&self) -> Iter<T> {
            Iter::new(self)
        }
        pub fn iter_mut(&mut self) -> IterMut<T> {
            IterMut::new(self)
        }
    }
    pub struct Iter<'a, T> {
        rope: &'a Rope<T>,
        stack: Vec<NonNull<Node<T>>>,
        curr: Link<T>,
    }
    impl<'a, T> Iter<'a, T> {
        fn new(rope: &'a Rope<T>) -> Self {
            let root = rope.root;
            Self {
                rope,
                stack: Vec::new(),
                curr: root,
            }
        }
    }
    impl<'a, T> IntoIterator for &'a Rope<T> {
        type Item = &'a T;
        type IntoIter = Iter<'a, T>;
        fn into_iter(self) -> Self::IntoIter {
            Self::IntoIter::new(self)
        }
    }
    impl<'a, T> Iterator for Iter<'a, T> {
        type Item = &'a T;
        fn next(&mut self) -> Option<Self::Item> {
            unsafe {
                while let Some(x) = self.curr {
                    self.stack.push(x);
                    self.curr = (*x.as_ptr()).l;
                }
                if let Some(x) = self.stack.pop() {
                    self.curr = (*x.as_ptr()).r;
                    Some(&x.as_ref().data)
                } else {
                    None
                }
            }
        }
        fn size_hint(&self) -> (usize, Option<usize>) {
            (self.rope.len(), Some(self.rope.len()))
        }
    }
    pub struct IterMut<'a, T> {
        rope: &'a mut Rope<T>,
        stack: Vec<NonNull<Node<T>>>,
        curr: Link<T>,
    }
    impl<'a, T> IterMut<'a, T> {
        fn new(rope: &'a mut Rope<T>) -> Self {
            let root = rope.root;
            Self {
                rope,
                stack: Vec::new(),
                curr: root,
            }
        }
    }
    impl<'a, T> IntoIterator for &'a mut Rope<T> {
        type Item = &'a mut T;
        type IntoIter = IterMut<'a, T>;
        fn into_iter(self) -> Self::IntoIter {
            Self::IntoIter::new(self)
        }
    }
    impl<'a, T> Iterator for IterMut<'a, T> {
        type Item = &'a mut T;
        fn next(&mut self) -> Option<Self::Item> {
            unsafe {
                while let Some(x) = self.curr {
                    self.stack.push(x);
                    self.curr = (*x.as_ptr()).l;
                }
                if let Some(mut x) = self.stack.pop() {
                    self.curr = (*x.as_ptr()).r;
                    Some(&mut x.as_mut().data)
                } else {
                    None
                }
            }
        }
        fn size_hint(&self) -> (usize, Option<usize>) {
            (self.rope.len(), Some(self.rope.len()))
        }
    }
}

Code

mod rope {
    use std::{
        cmp::Ordering,
        fmt::{Debug, Display},
        ops::{Bound::*, Index, IndexMut, RangeBounds},
        ptr::{self, NonNull},
    };

    pub struct Rope<T> {
        root: Link<T>,
        size: usize,
    }

    impl<T> Default for Rope<T> {
        fn default() -> Self {
            Self {
                root: None,
                size: 0,
            }
        }
    }

    impl<T> Rope<T> {
        pub fn new() -> Self {
            Self::default()
        }

        pub fn len(&self) -> usize {
            self.size
        }

        pub fn clear(&mut self) {
            let drop_tree = Self {
                root: self.root,
                size: self.size,
            };
            drop(drop_tree);
            self.root = None;
            self.size = 0;
        }

        pub fn insert(&mut self, idx: usize, data: T) {
            debug_assert!(idx <= self.size);
            unsafe {
                let new_node = NonNull::new_unchecked(Box::into_raw(Box::new(Node::new(data))));

                if let Some(r) = self.root {
                    let idx = self.kth_ptr(idx);
                    if let Some(idx) = idx {
                        // idx_node is the node which new_node should replace
                        // "Replace" means the new_node should be placed right before the idx_node
                        if let Some(l) = (*idx.as_ptr()).l {
                            // Attach at the right of rightmost node from l
                            let mut p = l;
                            while let Some(r) = (*p.as_ptr()).r {
                                p = r;
                            }
                            // Attach new_node to the right of p
                            (*new_node.as_ptr()).p = Some(p);
                            (*p.as_ptr()).r = Some(new_node);
                        } else {
                            // Attach it right away
                            let p = idx;
                            (*new_node.as_ptr()).p = Some(p);
                            (*p.as_ptr()).l = Some(new_node);
                        }
                    } else {
                        // idx == self.size
                        // new_node goes to the rightmost of the tree
                        let mut p = r;
                        while let Some(r) = (*p.as_ptr()).r {
                            p = r;
                        }
                        // Attach new_node to the right of p
                        (*new_node.as_ptr()).p = Some(p);
                        (*p.as_ptr()).r = Some(new_node);
                    }

                    let mut c = new_node;
                    while let Some(p) = (*c.as_ptr()).p {
                        c = p;
                        (*c.as_ptr()).upd_subtree();
                    }
                } else {
                    self.root = Some(new_node);
                }

                self.splay(new_node);
                self.size += 1;
            }
        }

        pub fn remove(&mut self, idx: usize) -> Option<T> {
            if idx >= self.size {
                return None;
            }

            let data: T = unsafe {
                if let Some(mut rt) = self.kth_ptr(idx) {
                    rt = self.remove_helper(rt);
                    if let Some(rp) = (*rt.as_ptr()).p {
                        self.splay(rp);
                    }
                    let retr = Box::from_raw(rt.as_ptr());
                    retr.data
                } else {
                    unreachable!()
                }
            };

            self.size -= 1;
            Some(data)
        }

        pub fn push_front(&mut self, data: T) {
            self.insert(0, data);
        }

        pub fn push_back(&mut self, data: T) {
            self.insert(self.size, data);
        }

        pub fn pop_front(&mut self) -> Option<T> {
            self.remove(0)
        }

        pub fn pop_back(&mut self) -> Option<T> {
            self.remove(self.size - 1)
        }

        /// Splits out the rope, leaving self[..at] and returning self[at..].
        /// If the index is invalid, it returns None.
        pub fn take_right(&mut self, right_start: usize) -> Option<Self> {
            let rhs = unsafe {
                if right_start == 0 {
                    let rhs = Self {
                        root: self.root,
                        size: self.size,
                    };
                    self.root = None;
                    self.size = 0;
                    rhs
                } else {
                    let root = self.kth_ptr(right_start - 1)?;
                    self.splay(root);
                    if let Some(r) = (*root.as_ptr()).r {
                        (*root.as_ptr()).r = None;
                        (*r.as_ptr()).p = None;
                        (*root.as_ptr()).upd_subtree();
                        self.size = (*root.as_ptr()).subt;
                        Self {
                            root: Some(r),
                            size: (*r.as_ptr()).subt,
                        }
                    } else {
                        Self {
                            root: None,
                            size: 0,
                        }
                    }
                }
            };
            Some(rhs)
        }

        /// Splits out the rope and returns self[..at] and self[at..].
        /// If the index is invalid, it returns None.
        pub fn split_at(mut self, at: usize) -> Option<(Self, Self)> {
            let rhs = self.take_right(at)?;
            Some((self, rhs))
        }

        /// Takes out the range from the rope.
        /// Returns None if the index is invalid.
        pub fn take_range(&mut self, range: impl RangeBounds<usize>) -> Option<Self> {
            let l = match range.start_bound() {
                Included(&l) => l,
                Excluded(&l) => l + 1,
                Unbounded => 0,
            };
            let r = match range.end_bound() {
                Included(&r) => r + 1,
                Excluded(&r) => r,
                Unbounded => self.size,
            };

            if l > r || l > self.size || r > self.size {
                return None;
            }
            // Now the operations below never ends early
            let c = self.take_right(r)?;
            let b = self.take_right(l)?;
            self.merge_right(c);
            Some(b)
        }

        pub fn merge_right(&mut self, mut rhs: Self) {
            if self.len() == 0 {
                self.root = rhs.root;
                self.size = rhs.size;
            } else {
                unsafe {
                    let rmost = self.kth_ptr(self.size - 1).unwrap();
                    self.splay(rmost);
                    (*rmost.as_ptr()).r = rhs.root;
                    if let Some(rhs_root) = rhs.root {
                        (*rhs_root.as_ptr()).p = Some(rmost);
                    }
                    (*rmost.as_ptr()).upd_subtree();
                    self.size = (*rmost.as_ptr()).subt;
                }
            }
            rhs.root = None;
            rhs.size = 0;
        }

        pub fn merge_left(&mut self, mut lhs: Self) {
            if self.len() == 0 {
                self.root = lhs.root;
                self.size = lhs.size;
            } else {
                unsafe {
                    let lmost = self.kth_ptr(0).unwrap();
                    self.splay(lmost);
                    (*lmost.as_ptr()).l = lhs.root;
                    if let Some(lhs_root) = lhs.root {
                        (*lhs_root.as_ptr()).p = Some(lmost);
                    }
                    (*lmost.as_ptr()).upd_subtree();
                    self.size = (*lmost.as_ptr()).subt;
                }
            }
            lhs.root = None;
            lhs.size = 0;
        }

        /// Inserts rope into self at self.
        /// After the operation, rope[0] becomes self[at].
        /// Returns false if the specified index is invalid, true otherwise.
        pub fn insert_rope(&mut self, rope: Self, at: usize) -> bool {
            let rhs = self.take_right(at);
            if let Some(rhs) = rhs {
                self.merge_right(rope);
                self.merge_right(rhs);
                true
            } else {
                false
            }
        }
    }

    impl<T: Clone> Clone for Rope<T> {
        fn clone(&self) -> Self {
            self.iter().cloned().collect()
        }
    }

    impl<T: Debug> Debug for Rope<T> {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            write!(f, "[")?;
            let mut cnt: usize = 0;
            unsafe {
                let mut stack: Vec<*mut Node<T>> = Vec::new();
                let mut curr = self.root;
                loop {
                    while let Some(x) = curr {
                        stack.push(x.as_ptr());
                        curr = (*x.as_ptr()).l;
                    }
                    if let Some(x) = stack.pop() {
                        if cnt == 0 {
                            write!(f, "{:?}", (*x).data)?;
                        } else {
                            write!(f, ", {:?}", (*x).data)?;
                        }
                        cnt += 1;
                        curr = (*x).r;
                    } else {
                        break;
                    }
                }
            }
            write!(f, "]")
        }
    }

    impl<T: Display> Display for Rope<T> {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            unsafe {
                let mut stack: Vec<*mut Node<T>> = Vec::new();
                let mut curr = self.root;
                loop {
                    while let Some(x) = curr {
                        stack.push(x.as_ptr());
                        curr = (*x.as_ptr()).l;
                    }
                    if let Some(x) = stack.pop() {
                        write!(f, "{}", (*x).data)?;
                        curr = (*x).r;
                    } else {
                        break;
                    }
                }
            }
            Ok(())
        }
    }

    impl<T> Drop for Rope<T> {
        fn drop(&mut self) {
            if let Some(root) = self.root {
                unsafe {
                    let mut st: Vec<*mut Node<T>> = Vec::new();
                    st.push(root.as_ptr());
                    while let Some(t) = st.pop() {
                        let v = Box::from_raw(t);
                        if let Some(l) = v.l {
                            st.push(l.as_ptr());
                        }
                        if let Some(r) = v.r {
                            st.push(r.as_ptr());
                        }
                        drop(v);
                    }
                }
            }
        }
    }

    impl<T> Index<usize> for Rope<T> {
        type Output = T;
        fn index(&self, idx: usize) -> &Self::Output {
            unsafe {
                let p = self.kth_ptr(idx).unwrap();
                &(*p.as_ptr()).data
            }
        }
    }

    impl<T> IndexMut<usize> for Rope<T> {
        fn index_mut(&mut self, idx: usize) -> &mut Self::Output {
            unsafe {
                let p = self.kth_ptr(idx).unwrap();
                self.splay(p);
                &mut (*p.as_ptr()).data
            }
        }
    }

    impl<T> FromIterator<T> for Rope<T> {
        fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
            let mut arr = Self::new();
            for v in iter {
                unsafe { arr.push_ontop_root(v) };
            }
            arr
        }
    }

    impl<T> Rope<T> {
        pub fn iter(&self) -> Iter<T> {
            Iter::new(self)
        }

        pub fn iter_mut(&mut self) -> IterMut<T> {
            IterMut::new(self)
        }
    }

    pub struct Iter<'a, T> {
        rope: &'a Rope<T>,
        stack: Vec<NonNull<Node<T>>>,
        curr: Link<T>,
    }

    impl<'a, T> Iter<'a, T> {
        fn new(rope: &'a Rope<T>) -> Self {
            let root = rope.root;
            Self {
                rope,
                stack: Vec::new(),
                curr: root,
            }
        }
    }

    impl<'a, T> IntoIterator for &'a Rope<T> {
        type Item = &'a T;
        type IntoIter = Iter<'a, T>;
        fn into_iter(self) -> Self::IntoIter {
            Self::IntoIter::new(self)
        }
    }

    impl<'a, T> Iterator for Iter<'a, T> {
        type Item = &'a T;
        fn next(&mut self) -> Option<Self::Item> {
            unsafe {
                while let Some(x) = self.curr {
                    self.stack.push(x);
                    self.curr = (*x.as_ptr()).l;
                }
                if let Some(x) = self.stack.pop() {
                    self.curr = (*x.as_ptr()).r;
                    Some(&x.as_ref().data)
                } else {
                    None
                }
            }
        }

        fn size_hint(&self) -> (usize, Option<usize>) {
            (self.rope.len(), Some(self.rope.len()))
        }
    }

    pub struct IterMut<'a, T> {
        rope: &'a mut Rope<T>,
        stack: Vec<NonNull<Node<T>>>,
        curr: Link<T>,
    }

    impl<'a, T> IterMut<'a, T> {
        fn new(rope: &'a mut Rope<T>) -> Self {
            let root = rope.root;
            Self {
                rope,
                stack: Vec::new(),
                curr: root,
            }
        }
    }

    impl<'a, T> IntoIterator for &'a mut Rope<T> {
        type Item = &'a mut T;
        type IntoIter = IterMut<'a, T>;
        fn into_iter(self) -> Self::IntoIter {
            Self::IntoIter::new(self)
        }
    }

    impl<'a, T> Iterator for IterMut<'a, T> {
        type Item = &'a mut T;
        fn next(&mut self) -> Option<Self::Item> {
            unsafe {
                while let Some(x) = self.curr {
                    self.stack.push(x);
                    self.curr = (*x.as_ptr()).l;
                }
                if let Some(mut x) = self.stack.pop() {
                    self.curr = (*x.as_ptr()).r;
                    Some(&mut x.as_mut().data)
                } else {
                    None
                }
            }
        }

        fn size_hint(&self) -> (usize, Option<usize>) {
            (self.rope.len(), Some(self.rope.len()))
        }
    }

    //------------------------
    // Helper implementations
    //------------------------

    struct Node<T> {
        data: T,
        subt: usize,
        l: Link<T>,
        r: Link<T>,
        p: Link<T>,
    }

    type Link<T> = Option<NonNull<Node<T>>>;

    impl<T> Node<T> {
        fn new(data: T) -> Self {
            Node {
                data,
                subt: 1,
                l: None,
                r: None,
                p: None,
            }
        }
        fn left_size(&self) -> usize {
            unsafe { self.l.map_or(0, |l| (*l.as_ptr()).subt) }
        }
        fn right_size(&self) -> usize {
            unsafe { self.r.map_or(0, |r| (*r.as_ptr()).subt) }
        }
        fn upd_subtree(&mut self) {
            self.subt = 1 + self.left_size() + self.right_size();
        }

        // Option<(is_left, parent)>
        unsafe fn is_left_child(x: NonNull<Self>) -> Option<(bool, NonNull<Self>)> {
            if let Some(p) = (*x.as_ptr()).p {
                if (*p.as_ptr())
                    .l
                    .map_or(false, |pl| ptr::eq(x.as_ptr(), pl.as_ptr()))
                {
                    Some((true, p))
                } else {
                    Some((false, p))
                }
            } else {
                None
            }
        }
    }

    impl<T> Rope<T> {
        /// Adds data as a new root of a rope, and putting the original root
        /// as a left child of the root.
        unsafe fn push_ontop_root(&mut self, data: T) {
            let new_node = NonNull::new_unchecked(Box::into_raw(Box::new(Node::new(data))));
            if let Some(root) = self.root {
                (*root.as_ptr()).p = Some(new_node);
                (*new_node.as_ptr()).l = Some(root);
            }
            self.root = Some(new_node);
            (*new_node.as_ptr()).upd_subtree();
            self.size += 1;
        }

        /// Returns false if x has no parent, and do nothing
        /// Returns true if x has a parent, after performing rotation
        unsafe fn rotate(&mut self, x: NonNull<Node<T>>) -> bool {
            if let Some((is_x_left, p)) = Node::is_left_child(x) {
                // Check if p is root
                if let Some(root) = self.root {
                    if ptr::eq(root.as_ptr(), p.as_ptr()) {
                        self.root = Some(x);
                    }
                }

                // Connect x to xpp. If pp is None, do nothing.
                (*x.as_ptr()).p = (*p.as_ptr()).p;
                if let Some((is_p_left, pp)) = Node::is_left_child(p) {
                    if is_p_left {
                        (*pp.as_ptr()).l = Some(x);
                    } else {
                        (*pp.as_ptr()).r = Some(x);
                    }
                }

                if is_x_left {
                    let b = (*x.as_ptr()).r;
                    (*x.as_ptr()).r = Some(p);
                    (*p.as_ptr()).p = Some(x);
                    (*p.as_ptr()).l = b;
                    if let Some(b) = b {
                        (*b.as_ptr()).p = Some(p);
                    }
                } else {
                    let b = (*x.as_ptr()).l;
                    (*x.as_ptr()).l = Some(p);
                    (*p.as_ptr()).p = Some(x);
                    (*p.as_ptr()).r = b;
                    if let Some(b) = b {
                        (*b.as_ptr()).p = Some(p);
                    }
                }

                (*p.as_ptr()).upd_subtree();
                (*x.as_ptr()).upd_subtree();
                true
            } else {
                false
            }
        }

        fn splay(&mut self, x: NonNull<Node<T>>) {
            unsafe {
                while let Some(root) = self.root {
                    if ptr::eq(x.as_ptr(), root.as_ptr()) {
                        break;
                    }

                    if let Some((is_x_left, p)) = Node::is_left_child(x) {
                        if ptr::eq(root.as_ptr(), p.as_ptr()) {
                            // If p is root, rotate x once
                            self.rotate(x);
                        } else {
                            // Panics if pp doesn't exist, which happens only when p is root
                            let (is_p_left, _pp) = Node::is_left_child(p).unwrap();
                            if is_x_left == is_p_left {
                                self.rotate(p);
                                self.rotate(x);
                            } else {
                                self.rotate(x);
                                self.rotate(x);
                            }
                        }
                    } else {
                        // x has no parent, which should logically never happen
                        unreachable!()
                    }
                }
            }
        }

        unsafe fn kth_ptr(&self, idx: usize) -> Link<T> {
            if self.size <= idx {
                return None;
            }
            if let Some(r) = self.root {
                let mut rem = idx;
                let mut p = r;
                loop {
                    let lsize = (*p.as_ptr()).left_size();
                    match rem.cmp(&lsize) {
                        Ordering::Less => {
                            p = (*p.as_ptr()).l?;
                        }
                        Ordering::Equal => {
                            break;
                        }
                        Ordering::Greater => {
                            rem -= lsize + 1;
                            p = (*p.as_ptr()).r?;
                        }
                    }
                }
                Some(p)
            } else {
                None
            }
        }

        unsafe fn remove_helper(&mut self, x: NonNull<Node<T>>) -> NonNull<Node<T>> {
            // Set remove_target to the actual node to delete
            match ((*x.as_ptr()).l, ((*x.as_ptr()).r)) {
                (None, None) => {
                    // Reset root if the node itself is root
                    if let Some(root) = self.root {
                        if ptr::eq(root.as_ptr(), x.as_ptr()) {
                            self.root = None;
                        }
                    }
                    // Detatch itself from parent
                    if let Some((is_x_left, p)) = Node::is_left_child(x) {
                        if is_x_left {
                            (*p.as_ptr()).l = None;
                        } else {
                            (*p.as_ptr()).r = None;
                        }
                        // Update subtree size
                        let mut p = p;
                        (*p.as_ptr()).upd_subtree();
                        while let Some(pp) = (*p.as_ptr()).p {
                            p = pp;
                            (*p.as_ptr()).upd_subtree();
                        }
                    }
                    x
                }
                (Some(l), None) => {
                    // Reset root if the node itself is a root
                    if let Some(root) = self.root {
                        if ptr::eq(root.as_ptr(), x.as_ptr()) {
                            self.root = Some(l);
                        }
                    }

                    (*l.as_ptr()).p = (*x.as_ptr()).p;
                    if let Some((is_rt_left, p)) = Node::is_left_child(x) {
                        if is_rt_left {
                            (*p.as_ptr()).l = Some(l);
                        } else {
                            (*p.as_ptr()).r = Some(l);
                        }
                    }

                    let mut p = l;
                    while let Some(pp) = (*p.as_ptr()).p {
                        p = pp;
                        (*p.as_ptr()).upd_subtree();
                    }
                    x
                }
                (None, Some(r)) => {
                    // Reset root if the node itself is a root
                    if let Some(root) = self.root {
                        if ptr::eq(root.as_ptr(), x.as_ptr()) {
                            self.root = Some(r);
                        }
                    }

                    (*r.as_ptr()).p = (*x.as_ptr()).p;
                    if let Some((is_rt_left, p)) = Node::is_left_child(x) {
                        if is_rt_left {
                            (*p.as_ptr()).l = Some(r);
                        } else {
                            (*p.as_ptr()).r = Some(r);
                        }
                    }

                    let mut p = r;
                    while let Some(pp) = (*p.as_ptr()).p {
                        p = pp;
                        (*p.as_ptr()).upd_subtree();
                    }
                    x
                }
                (Some(l), Some(_)) => {
                    let mut sw = l;
                    while let Some(sr) = (*sw.as_ptr()).r {
                        sw = sr;
                    }
                    std::mem::swap(&mut (*x.as_ptr()).data, &mut (*sw.as_ptr()).data);
                    sw = self.remove_helper(sw);
                    sw
                }
            }
        }
    }
}

Bitset

BitSet is equivalent to a fixed-size array of booleans. Each boolean value is packed as a bit of u64.

For auto-vectorization, each u64 are packed as [u64; 4] so that it can act as a "SIMD lane".

As this snippet is purely for PS and CP, it does not contain many necessary checks, such as checking if two bitset as an argument of a function has the same length. For any other purpose, I highly recommend the bitset_core crate which highly inspired this snippet.

Usage

An array of [u64; 4] implements BitSetOps trait, therefore is recognized as a bitset. The number of booleans packed into the bitset can be found with fn bit_len(&self) -> usize.

Refer to the APIs for further information.

Example

use bitset::*;

fn main() {
const MAX_VAL: usize = 1000000;

let mut is_prime = [[0u64; 4]; (MAX_VAL + 256) / 256];
println!("{}", is_prime.bit_len()); // 1000192

is_prime.bit_init(true);
is_prime.bit_reset(0);
is_prime.bit_reset(1);

for i in (2..=MAX_VAL).take_while(|&i| i * i <= MAX_VAL) {
    if is_prime.bit_get(i) {
        for j in (i * i..=MAX_VAL).step_by(i) {
            is_prime.bit_reset(j);
        }
    }
}

println!(
    "{}",
    is_prime.bit_count_ones() - (is_prime.bit_len() - (MAX_VAL + 1))
); // 78498
}

mod bitset {
    /* Copyright (c) 2020 Casper <CasualX@users.noreply.github.com>
     * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
     * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
     * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
     */

    //! This module is purely for PS and CP. Thus it skips safety checks such as checking if
    //! self.len() and rhs.len() are equal, and it may panic if shift overflow (for the whole
    //! bitset) happens.

    // DO NOT CHANGE THESE VALUES
    // The full generalization for bitset is not done.
    type ElemTy = u64;
    const ELEM_BIT: usize = ElemTy::BITS as usize;
    const ELEM_LEN: usize = 4;
    const BITS_PER_WORD: usize = ELEM_BIT * ELEM_LEN;

    pub type BitSet = [[ElemTy; ELEM_LEN]];

    pub trait BitSetOps {
        fn bit_len(&self) -> usize;
        fn bit_init(&mut self, val: bool) -> &mut Self;

        fn bit_get(&self, idx: usize) -> bool;
        fn bit_set(&mut self, idx: usize) -> &mut Self;
        fn bit_reset(&mut self, idx: usize) -> &mut Self;
        fn bit_flip(&mut self, idx: usize) -> &mut Self;
        fn bit_manip(&mut self, idx: usize, val: bool) -> &mut Self;

        fn bit_all(&self) -> bool;
        fn bit_any(&self) -> bool;
        #[inline]
        fn bit_none(&self) -> bool {
            !self.bit_any()
        }

        fn bit_eq(&self, rhs: &Self) -> bool;
        fn bit_disjoint(&self, rhs: &Self) -> bool;
        fn bit_subset(&self, rhs: &Self) -> bool;
        #[inline]
        fn bit_superset(&self, rhs: &Self) -> bool {
            rhs.bit_subset(self)
        }

        fn bit_or(&mut self, rhs: &Self) -> &mut Self;
        fn bit_and(&mut self, rhs: &Self) -> &mut Self;
        fn bit_nand(&mut self, rhs: &Self) -> &mut Self;
        fn bit_xor(&mut self, rhs: &Self) -> &mut Self;
        fn bit_not(&mut self) -> &mut Self;
        fn bit_mask(&mut self, rhs: &Self, mask: &Self) -> &mut Self;

        fn bit_shr(&mut self, by: usize) -> &mut Self;
        fn bit_shl(&mut self, by: usize) -> &mut Self;

        fn bit_count_ones(&self) -> usize;
        #[inline]
        fn bit_count_zeros(&self) -> usize {
            self.bit_len() - self.bit_count_ones()
        }

        #[inline]
        fn bit_fmt(&self) -> &BitFmt<Self> {
            unsafe { &*(self as *const _ as *const _) }
        }
    }

    impl BitSetOps for BitSet {
        #[inline]
        fn bit_len(&self) -> usize {
            self.len() * BITS_PER_WORD
        }

        #[inline]
        fn bit_init(&mut self, val: bool) -> &mut Self {
            let val = [ElemTy::wrapping_add(!(val as ElemTy), 1); ELEM_LEN];
            for i in 0..self.len() {
                self[i] = val;
            }
            self
        }

        #[inline]
        fn bit_get(&self, idx: usize) -> bool {
            let block = idx / BITS_PER_WORD;
            let lane = (idx % BITS_PER_WORD) / ELEM_BIT;
            let bit = idx % ELEM_BIT;
            self[block][lane] & (1 << bit) != 0
        }

        #[inline]
        fn bit_set(&mut self, idx: usize) -> &mut Self {
            let block = idx / BITS_PER_WORD;
            let lane = (idx % BITS_PER_WORD) / ELEM_BIT;
            let bit = idx % ELEM_BIT;
            self[block][lane] |= 1 << bit;
            self
        }

        #[inline]
        fn bit_reset(&mut self, idx: usize) -> &mut Self {
            let block = idx / BITS_PER_WORD;
            let lane = (idx % BITS_PER_WORD) / ELEM_BIT;
            let bit = idx % ELEM_BIT;
            self[block][lane] &= !(1 << bit);
            self
        }

        #[inline]
        fn bit_flip(&mut self, idx: usize) -> &mut Self {
            let block = idx / BITS_PER_WORD;
            let lane = (idx % BITS_PER_WORD) / ELEM_BIT;
            let bit = idx % ELEM_BIT;
            self[block][lane] ^= 1 << bit;
            self
        }

        #[inline]
        fn bit_manip(&mut self, idx: usize, val: bool) -> &mut Self {
            let block = idx / BITS_PER_WORD;
            let lane = (idx % BITS_PER_WORD) / ELEM_BIT;
            let bit = idx % ELEM_BIT;
            let mask = 1 << bit;
            self[block][lane] =
                (self[block][lane] & !mask) | (ElemTy::wrapping_add(!(val as ElemTy), 1) & mask);
            self
        }

        #[inline]
        fn bit_all(&self) -> bool {
            self.iter()
                .all(|block| block[0] == !0 && block[1] == !0 && block[2] == !0 && block[3] == !0)
        }

        #[inline]
        fn bit_any(&self) -> bool {
            self.iter()
                .all(|block| block[0] == 0 && block[1] == 0 && block[2] == 0 && block[3] == 0)
        }

        #[inline]
        fn bit_eq(&self, rhs: &Self) -> bool {
            self.iter().zip(rhs.iter()).all(|(&lblk, &rblk)| {
                lblk[0] == rblk[0] && lblk[1] == rblk[1] && lblk[2] == rblk[2] && lblk[3] == rblk[3]
            })
        }

        #[inline]
        fn bit_disjoint(&self, rhs: &Self) -> bool {
            self.iter().zip(rhs.iter()).all(|(&lblk, &rblk)| {
                lblk[0] & rblk[0] == 0
                    && lblk[1] & rblk[1] == 0
                    && lblk[2] & rblk[2] == 0
                    && lblk[3] & rblk[3] == 0
            })
        }

        #[inline]
        fn bit_subset(&self, rhs: &Self) -> bool {
            self.iter().zip(rhs.iter()).all(|(&lblk, &rblk)| {
                rblk[0] == rblk[0] | lblk[0]
                    && rblk[1] == rblk[1] | lblk[1]
                    && rblk[2] == rblk[2] | lblk[2]
                    && rblk[3] == rblk[3] | lblk[3]
            })
        }

        #[inline]
        fn bit_or(&mut self, rhs: &Self) -> &mut Self {
            for i in 0..self.len() {
                self[i][0] |= rhs[i][0];
                self[i][1] |= rhs[i][1];
                self[i][2] |= rhs[i][2];
                self[i][3] |= rhs[i][3];
            }
            self
        }

        #[inline]
        fn bit_and(&mut self, rhs: &Self) -> &mut Self {
            for i in 0..self.len() {
                self[i][0] &= rhs[i][0];
                self[i][1] &= rhs[i][1];
                self[i][2] &= rhs[i][2];
                self[i][3] &= rhs[i][3];
            }
            self
        }

        #[inline]
        fn bit_nand(&mut self, rhs: &Self) -> &mut Self {
            for i in 0..self.len() {
                self[i][0] &= !rhs[i][0];
                self[i][1] &= !rhs[i][1];
                self[i][2] &= !rhs[i][2];
                self[i][3] &= !rhs[i][3];
            }
            self
        }

        #[inline]
        fn bit_xor(&mut self, rhs: &Self) -> &mut Self {
            for i in 0..self.len() {
                self[i][0] ^= rhs[i][0];
                self[i][1] ^= rhs[i][1];
                self[i][2] ^= rhs[i][2];
                self[i][3] ^= rhs[i][3];
            }
            self
        }

        #[inline]
        fn bit_not(&mut self) -> &mut Self {
            for i in 0..self.len() {
                self[i][0] = !self[i][0];
                self[i][1] = !self[i][1];
                self[i][2] = !self[i][2];
                self[i][3] = !self[i][3];
            }
            self
        }

        #[inline]
        fn bit_mask(&mut self, rhs: &Self, mask: &Self) -> &mut Self {
            for i in 0..self.len() {
                self[i][0] = self[i][0] & !mask[i][0] | rhs[i][0] & mask[i][0];
                self[i][1] = self[i][1] & !mask[i][1] | rhs[i][1] & mask[i][1];
                self[i][2] = self[i][2] & !mask[i][2] | rhs[i][2] & mask[i][2];
                self[i][3] = self[i][3] & !mask[i][3] | rhs[i][3] & mask[i][3];
            }
            self
        }

        #[inline]
        fn bit_shl(&mut self, by: usize) -> &mut Self {
            let elem_move = by / ELEM_BIT;
            let bit_move = by % ELEM_BIT;

            let (_, slice, _): (_, &mut [ElemTy], _) = unsafe { self.align_to_mut() };
            slice.copy_within(..slice.len() - elem_move, elem_move);
            slice[..elem_move].fill(0);

            if bit_move != 0 {
                let mut carry: ElemTy = 0;
                let mut tmp: [ElemTy; ELEM_LEN] = [0; ELEM_LEN];
                for i in 0..self.len() {
                    tmp[0] = self[i][0] >> (ELEM_BIT - bit_move);
                    tmp[1] = self[i][1] >> (ELEM_BIT - bit_move);
                    tmp[2] = self[i][2] >> (ELEM_BIT - bit_move);
                    tmp[3] = self[i][3] >> (ELEM_BIT - bit_move);
                    self[i][0] <<= bit_move;
                    self[i][1] <<= bit_move;
                    self[i][2] <<= bit_move;
                    self[i][3] <<= bit_move;
                    let tmpc = tmp[ELEM_LEN - 1];
                    tmp.copy_within(..ELEM_LEN - 1, 1);
                    tmp[0] = carry;
                    self[i][0] |= tmp[0];
                    self[i][1] |= tmp[1];
                    self[i][2] |= tmp[2];
                    self[i][3] |= tmp[3];
                    carry = tmpc;
                }
            }

            self
        }

        #[inline]
        fn bit_shr(&mut self, by: usize) -> &mut Self {
            let elem_move = by / ELEM_BIT;
            let bit_move = by % ELEM_BIT;

            let (_, slice, _): (_, &mut [ElemTy], _) = unsafe { self.align_to_mut() };
            slice.copy_within(elem_move.., 0);
            let sl = slice.len();
            slice[sl - elem_move..].fill(0);

            if bit_move != 0 {
                let mut carry: ElemTy = 0;
                let mut tmp: [ElemTy; ELEM_LEN] = [0; ELEM_LEN];
                for i in 0..self.len() {
                    tmp[0] = self[i][0] << (ELEM_BIT - bit_move);
                    tmp[1] = self[i][1] << (ELEM_BIT - bit_move);
                    tmp[2] = self[i][2] << (ELEM_BIT - bit_move);
                    tmp[3] = self[i][3] << (ELEM_BIT - bit_move);
                    self[i][0] >>= bit_move;
                    self[i][1] >>= bit_move;
                    self[i][2] >>= bit_move;
                    self[i][3] >>= bit_move;
                    let tmpc = tmp[0];
                    tmp.copy_within(1.., 0);
                    tmp[ELEM_LEN - 1] = carry;
                    carry = tmpc;
                    self[i][0] |= tmp[0];
                    self[i][1] |= tmp[1];
                    self[i][2] |= tmp[2];
                    self[i][3] |= tmp[3];
                }
            }

            self
        }

        #[inline]
        fn bit_count_ones(&self) -> usize {
            self.iter()
                .map(|chunk| {
                    chunk[0].count_ones() as usize
                        + chunk[1].count_ones() as usize
                        + chunk[2].count_ones() as usize
                        + chunk[3].count_ones() as usize
                })
                .sum()
        }
    }

    mod fmt {
        use super::BitSetOps as BitSet;
        use std::fmt;

        #[repr(transparent)]
        pub struct BitFmt<T: ?Sized>(T);

        fn bitstring<T: ?Sized + BitSet>(this: &T, f: &mut fmt::Formatter) -> fmt::Result {
            const ALPHABET: [u8; 2] = [b'0', b'1'];
            let mut buf = [0u8; 9];
            let mut first = true;
            buf[0] = b'_';
            let mut i = 0;
            while i < this.bit_len() {
                buf[1] = ALPHABET[this.bit_get(i + 0) as usize];
                buf[2] = ALPHABET[this.bit_get(i + 1) as usize];
                buf[3] = ALPHABET[this.bit_get(i + 2) as usize];
                buf[4] = ALPHABET[this.bit_get(i + 3) as usize];
                buf[5] = ALPHABET[this.bit_get(i + 4) as usize];
                buf[6] = ALPHABET[this.bit_get(i + 5) as usize];
                buf[7] = ALPHABET[this.bit_get(i + 6) as usize];
                buf[8] = ALPHABET[this.bit_get(i + 7) as usize];
                let s = unsafe { &*((&buf[first as usize..]) as *const _ as *const str) };
                f.write_str(s)?;
                i += 8;
                first = false;
            }
            Ok(())
        }

        impl<T: ?Sized + BitSet> fmt::Display for BitFmt<T> {
            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
                bitstring(&self.0, f)
            }
        }
    }
    pub use self::fmt::BitFmt;
}

Code

mod bitset {
    /* Copyright (c) 2020 Casper <CasualX@users.noreply.github.com>
     * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
     * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
     * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
     */

    //! This module is purely for PS and CP. Thus it skips safety checks such as checking if
    //! self.len() and rhs.len() are equal, and it may panic if shift overflow (for the whole
    //! bitset) happens.

    // DO NOT CHANGE THESE VALUES
    // The full generalization for bitset is not done.
    type ElemTy = u64;
    const ELEM_BIT: usize = ElemTy::BITS as usize;
    const ELEM_LEN: usize = 4;
    const BITS_PER_WORD: usize = ELEM_BIT * ELEM_LEN;

    pub type BitSet = [[ElemTy; ELEM_LEN]];

    pub trait BitSetOps {
        fn bit_len(&self) -> usize;
        fn bit_init(&mut self, val: bool) -> &mut Self;

        fn bit_get(&self, idx: usize) -> bool;
        fn bit_set(&mut self, idx: usize) -> &mut Self;
        fn bit_reset(&mut self, idx: usize) -> &mut Self;
        fn bit_flip(&mut self, idx: usize) -> &mut Self;
        fn bit_manip(&mut self, idx: usize, val: bool) -> &mut Self;

        fn bit_all(&self) -> bool;
        fn bit_any(&self) -> bool;
        #[inline]
        fn bit_none(&self) -> bool {
            !self.bit_any()
        }

        fn bit_eq(&self, rhs: &Self) -> bool;
        fn bit_disjoint(&self, rhs: &Self) -> bool;
        fn bit_subset(&self, rhs: &Self) -> bool;
        #[inline]
        fn bit_superset(&self, rhs: &Self) -> bool {
            rhs.bit_subset(self)
        }

        fn bit_or(&mut self, rhs: &Self) -> &mut Self;
        fn bit_and(&mut self, rhs: &Self) -> &mut Self;
        fn bit_nand(&mut self, rhs: &Self) -> &mut Self;
        fn bit_xor(&mut self, rhs: &Self) -> &mut Self;
        fn bit_not(&mut self) -> &mut Self;
        fn bit_mask(&mut self, rhs: &Self, mask: &Self) -> &mut Self;

        fn bit_shr(&mut self, by: usize) -> &mut Self;
        fn bit_shl(&mut self, by: usize) -> &mut Self;

        fn bit_count_ones(&self) -> usize;
        #[inline]
        fn bit_count_zeros(&self) -> usize {
            self.bit_len() - self.bit_count_ones()
        }

        #[inline]
        fn bit_fmt(&self) -> &BitFmt<Self> {
            unsafe { &*(self as *const _ as *const _) }
        }
    }

    impl BitSetOps for BitSet {
        #[inline]
        fn bit_len(&self) -> usize {
            self.len() * BITS_PER_WORD
        }

        #[inline]
        fn bit_init(&mut self, val: bool) -> &mut Self {
            let val = [ElemTy::wrapping_add(!(val as ElemTy), 1); ELEM_LEN];
            for i in 0..self.len() {
                self[i] = val;
            }
            self
        }

        #[inline]
        fn bit_get(&self, idx: usize) -> bool {
            let block = idx / BITS_PER_WORD;
            let lane = (idx % BITS_PER_WORD) / ELEM_BIT;
            let bit = idx % ELEM_BIT;
            self[block][lane] & (1 << bit) != 0
        }

        #[inline]
        fn bit_set(&mut self, idx: usize) -> &mut Self {
            let block = idx / BITS_PER_WORD;
            let lane = (idx % BITS_PER_WORD) / ELEM_BIT;
            let bit = idx % ELEM_BIT;
            self[block][lane] |= 1 << bit;
            self
        }

        #[inline]
        fn bit_reset(&mut self, idx: usize) -> &mut Self {
            let block = idx / BITS_PER_WORD;
            let lane = (idx % BITS_PER_WORD) / ELEM_BIT;
            let bit = idx % ELEM_BIT;
            self[block][lane] &= !(1 << bit);
            self
        }

        #[inline]
        fn bit_flip(&mut self, idx: usize) -> &mut Self {
            let block = idx / BITS_PER_WORD;
            let lane = (idx % BITS_PER_WORD) / ELEM_BIT;
            let bit = idx % ELEM_BIT;
            self[block][lane] ^= 1 << bit;
            self
        }

        #[inline]
        fn bit_manip(&mut self, idx: usize, val: bool) -> &mut Self {
            let block = idx / BITS_PER_WORD;
            let lane = (idx % BITS_PER_WORD) / ELEM_BIT;
            let bit = idx % ELEM_BIT;
            let mask = 1 << bit;
            self[block][lane] =
                (self[block][lane] & !mask) | (ElemTy::wrapping_add(!(val as ElemTy), 1) & mask);
            self
        }

        #[inline]
        fn bit_all(&self) -> bool {
            self.iter()
                .all(|block| block[0] == !0 && block[1] == !0 && block[2] == !0 && block[3] == !0)
        }

        #[inline]
        fn bit_any(&self) -> bool {
            self.iter()
                .all(|block| block[0] == 0 && block[1] == 0 && block[2] == 0 && block[3] == 0)
        }

        #[inline]
        fn bit_eq(&self, rhs: &Self) -> bool {
            self.iter().zip(rhs.iter()).all(|(&lblk, &rblk)| {
                lblk[0] == rblk[0] && lblk[1] == rblk[1] && lblk[2] == rblk[2] && lblk[3] == rblk[3]
            })
        }

        #[inline]
        fn bit_disjoint(&self, rhs: &Self) -> bool {
            self.iter().zip(rhs.iter()).all(|(&lblk, &rblk)| {
                lblk[0] & rblk[0] == 0
                    && lblk[1] & rblk[1] == 0
                    && lblk[2] & rblk[2] == 0
                    && lblk[3] & rblk[3] == 0
            })
        }

        /// Returns if self is a subset of rhs
        #[inline]
        fn bit_subset(&self, rhs: &Self) -> bool {
            self.iter().zip(rhs.iter()).all(|(&lblk, &rblk)| {
                rblk[0] == rblk[0] | lblk[0]
                    && rblk[1] == rblk[1] | lblk[1]
                    && rblk[2] == rblk[2] | lblk[2]
                    && rblk[3] == rblk[3] | lblk[3]
            })
        }

        #[inline]
        fn bit_or(&mut self, rhs: &Self) -> &mut Self {
            for i in 0..self.len() {
                self[i][0] |= rhs[i][0];
                self[i][1] |= rhs[i][1];
                self[i][2] |= rhs[i][2];
                self[i][3] |= rhs[i][3];
            }
            self
        }

        #[inline]
        fn bit_and(&mut self, rhs: &Self) -> &mut Self {
            for i in 0..self.len() {
                self[i][0] &= rhs[i][0];
                self[i][1] &= rhs[i][1];
                self[i][2] &= rhs[i][2];
                self[i][3] &= rhs[i][3];
            }
            self
        }

        #[inline]
        fn bit_nand(&mut self, rhs: &Self) -> &mut Self {
            for i in 0..self.len() {
                self[i][0] &= !rhs[i][0];
                self[i][1] &= !rhs[i][1];
                self[i][2] &= !rhs[i][2];
                self[i][3] &= !rhs[i][3];
            }
            self
        }

        #[inline]
        fn bit_xor(&mut self, rhs: &Self) -> &mut Self {
            for i in 0..self.len() {
                self[i][0] ^= rhs[i][0];
                self[i][1] ^= rhs[i][1];
                self[i][2] ^= rhs[i][2];
                self[i][3] ^= rhs[i][3];
            }
            self
        }

        #[inline]
        fn bit_not(&mut self) -> &mut Self {
            for i in 0..self.len() {
                self[i][0] = !self[i][0];
                self[i][1] = !self[i][1];
                self[i][2] = !self[i][2];
                self[i][3] = !self[i][3];
            }
            self
        }

        #[inline]
        fn bit_mask(&mut self, rhs: &Self, mask: &Self) -> &mut Self {
            for i in 0..self.len() {
                self[i][0] = self[i][0] & !mask[i][0] | rhs[i][0] & mask[i][0];
                self[i][1] = self[i][1] & !mask[i][1] | rhs[i][1] & mask[i][1];
                self[i][2] = self[i][2] & !mask[i][2] | rhs[i][2] & mask[i][2];
                self[i][3] = self[i][3] & !mask[i][3] | rhs[i][3] & mask[i][3];
            }
            self
        }

        #[inline]
        fn bit_shl(&mut self, by: usize) -> &mut Self {
            let elem_move = by / ELEM_BIT;
            let bit_move = by % ELEM_BIT;

            let (_, slice, _): (_, &mut [ElemTy], _) = unsafe { self.align_to_mut() };
            slice.copy_within(..slice.len() - elem_move, elem_move);
            slice[..elem_move].fill(0);

            if bit_move != 0 {
                let mut carry: ElemTy = 0;
                let mut tmp: [ElemTy; ELEM_LEN] = [0; ELEM_LEN];
                for i in 0..self.len() {
                    tmp[0] = self[i][0] >> (ELEM_BIT - bit_move);
                    tmp[1] = self[i][1] >> (ELEM_BIT - bit_move);
                    tmp[2] = self[i][2] >> (ELEM_BIT - bit_move);
                    tmp[3] = self[i][3] >> (ELEM_BIT - bit_move);
                    self[i][0] <<= bit_move;
                    self[i][1] <<= bit_move;
                    self[i][2] <<= bit_move;
                    self[i][3] <<= bit_move;
                    let tmpc = tmp[ELEM_LEN - 1];
                    tmp.copy_within(..ELEM_LEN - 1, 1);
                    tmp[0] = carry;
                    self[i][0] |= tmp[0];
                    self[i][1] |= tmp[1];
                    self[i][2] |= tmp[2];
                    self[i][3] |= tmp[3];
                    carry = tmpc;
                }
            }

            self
        }

        #[inline]
        fn bit_shr(&mut self, by: usize) -> &mut Self {
            let elem_move = by / ELEM_BIT;
            let bit_move = by % ELEM_BIT;

            let (_, slice, _): (_, &mut [ElemTy], _) = unsafe { self.align_to_mut() };
            slice.copy_within(elem_move.., 0);
            let sl = slice.len();
            slice[sl - elem_move..].fill(0);

            if bit_move != 0 {
                let mut carry: ElemTy = 0;
                let mut tmp: [ElemTy; ELEM_LEN] = [0; ELEM_LEN];
                for i in 0..self.len() {
                    tmp[0] = self[i][0] << (ELEM_BIT - bit_move);
                    tmp[1] = self[i][1] << (ELEM_BIT - bit_move);
                    tmp[2] = self[i][2] << (ELEM_BIT - bit_move);
                    tmp[3] = self[i][3] << (ELEM_BIT - bit_move);
                    self[i][0] >>= bit_move;
                    self[i][1] >>= bit_move;
                    self[i][2] >>= bit_move;
                    self[i][3] >>= bit_move;
                    let tmpc = tmp[0];
                    tmp.copy_within(1.., 0);
                    tmp[ELEM_LEN - 1] = carry;
                    carry = tmpc;
                    self[i][0] |= tmp[0];
                    self[i][1] |= tmp[1];
                    self[i][2] |= tmp[2];
                    self[i][3] |= tmp[3];
                }
            }

            self
        }

        #[inline]
        fn bit_count_ones(&self) -> usize {
            self.iter()
                .map(|chunk| {
                    chunk[0].count_ones() as usize
                        + chunk[1].count_ones() as usize
                        + chunk[2].count_ones() as usize
                        + chunk[3].count_ones() as usize
                })
                .sum()
        }
    }

    mod fmt {
        use super::BitSetOps as BitSet;
        use std::fmt;

        #[repr(transparent)]
        pub struct BitFmt<T: ?Sized>(T);

        fn bitstring<T: ?Sized + BitSet>(this: &T, f: &mut fmt::Formatter) -> fmt::Result {
            const ALPHABET: [u8; 2] = [b'0', b'1'];
            let mut buf = [0u8; 9];
            let mut first = true;
            buf[0] = b'_';
            let mut i = 0;
            while i < this.bit_len() {
                buf[1] = ALPHABET[this.bit_get(i + 0) as usize];
                buf[2] = ALPHABET[this.bit_get(i + 1) as usize];
                buf[3] = ALPHABET[this.bit_get(i + 2) as usize];
                buf[4] = ALPHABET[this.bit_get(i + 3) as usize];
                buf[5] = ALPHABET[this.bit_get(i + 4) as usize];
                buf[6] = ALPHABET[this.bit_get(i + 5) as usize];
                buf[7] = ALPHABET[this.bit_get(i + 6) as usize];
                buf[8] = ALPHABET[this.bit_get(i + 7) as usize];
                let s = unsafe { &*((&buf[first as usize..]) as *const _ as *const str) };
                f.write_str(s)?;
                i += 8;
                first = false;
            }
            Ok(())
        }

        impl<T: ?Sized + BitSet> fmt::Display for BitFmt<T> {
            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
                bitstring(&self.0, f)
            }
        }
    }
    pub use self::fmt::BitFmt;
}

SIMD, Auto-vectorization and rustc Optimization Level

To fully enable the power of aggressive SIMD optimization, the opt-level for compilation should be 3. When the level is 2, despite many of vectorizations still happen, the occurence noticably decreases.

As most of the OJs compile Rust codes with opt-level of 2, to fully enable the power of SIMD, you need to hardcode the machine code into Rust. As this is virtually impossible to do manually in actual PS/CP, using third-party tools like basm-rs is highly recommended.

APIs

The behavior of APIs having multiple bitsets as arguments, when their length are not equal to each other, is unspecified.

  • pub type BitSet

    [u64; 4] is defined as a type BitSet. As this is a dynamically sized type, when declaring a bitset you cannot use BitSet in its type declaration. Instead, you need to do like the below example.

    let mut bitset
  • fn bit_len(&self) -> usize

    Returns the number of boolean values included in the set.

  • fn bit_init(&mut self, val: bool) -> &mut Self

    Initializes every boolean value of self as val, and returns &mut self back.

  • fn bit_get(&self, idx: usize) -> bool

    Returns the idxth boolean value of self.

  • fn bit_set(&mut self, idx: usize) -> &mut Self

    Sets the idxth boolean value to true, and returns &mut self back.

  • fn bit_reset(&mut self, idx: usize) -> &mut Self

    Sets the idxth boolean value to false, and returns &mut self back.

  • fn bit_flip(&mut self, idx: usize) -> &mut Self

    Flips the idxth boolean value, and returns &mut self back.

  • fn bit_manip(&mut self, idx: usize, val: bool) -> &mut Self

    Sets the idxth boolean value to val, and returns &mut self back.

  • fn bit_all(&self) -> bool

    Returns true if every boolean value of self is true. Otherwise, returns false.

  • fn bit_none(&self) -> bool

    Returns true if every boolean value of self is false. Otherwise, returns false.

  • fn bit_disjoint(&self, rhs: &Self) -> bool

    Returns true if every bit of self turned on is not in rhs, and vice versa. Otherwise, returns false.

  • fn bit_subset(&self, rhs: &Self) -> bool

    Returns true if self is a subset of rhs. Otherwise, returns false.

  • fn bit_superset(&self, rhs: &Self) -> bool

    Returns true if self is a superset of rhs. Otherwise, returns false.

  • fn bit_or(&mut self, rhs: &Self) -> &mut Self

    Sets self as self | rhs, and returns &mut self back.

  • fn bit_and(&mut self, rhs: &Self) -> &mut Self

    Sets self as self & rhs, and returns &mut self back.

  • fn bit_nand(&mut self, rhs: &Self) -> &mut Self

    Sets self as self & !rhs, and returns &mut self back.

  • fn bit_xor(&mut self, rhs: &Self) -> &mut Self

    Sets self as self ^ rhs, and returns &mut self back.

  • fn bit_not(&mut self) -> &mut Self

    Reverses every bits of self, and returns &mut self back.

  • fn bit_mask(&mut self, rhs: &Self, mask: &Self) -> &mut Self

    Sets self as (self & !mask) | (rhs & mask), and returns &mut self back.

  • fn bit_shr(&mut self, by: usize) -> &mut Self

    Shifts self right by by. The direction of shifting is to the lower index. The empty bits are filled with 0, and the overflowed bits disappear.

  • fn bit_shl(&mut self, by: usize) -> &mut Self

    Shifts self left by by. The direction of shifting is to the lower index. The empty bits are filled with 0, and the overflowed bits disappear.

  • fn bit_count_ones(&self) -> usize

    Returns the number of boolean values that is true.

  • fn bit_count_zeros(&self) -> usize

    Returns the number of boolean values that is false.

  • fn bit_fmt(&self) -> &BitFmt<Self>

    Used for printing out the bitset.

    println!("{}", bitset.bit_fmt());

License

This code is protected under MIT license, holded by Casper CasualX@users.noreply.github.com.

Dijkstra

dijkstra(graph: &Graph<T>, src: usize) where Graph<T> is a graph representation, T is a numeric data type, and src is an id of a source node. It returns a dist = Vec<Option<T>>, where dist[dst] is the length of a shortest path from src to dst if it exists, or None if dst is unreachable from src.

Example

#[allow(unused)]
use std::{cmp::*, collections::*, iter, mem::*, num::*, ops::*};

fn main() {
let mut graph: Vec<Vec<(usize, i64)>> = vec![vec![]; 4];
for (u, v, w) in [(0, 1, 5), (1, 2, 5), (1, 2, 15)] {
    graph[u].push((v, w));
    graph[v].push((u, w));
}
println!("{:?}", dijkstra(&graph, 0)); // [Some(0), Some(5), Some(10), None]
}

trait HasNz {
    type NzType;
    fn into_nz(self) -> Option<Self::NzType>;
    fn retrieve(nz: Self::NzType) -> Self;
}

macro_rules! impl_hasnz {
    ($($t:ty, $n:ty);*) => { $(
        impl HasNz for $t {
            type NzType = $n;
            fn into_nz(self) -> Option<$n> { <$n>::new(self) }
            fn retrieve(nz: $n) -> Self { nz.get() }
        }
    )* };
}

impl_hasnz!(i8, NonZeroI8; i16, NonZeroI16; i32, NonZeroI32; i64, NonZeroI64; i128, NonZeroI128; isize, NonZeroIsize);
impl_hasnz!(u8, NonZeroU8; u16, NonZeroU16; u32, NonZeroU32; u64, NonZeroU64; u128, NonZeroU128; usize, NonZeroUsize);

fn dijkstra<T>(graph: &[Vec<(usize, T)>], src: usize) -> Vec<Option<T>>
where
    T: Copy + From<u8> + Add<Output = T> + Sub<Output = T> + Eq + Ord + HasNz,
    <T as HasNz>::NzType: Copy,
{
    let mut dist: Vec<Option<T::NzType>> = vec![None; graph.len()];
    let mut heap: BinaryHeap<(Reverse<T>, usize)> = BinaryHeap::new();
    heap.push((Reverse(1.into()), src));

    while let Some((Reverse(curr_cost), curr)) = heap.pop() {
        if dist[curr].map_or(false, |x| T::retrieve(x) < curr_cost) {
            continue;
        }
        dist[curr] = curr_cost.into_nz();

        for &(next, weight) in graph[curr].iter() {
            let next_cost = curr_cost + weight;
            if dist[next].map_or(true, |x| T::retrieve(x) > next_cost) {
                dist[next] = next_cost.into_nz();
                heap.push((Reverse(next_cost), next));
            }
        }
    }

    dist.iter().map(|x| x.map(|x| T::retrieve(x) - 1.into())).collect()
}

Code

trait HasNz {
    type NzType;
    fn into_nz(self) -> Option<Self::NzType>;
    fn retrieve(nz: Self::NzType) -> Self;
}

macro_rules! impl_hasnz {
    ($($t:ty, $n:ty);*) => { $(
        impl HasNz for $t {
            type NzType = $n;
            fn into_nz(self) -> Option<$n> { <$n>::new(self) }
            fn retrieve(nz: $n) -> Self { nz.get() }
        }
    )* };
}

impl_hasnz!(i8, NonZeroI8; i16, NonZeroI16; i32, NonZeroI32; i64, NonZeroI64; i128, NonZeroI128; isize, NonZeroIsize);
impl_hasnz!(u8, NonZeroU8; u16, NonZeroU16; u32, NonZeroU32; u64, NonZeroU64; u128, NonZeroU128; usize, NonZeroUsize);

fn dijkstra<T>(graph: &[Vec<(usize, T)>], src: usize) -> Vec<Option<T>>
where
    T: Copy + From<u8> + Add<Output = T> + Sub<Output = T> + Eq + Ord + HasNz,
    <T as HasNz>::NzType: Copy,
{
    let mut dist: Vec<Option<T::NzType>> = vec![None; graph.len()];
    let mut heap: BinaryHeap<(Reverse<T>, usize)> = BinaryHeap::new();
    heap.push((Reverse(1.into()), src));

    while let Some((Reverse(curr_cost), curr)) = heap.pop() {
        if dist[curr].map_or(false, |x| T::retrieve(x) < curr_cost) {
            continue;
        }
        dist[curr] = curr_cost.into_nz();

        for &(next, weight) in graph[curr].iter() {
            let next_cost = curr_cost + weight;
            if dist[next].map_or(true, |x| T::retrieve(x) > next_cost) {
                dist[next] = next_cost.into_nz();
                heap.push((Reverse(next_cost), next));
            }
        }
    }

    dist.iter().map(|x| x.map(|x| T::retrieve(x) - 1.into())).collect()
}

Compact Version

The code is much compact, but the performance is slightly worse than the one above. This version seems to be approximately 10% slower, but none of the definitive tests have been done to check it.

fn dijkstra<T: Copy + From<u8> + Add<Output = T> + Eq + Ord>(graph: &[Vec<(usize, T)>], src: usize) -> Vec<Option<T>> {
    let mut dist: Vec<Option<T>> = vec![None; graph.len()];
    let mut heap: BinaryHeap<(Reverse<T>, usize)> = BinaryHeap::new();
    heap.push((Reverse(0.into()), src));

    while let Some((Reverse(curr_cost), curr)) = heap.pop() {
        if dist[curr].map_or(false, |x| x < curr_cost) {
            continue;
        }
        dist[curr] = Some(curr_cost);

        for &(next, weight) in graph[curr].iter() {
            let next_cost = curr_cost + weight;
            if dist[next].map_or(true, |x| x > next_cost) {
                dist[next] = Some(next_cost);
                heap.push((Reverse(next_cost), next));
            }
        }
    }
    dist
}

Dial

Dial algorithm is an alternative to Dijkstra algorithm, which can be used when the maximum value of the edge costs is small. Instead of using a heap, dial algorithm uses a queue of vectors to sort searched paths by their distances. Generally the performance of it is quite similar to that of Dijkstra, but if somehow you want to use this instead, then go ahead!

The usage is exactly the same with Dijkstra algorithm, except that the name of the function is different.

Code

trait HasNz {
    type NzType;
    fn into_nz(self) -> Option<Self::NzType>;
    fn retrieve(nz: Self::NzType) -> Self;
}

macro_rules! impl_hasnz {
    ($($t:ty, $n:ty);*) => { $(
        impl HasNz for $t {
            type NzType = $n;
            fn into_nz(self) -> Option<$n> { <$n>::new(self) }
            fn retrieve(nz: $n) -> Self { nz.get() }
        }
    )* };
}

impl_hasnz!(i8, NonZeroI8; i16, NonZeroI16; i32, NonZeroI32; i64, NonZeroI64; i128, NonZeroI128; isize, NonZeroIsize);
impl_hasnz!(u8, NonZeroU8; u16, NonZeroU16; u32, NonZeroU32; u64, NonZeroU64; u128, NonZeroU128; usize, NonZeroUsize);

fn dial<T>(graph: &[Vec<(usize, T)>], src: usize) -> Vec<Option<T>>
where
    T: Copy + From<u8> + Into<u32> + Add<Output = T> + Sub<Output = T> + Eq + Ord + HasNz,
    <T as HasNz>::NzType: Copy,
{
    let max_cost: u32 = graph.iter().map(|list| list.iter().map(|&(_, v)| v.into())).flatten().max().unwrap_or(0);

    let mut dist: Vec<Option<T::NzType>> = vec![None; graph.len()];
    dist[src] = {
        let one: T = 1.into();
        one.into_nz()
    };

    let mut qcnt: usize = 1;
    let mut queue = std::collections::VecDeque::with_capacity(max_cost as usize + 2);
    for _ in 0..=queue.capacity() {
        queue.push_back(vec![]);
    }
    queue[0].push(src as u32);

    let mut curr_cost: T = 0.into();
    while qcnt != 0 {
        curr_cost = curr_cost + 1.into();
        let mut hand = queue.pop_front().unwrap();

        while let Some(curr) = hand.pop() {
            qcnt -= 1;
            if dist[curr as usize].map_or(false, |x| T::retrieve(x) < curr_cost) {
                continue;
            }

            for &(next, weight) in graph[curr as usize].iter() {
                let next_cost = curr_cost + weight;

                if dist[next].map_or(true, |x| T::retrieve(x) > next_cost) {
                    dist[next] = next_cost.into_nz();

                    qcnt += 1;
                    if weight == 0.into() {
                        hand.push(next as u32);
                    } else {
                        queue[weight.into() as usize - 1].push(next as u32);
                    }
                }
            }
        }

        queue.push_back(hand);
    }

    dist.iter().map(|x| x.map(|x| T::retrieve(x) - 1.into())).collect()
}

Strongly Connected Components

Finding SCCs of a directed graph. Tarjan's algorithm is used for the algorithm.

Example

use scc::*;

fn main() {
let mut graph = vec![vec![]; 5];
for (u, v) in [(0, 2), (3, 0), (2, 3), (0, 1), (1, 4)] {
    graph[u].push(v);
}

let scc_list = find_scc(&graph);
println!("{:?}", scc_list); // [[3, 2, 0], [1], [4]]
let scc_id = gen_scc_ids(&graph, &scc_list);
println!("{:?}", scc_id); // [0, 1, 0, 0, 2]
let scc_graph = gen_scc_graph(&graph, &scc_list, &scc_id);
println!("{:?}", scc_graph); // [[1], [2], []]
}

mod scc {
    struct SccStack {
        stack: Vec<u32>,
        check: Vec<u64>,
    }

    impl SccStack {
        fn new(cap: usize) -> Self {
            Self {
                stack: vec![0; cap],
                check: vec![0; (cap + 63) / 64],
            }
        }
        fn push(&mut self, n: usize) {
            self.stack.push(n as u32);
            self.check[n / 64] |= 1 << (n % 64);
        }
        fn pop(&mut self) -> Option<usize> {
            let tmp = self.stack.pop()? as usize;
            self.check[tmp / 64] &= !(1 << (tmp % 64));
            Some(tmp)
        }
        fn contains(&self, n: usize) -> bool {
            self.check[n / 64] & (1 << (n % 64)) != 0
        }
    }

    struct DfsPack {
        gid: usize,
        id: Vec<usize>,
        low: Vec<usize>,
        st: SccStack,
    }

    fn dfs(n: usize, graph: &[Vec<usize>], curr: usize, p: &mut DfsPack, list: &mut Vec<Vec<usize>>) {
        p.st.push(curr);
        p.id[curr] = p.gid;
        p.low[curr] = p.gid;
        p.gid += 1;

        for &next in graph[curr].iter() {
            if p.id[next] == n {
                dfs(n, graph, next, p, list);
            }
        }
        for &next in graph[curr].iter() {
            if p.st.contains(next) {
                p.low[curr] = p.low[curr].min(p.low[next]);
            }
        }

        if p.id[curr] == p.low[curr] {
            let mut newlist = vec![];
            while let Some(popped) = p.st.pop() {
                if popped == curr {
                    break;
                }
                newlist.push(popped);
            }
            newlist.push(curr);
            list.push(newlist);
        }
    }

    /// Returns a list of SCCs of `graph`.
    /// The returned list is a 2D vector of `usize`, which consists of a list of vertices within a same SCC.
    /// The order of the SCCs in the returned list is topologically sorted.
    ///
    /// The implementation uses Tarjan's SCC algorithm.
    pub fn find_scc(graph: &[Vec<usize>]) -> Vec<Vec<usize>> {
        let n = graph.len();
        let mut list = vec![];

        let mut p = DfsPack {
            gid: 0,
            id: vec![n; n],
            low: vec![usize::MAX; n],
            st: SccStack::new(n),
        };
        for x in 0..n {
            if p.id[x] != n {
                continue;
            }
            dfs(n, graph, x, &mut p, &mut list);
        }
        list.reverse();
        list
    }

    /// Returns a list about what SCC each vertices are in.
    /// `scc_list` has to be generated in advance from `find_scc`.
    pub fn gen_scc_ids(graph: &[Vec<usize>], scc_list: &[Vec<usize>]) -> Vec<usize> {
        let mut ids = vec![0; graph.len()];
        for (i, l) in scc_list.iter().enumerate() {
            for &v in l {
                ids[v] = i;
            }
        }
        ids
    }

    /// Returns a graph of SCCs. The number of vertices of the new graph will be the number of SCCs in the graph.
    /// `scc_list` and `scc_ids` have to be generated in advanced from `find_scc` and `gen_scc_ids`.
    pub fn gen_scc_graph(graph: &[Vec<usize>], scc_list: &[Vec<usize>], scc_ids: &[usize]) -> Vec<Vec<usize>> {
        let mut ret = vec![vec![]; scc_list.len()];
        for u in 0..graph.len() {
            let a = scc_ids[u];
            for &v in graph[u].iter() {
                let b = scc_ids[v];
                if a < b {
                    ret[a].push(b);
                }
            }
        }
        ret
    }
}

Code

mod scc {
    struct SccStack {
        stack: Vec<u32>,
        check: Vec<u64>,
    }

    impl SccStack {
        fn new(cap: usize) -> Self {
            Self {
                stack: vec![0; cap],
                check: vec![0; (cap + 63) / 64],
            }
        }
        fn push(&mut self, n: usize) {
            self.stack.push(n as u32);
            self.check[n / 64] |= 1 << (n % 64);
        }
        fn pop(&mut self) -> Option<usize> {
            let tmp = self.stack.pop()? as usize;
            self.check[tmp / 64] &= !(1 << (tmp % 64));
            Some(tmp)
        }
        fn contains(&self, n: usize) -> bool {
            self.check[n / 64] & (1 << (n % 64)) != 0
        }
    }

    struct DfsPack {
        gid: usize,
        id: Vec<usize>,
        low: Vec<usize>,
        st: SccStack,
    }

    fn dfs(n: usize, graph: &[Vec<usize>], curr: usize, p: &mut DfsPack, list: &mut Vec<Vec<usize>>) {
        p.st.push(curr);
        p.id[curr] = p.gid;
        p.low[curr] = p.gid;
        p.gid += 1;

        for &next in graph[curr].iter() {
            if p.id[next] == n {
                dfs(n, graph, next, p, list);
            }
        }
        for &next in graph[curr].iter() {
            if p.st.contains(next) {
                p.low[curr] = p.low[curr].min(p.low[next]);
            }
        }

        if p.id[curr] == p.low[curr] {
            let mut newlist = vec![];
            while let Some(popped) = p.st.pop() {
                if popped == curr {
                    break;
                }
                newlist.push(popped);
            }
            newlist.push(curr);
            list.push(newlist);
        }
    }

    /// Returns a list of SCCs of `graph`.
    /// The returned list is a 2D vector of `usize`, which consists of a list of vertices within a same SCC.
    /// The order of the SCCs in the returned list is topologically sorted.
    ///
    /// The implementation uses Tarjan's SCC algorithm.
    pub fn find_scc(graph: &[Vec<usize>]) -> Vec<Vec<usize>> {
        let n = graph.len();
        let mut list = vec![];

        let mut p = DfsPack {
            gid: 0,
            id: vec![n; n],
            low: vec![usize::MAX; n],
            st: SccStack::new(n),
        };
        for x in 0..n {
            if p.id[x] != n {
                continue;
            }
            dfs(n, graph, x, &mut p, &mut list);
        }
        list.reverse();
        list
    }

    /// Returns a list about what SCC each vertices are in.
    /// `scc_list` has to be generated in advance from `find_scc`.
    pub fn gen_scc_ids(graph: &[Vec<usize>], scc_list: &[Vec<usize>]) -> Vec<usize> {
        let mut ids = vec![0; graph.len()];
        for (i, l) in scc_list.iter().enumerate() {
            for &v in l {
                ids[v] = i;
            }
        }
        ids
    }

    /// Returns a graph of SCCs. The number of vertices of the new graph will be the number of SCCs in the graph.
    /// `scc_list` and `scc_ids` have to be generated in advanced from `find_scc` and `gen_scc_ids`.
    pub fn gen_scc_graph(graph: &[Vec<usize>], scc_list: &[Vec<usize>], scc_ids: &[usize]) -> Vec<Vec<usize>> {
        let mut ret = vec![vec![]; scc_list.len()];
        for u in 0..graph.len() {
            let a = scc_ids[u];
            for &v in graph[u].iter() {
                let b = scc_ids[v];
                if a < b {
                    ret[a].push(b);
                }
            }
        }
        ret
    }
}

Last modified on 231008.

2-SAT

2-SAT with \(N\) clauses can be solved with time complexity of \(O(N)\). SCC should be with this snippet in the code.

Example

use twosat::*;

fn main() {
// (not 0 or 1) and (not 1 or 2) and (0 or 2) and (2 or 1)
let mut ts = TwoSat::new(3);
for (a, b) in [((0, false), (1, true)), ((1, false), (2, true)), ((0, true), (2, true)), ((2, true), (1, true))] {
    ts.add_clause(a, b);
}
println!("{:?}", ts.solve()); // Some([false, true, true])

// (0 or 0) and (not 0 or not 0)
let mut ts = TwoSat::new(1);
for (a, b) in [((0, true), (0, true)), ((0, false), (0, false))] {
    ts.add_clause(a, b);
}
println!("{:?}", ts.solve()); // None
}

mod scc {
    struct SccStack {
        stack: Vec<u32>,
        check: Vec<u64>,
    }

    impl SccStack {
        fn new(cap: usize) -> Self {
            Self {
                stack: vec![0; cap],
                check: vec![0; (cap + 63) / 64],
            }
        }
        fn push(&mut self, n: usize) {
            self.stack.push(n as u32);
            self.check[n / 64] |= 1 << (n % 64);
        }
        fn pop(&mut self) -> Option<usize> {
            let tmp = self.stack.pop()? as usize;
            self.check[tmp / 64] &= !(1 << (tmp % 64));
            Some(tmp)
        }
        fn contains(&self, n: usize) -> bool {
            self.check[n / 64] & (1 << (n % 64)) != 0
        }
    }

    struct DfsPack {
        gid: usize,
        id: Vec<usize>,
        low: Vec<usize>,
        st: SccStack,
    }

    fn dfs(n: usize, graph: &[Vec<usize>], curr: usize, p: &mut DfsPack, list: &mut Vec<Vec<usize>>) {
        p.st.push(curr);
        p.id[curr] = p.gid;
        p.low[curr] = p.gid;
        p.gid += 1;

        for &next in graph[curr].iter() {
            if p.id[next] == n {
                dfs(n, graph, next, p, list);
            }
        }
        for &next in graph[curr].iter() {
            if p.st.contains(next) {
                p.low[curr] = p.low[curr].min(p.low[next]);
            }
        }

        if p.id[curr] == p.low[curr] {
            let mut newlist = vec![];
            while let Some(popped) = p.st.pop() {
                if popped == curr {
                    break;
                }
                newlist.push(popped);
            }
            newlist.push(curr);
            list.push(newlist);
        }
    }

    /// Returns a list of SCCs of `graph`.
    /// The returned list is a 2D vector of `usize`, which consists of a list of vertices within a same SCC.
    /// The order of the SCCs in the returned list is topologically sorted.
    ///
    /// The implementation uses Tarjan's SCC algorithm.
    pub fn find_scc(graph: &[Vec<usize>]) -> Vec<Vec<usize>> {
        let n = graph.len();
        let mut list = vec![];

        let mut p = DfsPack {
            gid: 0,
            id: vec![n; n],
            low: vec![usize::MAX; n],
            st: SccStack::new(n),
        };
        for x in 0..n {
            if p.id[x] != n {
                continue;
            }
            dfs(n, graph, x, &mut p, &mut list);
        }
        list.reverse();
        list
    }

    /// Returns a list about what SCC each vertices are in.
    /// `scc_list` has to be generated in advance from `find_scc`.
    pub fn gen_scc_ids(graph: &[Vec<usize>], scc_list: &[Vec<usize>]) -> Vec<usize> {
        let mut ids = vec![0; graph.len()];
        for (i, l) in scc_list.iter().enumerate() {
            for &v in l {
                ids[v] = i;
            }
        }
        ids
    }

    /// Returns a graph of SCCs. The number of vertices of the new graph will be the number of SCCs in the graph.
    /// `scc_list` and `scc_ids` have to be generated in advanced from `find_scc` and `gen_scc_ids`.
    pub fn gen_scc_graph(graph: &[Vec<usize>], scc_list: &[Vec<usize>], scc_ids: &[usize]) -> Vec<Vec<usize>> {
        let mut ret = vec![vec![]; scc_list.len()];
        for u in 0..graph.len() {
            let a = scc_ids[u];
            for &v in graph[u].iter() {
                let b = scc_ids[v];
                if a < b {
                    ret[a].push(b);
                }
            }
        }
        ret
    }
}

mod twosat {
    use super::scc::*;

    /// 2-SAT solver.
    pub struct TwoSat {
        n: usize,
        graph: Vec<Vec<usize>>,
    }

    impl TwoSat {
        /// Creates a new instance of 2-SAT solver.
        pub fn new(n: usize) -> Self {
            Self { n, graph: vec![vec![]; n << 1] }
        }

        /// Adds a clause of `(i, f) & (j, g)`.
        /// For example, `self.add_clause((0, false), (1, true))` is adding a clause `~x0 & x1` to the solver.
        pub fn add_clause(&mut self, (i, f): (usize, bool), (j, g): (usize, bool)) {
            let judge = |x: bool, a: usize, b: usize| if x { a } else { b };
            self.graph[i * 2 + judge(f, 0, 1)].push(j * 2 + judge(g, 1, 0));
            self.graph[j * 2 + judge(g, 0, 1)].push(i * 2 + judge(f, 1, 0));
        }

        /// Returns any possible solution of the 2-SAT problem if there's any in O(N) time.
        /// Returns `None` if the problem is unsolvable.
        pub fn solve(&self) -> Option<Vec<bool>> {
            let mut ans = vec![false; self.n];
            let scc_list = find_scc(&self.graph);
            let ids = gen_scc_ids(&self.graph, &scc_list);
            for i in 0..self.n {
                if ids[i * 2] == ids[i * 2 + 1] {
                    return None;
                }
                ans[i] = ids[i * 2] < ids[i * 2 + 1];
            }
            Some(ans)
        }
    }
}

Code

mod twosat {
    use super::scc::*;

    /// 2-SAT solver.
    pub struct TwoSat {
        n: usize,
        graph: Vec<Vec<usize>>,
    }

    impl TwoSat {
        /// Creates a new instance of 2-SAT solver.
        pub fn new(n: usize) -> Self {
            Self { n, graph: vec![vec![]; n << 1] }
        }

        /// Adds a clause of `(i, f) & (j, g)`.
        /// For example, `self.add_clause((0, false), (1, true))` is adding a clause `~x0 & x1` to the solver.
        pub fn add_clause(&mut self, (i, f): (usize, bool), (j, g): (usize, bool)) {
            let judge = |x: bool, a: usize, b: usize| if x { a } else { b };
            self.graph[i * 2 + judge(f, 0, 1)].push(j * 2 + judge(g, 1, 0));
            self.graph[j * 2 + judge(g, 0, 1)].push(i * 2 + judge(f, 1, 0));
        }

        /// Returns any possible solution of the 2-SAT problem if there's any in O(N) time.
        /// Returns `None` if the problem is unsolvable.
        pub fn solve(&self) -> Option<Vec<bool>> {
            let mut ans = vec![false; self.n];
            let scc_list = find_scc(&self.graph);
            let ids = gen_scc_ids(&self.graph, &scc_list);
            for i in 0..self.n {
                if ids[i * 2] == ids[i * 2 + 1] {
                    return None;
                }
                ans[i] = ids[i * 2] < ids[i * 2 + 1];
            }
            Some(ans)
        }
    }
}

Last modified on 231008.

Heavy-Light Decomposition

Code

/// Reference: https://codeforces.com/blog/entry/53170
mod hld {
	pub struct Hld {
		pub root: usize,
		/// `size[i]`: The size of a subtree rooted with `i`.
		pub size: Vec<usize>,
		/// `dep[i]`: A depth of `i`.
		pub dep: Vec<usize>,
		/// `par[i]`: A parent node of `i`.
		pub par: Vec<usize>,
		/// `top[i]`: The highest node of a chain `i` is in.
		pub top: Vec<usize>,
		/// `ein[i]`: DFS ordering of `i`.
		pub ein: Vec<usize>,
		/// `eout[i]`: The order when DFS exited from `i`.
		pub eout: Vec<usize>,
		/// If `rin[v] == i`, then `ein[i] == v`.
		pub rin: Vec<usize>,
	}

	impl Hld {
		pub fn new(root: usize, graph: &mut [Vec<usize>]) -> Self {
			fn dfs1(u: usize, p: usize, graph: &mut [Vec<usize>], h: &mut Hld) {
				h.size[u] = 1;
				h.par[u] = p;
				for i in 0..graph[u].len() {
					let v = graph[u][i];
					if v != p {
						h.dep[v] = h.dep[u] + 1;
						dfs1(v, u, graph, h);
						h.size[u] += h.size[v];
						if h.size[v] > h.size[graph[u][0]] {
							graph[u].swap(i, 0);
						}
					}
				}
			}

			fn dfs2(u: usize, p: usize, cnt: &mut usize, graph: &[Vec<usize>], h: &mut Hld) {
				h.ein[u] = *cnt;
				h.rin[*cnt] = u;
				*cnt += 1;
				for &v in graph[u].iter().filter(|&&v| v != p) {
					h.top[v] = if v == graph[u][0] { h.top[u] } else { v };
					dfs2(v, u, cnt, graph, h);
				}
				h.eout[u] = *cnt;
			}

			let n = graph.len();
			let [size, dep, par, top, ein, eout, rin] = [0; 7].map(|_| vec![0; n]);
			let mut ret = Hld { root, size, dep, par, top, ein, eout, rin };

			dfs1(root, n, graph, &mut ret);
			dfs2(root, n, &mut 0, graph, &mut ret);
			ret
		}

		/// Returns ranges which constitute a path between `u` and `v`.
		/// Every range is a half-open intervals, i.e. `[l, r)`, and indices of the ranges are matched with `self.ein`.
		/// The ranges are for path queries on "vertices". For "edges", the last range should be truncated at left by 1.
		pub fn chains(&self, u: usize, v: usize) -> impl Iterator<Item = (usize, usize)> + '_ {
			let (mut a, mut b, mut k) = (u, v, true);
			std::iter::from_fn(move || {
				k.then(|| {
					if self.top[a] != self.top[b] {
						if self.dep[self.top[a]] < self.dep[self.top[b]] {
							std::mem::swap(&mut a, &mut b);
						}
						let st = self.top[a];
						let ret = (self.ein[st], self.ein[a] + 1);
						a = self.par[st];
						ret
					} else {
						if self.dep[a] > self.dep[b] {
							std::mem::swap(&mut a, &mut b);
						}
						k = false;
						(self.ein[a], self.ein[b] + 1)
					}
				})
			})
		}

		pub fn lca(&self, u: usize, v: usize) -> usize { self.chains(u, v).last().map(|(a, _)| self.rin[a]).unwrap() }

		pub fn lca_alt_root(&self, r: usize, u: usize, v: usize) -> usize {
			if r == self.root {
				return self.lca(u, v);
			}
			let (uv, ru, rv) = (self.lca(u, v), self.lca(r, u), self.lca(r, v));
			if rv == uv {
				ru
			} else if ru == uv {
				rv
			} else {
				uv
			}
		}
	}
}

Flow

Dinic's Algorithm

Dinic's algorithm is a practically fast algorithm for computing the maximum flow in a flow network.

User can create a network instance with Dinic::new(n) where n is the number of vertices in the network, and add edges with Dinic::add_edges(&mut self, src, dst, cap) method where cap is the capacity of the edge being added.

Calling Dinic::max_flow(&mut self, src, dst) calculates the maximum flow from src to dst. After the calculation, remaining capacity of any edges can be found by directly inspecting into cap field of a wanted edge.

Example

The example below calculates maximum flow of a flow network shown on page 5 of https://github.com/justiceHui/SSU-SCCC-Study/blob/master/2022-winter-adv/slide/02.pdf.

use dinic::Dinic;
fn main() {
let mut dn = Dinic::new(4);
for (src, dst, cap) in [(0, 1, 2), (0, 2, 2), (1, 2, 1), (1, 3, 2), (2, 3, 2)] {
    dn.add_edge(src, dst, cap);
}

let (src, dst) = (0, 3);
let max_flow = dn.max_flow(src, dst);
println!("{max_flow}"); // 4

// Remaining capacity of an edge from 0 to 1
let edge = dn.g[0].iter().filter(|e| e.dst == 1).next().unwrap();
println!("{}", edge.cap); // 0
}

mod dinic {
    //! Reference: https://github.com/justiceHui/SSU-SCCC-Study/blob/master/2022-winter-adv/slide/04.pdf

    use std::collections::VecDeque;

    #[derive(Clone)]
    pub struct Edge {
        pub dst: u32,
        pub opp: u32,
        pub cap: u64,
    }

    impl Edge {
        fn new(dst: usize, opp: usize, cap: u64) -> Self {
            Self {
                dst: dst as u32,
                opp: opp as u32,
                cap,
            }
        }
    }

    pub struct Dinic {
        pub n: usize,
        pub g: Vec<Vec<Edge>>,
    }

    impl Dinic {
        pub fn new(n: usize) -> Self {
            Self {
                n,
                g: vec![vec![]; n],
            }
        }

        pub fn add_edge(&mut self, s: usize, e: usize, cap: u64) {
            let sl = self.g[s].len();
            let el = self.g[e].len();
            self.g[s].push(Edge::new(e, el, cap));
            self.g[e].push(Edge::new(s, sl, 0));
        }

        fn bfs(&mut self, s: u32, t: u32, lv: &mut [u32]) -> bool {
            lv.fill(0);

            let mut queue = VecDeque::new();
            queue.push_back(s);
            lv[s as usize] = 1;

            while let Some(v) = queue.pop_front() {
                for e in self.g[v as usize].iter() {
                    if lv[e.dst as usize] == 0 && e.cap != 0 {
                        queue.push_back(e.dst);
                        lv[e.dst as usize] = lv[v as usize] + 1;
                    }
                }
            }

            lv[t as usize] != 0
        }

        fn dfs(&mut self, v: u32, t: u32, fl: u64, lv: &[u32], idx: &mut [u32]) -> u64 {
            if v == t || fl == 0 {
                return fl;
            }

            for i in idx[v as usize]..self.g[v as usize].len() as u32 {
                idx[v as usize] = i;

                let Edge { dst, opp, cap } = self.g[v as usize][i as usize];
                if lv[dst as usize] != lv[v as usize] + 1 || cap == 0 {
                    continue;
                }
                let now = self.dfs(dst, t, fl.min(cap), lv, idx);
                if now == 0 {
                    continue;
                }

                self.g[v as usize][i as usize].cap -= now;
                self.g[dst as usize][opp as usize].cap += now;
                return now;
            }

            0
        }

        pub fn max_flow(&mut self, src: usize, dst: usize) -> u64 {
            let mut flow = 0;
            let mut aug;
            let mut lv = vec![0; self.n];
            let mut idx = vec![0; self.n];

            while self.bfs(src as u32, dst as u32, &mut lv) {
                idx.fill(0);
                loop {
                    aug = self.dfs(src as u32, dst as u32, u64::MAX, &mut lv, &mut idx);
                    if aug == 0 {
                        break;
                    }
                    flow += aug;
                }
            }
            flow
        }
    }
}

Code

mod dinic {
    //! Reference: https://github.com/justiceHui/SSU-SCCC-Study/blob/master/2022-winter-adv/slide/04.pdf

    use std::collections::VecDeque;

    #[derive(Clone)]
    pub struct Edge {
        pub dst: u32,
        pub opp: u32,
        pub cap: u64,
    }

    impl Edge {
        fn new(dst: usize, opp: usize, cap: u64) -> Self {
            Self {
                dst: dst as u32,
                opp: opp as u32,
                cap,
            }
        }
    }

    pub struct Dinic {
        pub n: usize,
        pub g: Vec<Vec<Edge>>,
    }

    impl Dinic {
        pub fn new(n: usize) -> Self {
            Self {
                n,
                g: vec![vec![]; n],
            }
        }

        pub fn add_edge(&mut self, s: usize, e: usize, cap: u64) {
            let sl = self.g[s].len();
            let el = self.g[e].len();
            self.g[s].push(Edge::new(e, el, cap));
            self.g[e].push(Edge::new(s, sl, 0));
        }

        fn bfs(&mut self, s: u32, t: u32, lv: &mut [u32]) -> bool {
            lv.fill(0);

            let mut queue = VecDeque::new();
            queue.push_back(s);
            lv[s as usize] = 1;

            while let Some(v) = queue.pop_front() {
                for e in self.g[v as usize].iter() {
                    if lv[e.dst as usize] == 0 && e.cap != 0 {
                        queue.push_back(e.dst);
                        lv[e.dst as usize] = lv[v as usize] + 1;
                    }
                }
            }

            lv[t as usize] != 0
        }

        fn dfs(&mut self, v: u32, t: u32, fl: u64, lv: &[u32], idx: &mut [u32]) -> u64 {
            if v == t || fl == 0 {
                return fl;
            }

            for i in idx[v as usize]..self.g[v as usize].len() as u32 {
                idx[v as usize] = i;

                let Edge { dst, opp, cap } = self.g[v as usize][i as usize];
                if lv[dst as usize] != lv[v as usize] + 1 || cap == 0 {
                    continue;
                }
                let now = self.dfs(dst, t, fl.min(cap), lv, idx);
                if now == 0 {
                    continue;
                }

                self.g[v as usize][i as usize].cap -= now;
                self.g[dst as usize][opp as usize].cap += now;
                return now;
            }

            0
        }

        pub fn max_flow(&mut self, src: usize, dst: usize) -> u64 {
            let mut flow = 0;
            let mut aug;
            let mut lv = vec![0; self.n];
            let mut idx = vec![0; self.n];

            while self.bfs(src as u32, dst as u32, &mut lv) {
                idx.fill(0);
                loop {
                    aug = self.dfs(src as u32, dst as u32, u64::MAX, &mut lv, &mut idx);
                    if aug == 0 {
                        break;
                    }
                    flow += aug;
                }
            }
            flow
        }
    }
}

MCMF

The APIs of this snippet is similar to those of Dinic's.

The algorithm used for finding shortest path for MCMF is SPFA.

Code

mod mcmf {
    //! Reference: https://github.com/justiceHui/SSU-SCCC-Study/blob/master/2022-winter-adv/slide/04.pdf

    use std::collections::VecDeque;

    #[derive(Clone)]
    pub struct Edge {
        pub dst: u32,
        pub opp: u32,
        pub cap: u64,
        pub cost: i64,
    }

    impl Edge {
        fn new(dst: usize, opp: usize, flow: u64, cost: i64) -> Self {
            Self {
                dst: dst as u32,
                opp: opp as u32,
                cap: flow,
                cost,
            }
        }
    }

    pub struct Mcmf {
        pub n: usize,
        pub g: Vec<Vec<Edge>>,
    }

    impl Mcmf {
        pub fn new(n: usize) -> Self {
            Self {
                n,
                g: vec![vec![]; n],
            }
        }

        pub fn add_edge(&mut self, s: usize, e: usize, c: u64, d: i64) {
            let (slen, elen) = (self.g[s].len(), self.g[e].len());
            self.g[s].push(Edge::new(e, elen, c, d));
            self.g[e].push(Edge::new(s, slen, 0, -d));
        }

        fn augment(
            &mut self,
            s: usize,
            t: usize,
            prv: &mut [usize],
            idx: &mut [usize],
            dst: &mut [i64],
        ) -> bool {
            let mut inn = vec![false; self.n];
            dst.fill(i64::MAX);
            let mut queue = VecDeque::new();
            inn[s] = true;
            dst[s] = 0;
            queue.push_back(s);

            while let Some(v) = queue.pop_front() {
                inn[v] = false;
                for (i, e) in self.g[v].iter().enumerate() {
                    let src = e.dst as usize;
                    if e.cap != 0 && dst[src] > dst[v] + e.cost {
                        dst[src] = dst[v] + e.cost;
                        prv[src] = v;
                        idx[src] = i;
                        if !inn[src] {
                            inn[src] = true;
                            queue.push_back(src);
                        }
                    }
                }
            }
            dst[t] < i64::MAX
        }

        pub fn min_cost_max_flow(&mut self, s: usize, t: usize) -> (u64, i64) {
            use std::iter::successors;
            let mut flow = 0;
            let mut cost = 0;

            let mut prv = vec![0; self.n];
            let mut idx = vec![0; self.n];
            let mut dst = vec![0; self.n];
            while self.augment(s, t, &mut prv, &mut idx, &mut dst) {
                let path = successors(Some(t), |&i| Some(prv[i]))
                    .take_while(|&i| i != s)
                    .map(|i| self.g[prv[i]][idx[i]].cap)
                    .min()
                    .unwrap();
                flow += path;
                cost += path as i64 * dst[t];
                for i in successors(Some(t), |&i| Some(prv[i])).take_while(|&i| i != s) {
                    self.g[prv[i]][idx[i]].cap -= path;
                    let j = self.g[prv[i]][idx[i]].opp as usize;
                    self.g[i][j].cap += path;
                }
            }

            (flow, cost)
        }
    }
}

KMP

Given an array pattern, failure_function(pattern) returns a failure function failure of pattern.

Given an array haystack, kmp_search(haystack, pattern, failure) returns a result of searching pattern in haystack, given that failure is a proper failure function of pattern. Denoting the returned array as result and the length of pattern as n, if result[i + n] == n, then result[i..n] == pattern.

The API looks like this, as the failure function itself is needed in many algorithm problems, rather than directly using KMP for a string searching.

Example

fn main() {
let pattern = b"ABCDABC";
let targets = ["ABDABCDABCE", "ABCDABCDABCD", "ABBCCABCDABDABCDABC"].map(|b| b.as_bytes());

let failure = failure_function(pattern);
for &t in &targets {
    let result = kmp_search(t, pattern, &failure);
    println!("{:?}", result);
    for i in 0..result.len() - pattern.len() {
        if result[i + pattern.len()] == pattern.len() {
            print!("{} ", i);
        }
    }
    println!();
}
}

/// Returns a failure function of `pattern`.
fn failure_function<T: PartialEq>(pattern: &[T]) -> Vec<usize> {
    let n = pattern.len();
    let mut c = vec![0, 0];
    let mut x;
    for i in 1..n {
        x = c[i];
        loop {
            if pattern[i] == pattern[x] {
                c.push(x + 1);
                break;
            }
            if x == 0 {
                c.push(0);
                break;
            }
            x = c[x];
        }
    }
    c
}

/// Returns a result of KMP search.
/// For `n = pattern.len()`, if `result[i] == n`, then `haystack[i-n..i] == pattern`.
fn kmp_search<T: PartialEq>(haystack: &[T], pattern: &[T], failure: &[usize]) -> Vec<usize> {
    let m = haystack.len();
    let mut d = vec![0];
    let mut x;
    for i in 0..m {
        x = d[i];
        if x == pattern.len() {
            x = failure[x];
        }
        loop {
            if haystack[i] == pattern[x] {
                d.push(x + 1);
                break;
            }
            if x == 0 {
                d.push(0);
                break;
            }
            x = failure[x];
        }
    }
    d
}

Code

/// Returns a failure function of `pattern`.
fn failure_function<T: PartialEq>(pattern: &[T]) -> Vec<usize> {
    let n = pattern.len();
    let mut c = vec![0, 0];
    let mut x;
    for i in 1..n {
        x = c[i];
        loop {
            if pattern[i] == pattern[x] {
                c.push(x + 1);
                break;
            }
            if x == 0 {
                c.push(0);
                break;
            }
            x = c[x];
        }
    }
    c
}

/// Returns a result of KMP search.
/// For `n = pattern.len()`, if `result[i] == n`, then `haystack[i-n..i] == pattern`.
fn kmp_search<T: PartialEq>(haystack: &[T], pattern: &[T], failure: &[usize]) -> Vec<usize> {
    let m = haystack.len();
    let mut d = vec![0];
    let mut x;
    for i in 0..m {
        x = d[i];
        if x == pattern.len() {
            x = failure[x];
        }
        loop {
            if haystack[i] == pattern[x] {
                d.push(x + 1);
                break;
            }
            if x == 0 {
                d.push(0);
                break;
            }
            x = failure[x];
        }
    }
    d
}

Last modified on 231008.

Manacher

For an array \(A\) of length \(n\), manacher(A) returns a vector \(M\) where, for every \(i \in \left[0, n\right)\), \(A_{i-j} = A_{i+j}\) holds for every \(j \in \left[0, M_i \right)\).

Additional modification should be added by a user to use this function for finding every palindromes among subsequences of a string.

Example

fn main() {
let s = "abracadacabra".as_bytes();
let man = manacher(s);
println!("{:?}", man); // [1, 1, 1, 1, 2, 1, 4, 1, 2, 1, 1, 1, 1]

for i in 0..s.len() {
    println!(
        "{}",
        std::str::from_utf8(&s[i + 1 - man[i]..i + man[i]]).unwrap()
    );
}
}

fn manacher<T: Eq>(arr: &[T]) -> Vec<usize> {
    let n = arr.len();
    let mut mana: Vec<usize> = vec![1; n];
    let mut r: usize = 1;
    let mut p: usize = 0;

    for i in 1..arr.len() {
        if i + 1 >= r {
            mana[i] = 1;
        } else {
            let j = 2 * p - i;
            mana[i] = mana[j].min(r - i);
        }

        while mana[i] <= i && i + mana[i] < n {
            if arr[(i - mana[i])] != arr[(i + mana[i])] {
                break;
            }
            mana[i] += 1;
        }

        if r < mana[i] + i {
            r = mana[i] + i;
            p = i;
        }
    }

    mana
}

Code

fn manacher<T: Eq>(arr: &[T]) -> Vec<usize> {
    let n = arr.len();
    let mut mana: Vec<usize> = vec![1; n];
    let mut r: usize = 1;
    let mut p: usize = 0;

    for i in 1..arr.len() {
        if i + 1 >= r {
            mana[i] = 1;
        } else {
            let j = 2 * p - i;
            mana[i] = mana[j].min(r - i);
        }

        while mana[i] <= i && i + mana[i] < n {
            if arr[(i - mana[i])] != arr[(i + mana[i])] {
                break;
            }
            mana[i] += 1;
        }

        if r < mana[i] + i {
            r = mana[i] + i;
            p = i;
        }
    }

    mana
}

Suffix Array and LCP Array

For an array \(A\) of length \(n\), sa_lcp(A) returns two vectors \(SA\) and \(LCP\) where,

  • \(A[SA[i] \dots]\) is the \(i\)-th suffix in lexicographical order for every \(i \in \left[0, n\right)\)

and

  • \(LCP[i]\) is the length of the longest common prefix between \(A[SA[i-1] \dots]\) and \(A[SA[i] \dots]\) for every \(i \in \left[1, n \right)\). Also, \(LCP[0] = 0\).

\(SA\) and \(LCP\) are called "suffix array" and "LCP array" of \(A\) respectively.

For finding SA, Manber-Myers algorithm combined with counting sort is used, hence the time complexity is \(O(n\log{n})\). For LCP array, Kasai's algorithm is used, hence the time complexity is \(O(n)\). The total time complexity is \(O(n\log{n})\).

Example

fn main() {
let s = "asdsdasd";
let (sa, lcp) = sa_lcp(s.as_bytes());
println!("{:?}", sa);  // [5, 0, 7, 4, 2, 6, 3, 1]
println!("{:?}", lcp); // [x, 3, 0, 1, 1, 0, 2, 2]
}

// Suffix array and LCP array
// Reference: http://www.secmem.org/blog/2021/07/18/suffix-array-and-lcp/

fn suffix_array<T: Ord>(s: &[T]) -> Vec<usize> {
    use std::collections::*;

    if s.len() == 0 {
        return vec![];
    } else if s.len() == 1 {
        return vec![0];
    }

    let n = s.len();

    let mut r: Vec<usize> = vec![0; n * 2];
    let map: BTreeMap<_, _> = {
        let mut sorted: Vec<_> = s.iter().collect();
        sorted.sort_unstable();
        sorted
            .into_iter()
            .enumerate()
            .map(|x| (x.1, x.0 + 1))
            .collect()
    };
    for i in 0..n {
        r[i] = *map.get(&s[i]).unwrap();
    }

    let m = n.max(map.len()) + 1;
    let mut sa: Vec<usize> = (0..n).collect();
    let mut nr: Vec<usize> = vec![0; n * 2];
    let mut cnt: Vec<usize> = vec![0; m];
    let mut idx: Vec<usize> = vec![0; n];

    for d in (0..).map(|x| 1 << x).take_while(|&d| d < n) {
        macro_rules! key {
            ($i:expr) => {
                if $i + d >= n {
                    (r[$i], 0)
                } else {
                    (r[$i], r[$i + d])
                }
            };
        }

        (0..m).for_each(|i| cnt[i] = 0);
        (0..n).for_each(|i| cnt[r[i + d]] += 1);
        (1..m).for_each(|i| cnt[i] += cnt[i - 1]);
        for i in (0..n).rev() {
            cnt[r[i + d]] -= 1;
            idx[cnt[r[i + d]]] = i;
        }

        (0..m).for_each(|i| cnt[i] = 0);
        (0..n).for_each(|i| cnt[r[i]] += 1);
        (1..m).for_each(|i| cnt[i] += cnt[i - 1]);
        for i in (0..n).rev() {
            cnt[r[idx[i]]] -= 1;
            sa[cnt[r[idx[i]]]] = idx[i];
        }

        nr[sa[0]] = 1;
        for i in 1..n {
            nr[sa[i]] = nr[sa[i - 1]] + if key!(sa[i - 1]) < key!(sa[i]) { 1 } else { 0 };
        }
        std::mem::swap(&mut r, &mut nr);

        if r[sa[n - 1]] == n {
            break;
        }
    }

    sa
}

fn sa_lcp<T: Ord>(arr: &[T]) -> (Vec<usize>, Vec<usize>) {
    let n = arr.len();
    let sa = suffix_array(arr);
    let mut lcp: Vec<usize> = vec![0; n];
    let mut isa: Vec<usize> = vec![0; n];
    for i in 0..n {
        isa[sa[i]] = i;
    }
    let mut k = 0;
    for i in 0..n {
        if isa[i] != 0 {
            let j = sa[isa[i] - 1];
            while i + k < n && j + k < n && arr[i + k] == arr[j + k] {
                k += 1;
            }
            lcp[isa[i]] = if k != 0 {
                k -= 1;
                k + 1
            } else {
                0
            };
        }
    }
    (sa, lcp)
}

Code

// Suffix array and LCP array
// Reference: http://www.secmem.org/blog/2021/07/18/suffix-array-and-lcp/

fn suffix_array<T: Ord>(s: &[T]) -> Vec<usize> {
    use std::collections::*;

    if s.len() == 0 {
        return vec![];
    } else if s.len() == 1 {
        return vec![0];
    }

    let n = s.len();

    let mut r: Vec<usize> = vec![0; n * 2];
    let map: BTreeMap<_, _> = {
        let mut sorted: Vec<_> = s.iter().collect();
        sorted.sort_unstable();
        sorted
            .into_iter()
            .enumerate()
            .map(|x| (x.1, x.0 + 1))
            .collect()
    };
    for i in 0..n {
        r[i] = *map.get(&s[i]).unwrap();
    }

    let m = n.max(map.len()) + 1;
    let mut sa: Vec<usize> = (0..n).collect();
    let mut nr: Vec<usize> = vec![0; n * 2];
    let mut cnt: Vec<usize> = vec![0; m];
    let mut idx: Vec<usize> = vec![0; n];

    for d in (0..).map(|x| 1 << x).take_while(|&d| d < n) {
        macro_rules! key {
            ($i:expr) => {
                if $i + d >= n {
                    (r[$i], 0)
                } else {
                    (r[$i], r[$i + d])
                }
            };
        }

        (0..m).for_each(|i| cnt[i] = 0);
        (0..n).for_each(|i| cnt[r[i + d]] += 1);
        (1..m).for_each(|i| cnt[i] += cnt[i - 1]);
        for i in (0..n).rev() {
            cnt[r[i + d]] -= 1;
            idx[cnt[r[i + d]]] = i;
        }

        (0..m).for_each(|i| cnt[i] = 0);
        (0..n).for_each(|i| cnt[r[i]] += 1);
        (1..m).for_each(|i| cnt[i] += cnt[i - 1]);
        for i in (0..n).rev() {
            cnt[r[idx[i]]] -= 1;
            sa[cnt[r[idx[i]]]] = idx[i];
        }

        nr[sa[0]] = 1;
        for i in 1..n {
            nr[sa[i]] = nr[sa[i - 1]] + if key!(sa[i - 1]) < key!(sa[i]) { 1 } else { 0 };
        }
        std::mem::swap(&mut r, &mut nr);

        if r[sa[n - 1]] == n {
            break;
        }
    }

    sa
}

fn sa_lcp<T: Ord>(arr: &[T]) -> (Vec<usize>, Vec<usize>) {
    let n = arr.len();
    let sa = suffix_array(arr);
    let mut lcp: Vec<usize> = vec![0; n];
    let mut isa: Vec<usize> = vec![0; n];
    for i in 0..n {
        isa[sa[i]] = i;
    }
    let mut k = 0;
    for i in 0..n {
        if isa[i] != 0 {
            let j = sa[isa[i] - 1];
            while i + k < n && j + k < n && arr[i + k] == arr[j + k] {
                k += 1;
            }
            lcp[isa[i]] = if k != 0 {
                k -= 1;
                k + 1
            } else {
                0
            };
        }
    }
    (sa, lcp)
}

2D Geometry Base

Code

type I = i64;
type P = [I; 2];
type L = [P; 2];

fn scale(s: I, a: P) -> P { a.map(|x| x * s) }
fn add(a: P, b: P) -> P { [a[0] + b[0], a[1] + b[1]] }
fn sub(a: P, b: P) -> P { [a[0] - b[0], a[1] - b[1]] }
fn dot(a: P, b: P) -> I { a[0] * b[0] + a[1] * b[1] }
fn cross(a: P, b: P) -> I { a[0] * b[1] - a[1] * b[0] }
fn ccw(a: P, b: P, c: P) -> I { cross(sub(b, a), sub(c, b)) }

Last modified on 231225.

Angle Comparator

Code

/// Returns a comparator for points relative to `x_axis`: first by CCW angle, then by distance from the origin for equal angles.
/// The origin is considered the smallest.
fn angle_cmp(x_axis: P) -> impl Fn(&P, &P) -> Ordering {
	move |&a, &b| {
		let ud = |p| cross(p, x_axis) > 0 || (cross(p, x_axis) == 0 && dot(p, x_axis) < 0);
		ud(a).cmp(&ud(b)).then_with(|| cross(b, a).cmp(&0).then_with(|| dot(a, a).cmp(&dot(b, b))))
	}
}

Last modified on 231225.

Convex Hull

convex_hull finds a convex hull of a given array. It's implemented based on monotone chain algorithm, a much intuitive and straightforward convex hull algorithm compared to the well-known Graham scan.

If COLLINEAR is set to false, then every point which lies on a vertex of a convex hull, but not at the endpoints of it, is excluded. If it's set to true, then those points are all included.

The result is sorted in clockwise direction.

If the input includes duplicates and those happen to be on the convex hull, then only one of them is included in the result.

Code

/// Returns a convex hull of a set of 2D points `arr` in CCW order.
/// Set `COLLINEAR` to `true` to include, or `false` to exclude, collinear edge points.
fn convex_hull<const COLLINEAR: bool>(arr: &[P]) -> Vec<P> {
	let mut arr = arr.to_vec();
	arr.sort_unstable();
	arr.dedup();
	if arr.len() <= 1 {
		return arr.clone();
	}
	let mut ret = vec![];

	fn monotone<const COLLINEAR: bool>(it: impl Iterator<Item = P>) -> Vec<P> {
		let mut dl = vec![];
		for p in it {
			while dl.len() >= 2 {
				let n = dl.len();
				let v = ccw(dl[n - 2], dl[n - 1], p);
				if v < 0 || (!COLLINEAR && v == 0) {
					dl.pop();
				} else {
					break;
				}
			}
			dl.push(p);
		}
		dl
	}

	ret.extend(monotone::<COLLINEAR>(arr.iter().copied()));
	ret.pop();
	ret.extend(monotone::<COLLINEAR>(arr.iter().copied().rev()));
	ret.pop();

	ret
}

Last modified on 231225.

Line Intersection

Code

type Frac = [i128; 2];
/// Calculates the intersection of line segments `p` and `q`.
/// Returns `None` for no intersection.
/// Returns `Some(Ok([x, y]))` for a point intersection, where `[x, y]` is a fraction (`[numerator, denominator]`).
/// Returns `Some(Err([x, y]))` for an overlapping line segment, where `[x, y]` represents endpoints of the overlap.
///
/// Note: Potential overflow for `i64` or `i128`; intended for use with `i32`.
fn intersect(p: L, q: L) -> Option<Result<[Frac; 2], L>> {
	use std::cmp::Ordering::*;
	let u = cross(sub(p[1], p[0]), sub(q[1], q[0]));
	let sn = cross(sub(q[0], p[0]), sub(q[1], q[0]));
	let tn = cross(sub(q[0], p[0]), sub(p[1], p[0]));
	if u != 0 {
		let int = if u >= 0 { 0..=u } else { u..=0 };
		if int.contains(&sn) && int.contains(&tn) {
			let [s, r] = [sn, u - sn].map(|x| x as i128);
			let [g, h] = p.map(|f| f.map(|x| x as i128));
			let [x, y] = [0, 1].map(|i| [r * g[i] + s * h[i], u as i128]);
			Some(Ok([x, y]))
		} else {
			None
		}
	} else {
		if sn != 0 || tn != 0 {
			return None;
		}
		let (a0, a1) = (p[0].min(p[1]), p[0].max(p[1]));
		let (b0, b1) = (q[0].min(q[1]), q[0].max(q[1]));
		let (l, r) = (a0.max(b0), a1.min(b1));
		match l.cmp(&r) {
			Less => Some(Err([l, r])),
			Equal => Some(Ok(l.map(|x| [x.into(), 1]))),
			Greater => None,
		}
	}
}

The code below only checks if two lines meet, without calculating where they do.

/// Checks if line segments `p` and `q` intersect.
/// Returns `true` if they intersect at any point, `false` otherwise.
fn meets(p: L, q: L) -> bool {
	let u = cross(sub(p[1], p[0]), sub(q[1], q[0]));
	let sn = cross(sub(q[0], p[0]), sub(q[1], q[0]));
	let tn = cross(sub(q[0], p[0]), sub(p[1], p[0]));
	if u != 0 {
		let int = if u >= 0 { 0..=u } else { u..=0 };
		int.contains(&sn) && int.contains(&tn)
	} else {
		if sn != 0 || tn != 0 {
			return false;
		}
		let (a0, a1) = (p[0].min(p[1]), p[0].max(p[1]));
		let (b0, b1) = (q[0].min(q[1]), q[0].max(q[1]));
		let (l, r) = (a0.max(b0), a1.min(b1));
		l <= r
	}
}

Last modified on 240119.

Closest Points

Code

/// Returns the two farthest points from `p` by indices, or `None` if `p.len() <= 1`.
fn closest_pair(p: &[P]) -> Option<[usize; 2]> {
	let mut u: Vec<usize> = (0..p.len()).collect();
	u.sort_unstable_by_key(|&i| p[i]);
	closest_pair_recur(p, &mut u).map(|x| x.1)
}

/// Sorts `p` by y coord of each point, and returns the closest pair and their distance squared.
fn closest_pair_recur(p: &[P], u: &mut [usize]) -> Option<(I, [usize; 2])> {
	if u.len() <= 1 {
		return None;
	}
	let dist = |a: P, b: P| dot(sub(a, b), sub(a, b));

	// Divide
	let m = u.len() / 2;
	let pivot = p[u[m]][0];
	let (l, r) = u.split_at_mut(m);
	let [xl, xr] = [l, r].map(|x| closest_pair_recur(p, x));
	let mut d = [xl, xr].into_iter().flatten().min_by_key(|x| x.0);
	// Now l and r are sorted by y coords

	// Merge two lists into one while keeping the y coords sorted
	let (l, r) = u.split_at(m);
	let [mut lit, mut rit] = [l, r].map(|x| x.to_vec().into_iter().peekable());
	let mut x = u.iter_mut();
	while let (Some(&l), Some(&r)) = (lit.peek(), rit.peek()) {
		*(x.next().unwrap()) = if p[l][1] <= p[r][1] {
			lit.next();
			l
		} else {
			rit.next();
			r
		};
	}
	lit.for_each(|l| *(x.next().unwrap()) = l);
	rit.for_each(|r| *(x.next().unwrap()) = r);
	// Now p is sorted by y coords

	// Conquer
	let mut hist: [Option<usize>; 3] = [None; 3];
	let mut ptr = (0..3).cycle();
	for &i in u.iter() {
		if d.map_or(true, |(d, _)| (p[i][0] - pivot).pow(2) < d) {
			let j = hist.iter().flatten().min_by_key(|&&j| dist(p[i], p[j]));
			if let Some(&j) = j {
				let nd = dist(p[i], p[j]);
				if d.map_or(true, |(d, _)| nd < d) {
					d = Some((nd, [i, j]));
				}
			}
			hist[ptr.next().unwrap()] = Some(i);
		}
	}

	d
}

Last modified on 231225.

Farthest Points

Code

/// Returns the two farthest points from `p`, or `None` if `p` is empty.
fn farthest_points(p: &[P]) -> Option<[P; 2]> {
	let dist = |a: P, b: P| dot(sub(a, b), sub(a, b));

	let cvh = convex_hull::<false>(p);
	if cvh.len() == 1 {
		return Some([cvh[0], cvh[0]]);
	} else if cvh.len() == 2 {
		return Some([cvh[0], cvh[1]]);
	}

	fn line_iterator(poly: &[P]) -> impl Iterator<Item = [P; 2]> + '_ {
		let mut it = poly.iter().copied().cycle().peekable();
		iter::from_fn(move || {
			let u = it.next().unwrap();
			let v = *it.peek().unwrap();
			Some([u, v])
		})
	}

	let mut r = line_iterator(&cvh).skip(1).peekable();
	let mut rec = None;
	for [la, lb] in line_iterator(&cvh).take(cvh.len()) {
		let lv = sub(lb, la);
		while {
			let &[ra, rb] = r.peek().unwrap();
			cross(lv, sub(rb, ra)) > 0
		} {
			r.next();
		}
		let &[ra, _] = r.peek().unwrap();
		let ch = [la, ra];
		let d = dist(ch[0], ch[1]);
		if rec.map_or(true, |(x, _)| x < d) {
			rec = Some((d, ch));
		}
	}

	rec.map(|x| x.1)
}

Last modified on 231225.

Point in a Polygon

Code

/// Checks if point `p` is within the polygon `poly`.
/// Returns `None` for points on edges, `Some(true)` for inside, and `Some(false)` for outside.
/// Requires `poly` to be a non-self-intersecting polygon. Orientation does not matter.
fn is_inside(p: P, poly: &[P]) -> Option<bool> {
	let n = poly.len();
	let mut it = poly.iter().copied().cycle().peekable();
	let mut nxt = || [it.next().unwrap(), *it.peek().unwrap()];

	for l in (0..n).map(|_| nxt()) {
		if meets(l, [p, p]) {
			return None;
		}
	}

	let cnt = (0..n)
		.map(|_| nxt())
		.filter(|&l| {
			let half = (l[0][1] < p[1]) != (l[1][1] < p[1]);
			let touch = meets(l, [p, [p[0].max(l[0][0]).max(l[1][0]), p[1]]]);
			half && touch
		})
		.count();
	Some(cnt % 2 == 1)
}

Last modified on 231225.

Point in a Convex Polygon

Code

/// Checks if point `p` is within the convex polygon `poly`.
/// Returns `None` for points on edges, `Some(true)` for inside, `Some(false)` for outside.
/// Requires `poly` to be a convex polygon oriented in CCW.
fn is_inside_convex(p: P, poly: &[P]) -> Option<bool> {
	use Ordering::*;
	if poly.len() == 1 {
		return if poly[0] == p { None } else { Some(false) };
	}
	let cmp = |a, b| angle_cmp(sub(poly[1], poly[0]))(&sub(a, poly[0]), &sub(b, poly[0]));

	let i = poly.partition_point(|&c| cmp(p, c) != Ordering::Less);
	if i == poly.len() {
		if p == poly[i - 1] {
			None
		} else {
			Some(false)
		}
	} else {
		match ccw(poly[i - 1], p, poly[i]).cmp(&0) {
			Less => Some(true),
			Equal => None,
			Greater => Some(false),
		}
	}
}

Last modified on 231225.

Mincowski Sum of Convex Polygons

Code

/// Returns the Minkowski sum of two convex polygons `p` and `q`.
/// Both `p` and `q` must be oriented in CCW and should not contain two equal points.
/// The returned polygon is also oriented in CCW.
fn minkowski_convex(p: &[P], q: &[P]) -> Vec<P> {
	fn line_iterator(poly: &[P]) -> impl Iterator<Item = [P; 2]> + '_ {
		let mut it = poly.iter().copied().cycle().peekable();
		iter::from_fn(move || {
			let u = it.next().unwrap();
			let v = *it.peek().unwrap();
			Some([u, v])
		})
	}

	let pi = (0..p.len()).min_by_key(|&i| p[i]).unwrap();
	let qi = (0..q.len()).min_by_key(|&i| q[i]).unwrap();
	let mut pl = line_iterator(p).skip(pi).enumerate().peekable();
	let mut ql = line_iterator(q).skip(qi).enumerate().peekable();

	let mut ret = vec![];
	while let (Some(&(pc, pp)), Some(&(qc, qq))) = (pl.peek(), ql.peek()) {
		if pc >= p.len() && qc >= q.len() {
			break;
		}
		ret.push(add(pp[0], qq[0]));
		let pcq = cross(sub(pp[1], pp[0]), sub(qq[1], qq[0]));
		if pcq >= 0 {
			pl.next();
		}
		if pcq <= 0 {
			ql.next();
		}
	}
	ret
}

Last modified on 231225.

Value Compression

Example

fn main() {
let arr: Vec<i32> = vec![1, 2, 4, 7, 9, 7, 4, 2, 1];

let (compressor, reevaluator) = compress_value(&arr);

let compr: Vec<usize> = arr.iter().map(|x| *compressor.get(x).unwrap()).collect();
println!("{:?}", compr);    // [0, 1, 2, 3, 4, 3, 2, 1, 0]

let original: Vec<i32> = compr.iter().map(|&i| *reevaluator[i]).collect();
println!("{:?}", original); // [1, 2, 4, 7, 9, 7, 4, 2, 1]
}

/// compressor[original_value] = compressed_value
/// reevaluator[compressed_value] = original_value
fn compress_value<T: Ord>(arr: &[T]) -> (std::collections::BTreeMap<&T, usize>, Vec<&T>) {
    use std::collections::*;
    let compressor: BTreeMap<&T, usize> = {
        let mut sorted: Vec<_> = arr.iter().collect();
        sorted.sort_unstable();
        sorted.dedup();
        sorted.into_iter().enumerate().map(|x| (x.1, x.0)).collect()
    };
    let reevaluator: Vec<&T> = compressor.iter().map(|x| *x.0).collect();
    (compressor, reevaluator)
}

Code

/// compressor[original_value] = compressed_value
/// reevaluator[compressed_value] = original_value
fn compress_value<T: Ord>(arr: &[T]) -> (std::collections::BTreeMap<&T, usize>, Vec<&T>) {
    use std::collections::*;
    let compressor: BTreeMap<&T, usize> = {
        let mut sorted: Vec<_> = arr.iter().collect();
        sorted.sort_unstable();
        sorted.dedup();
        sorted.into_iter().enumerate().map(|x| (x.1, x.0)).collect()
    };
    let reevaluator: Vec<&T> = compressor.iter().map(|x| *x.0).collect();
    (compressor, reevaluator)
}

Longest Increasing Subsequence

Length

fn lis_len(arr: &[i64]) -> usize {
    let mut table: Vec<i64> = vec![arr[0]];
    for &v in arr[1..].iter() {
        let p = table.partition_point(|&x| x < v);
        if p == table.len() {
            table.push(v);
        } else {
            table[p] = v;
        }
    }
    table.len()
}

Sequence

fn lis(arr: &[i64]) -> Vec<i64> {
    let n = arr.len();
    let mut seq: Vec<i64> = Vec::with_capacity(n + 1);
    seq.push(i64::MIN);
    seq.extend(arr.iter().copied());

    let mut back = vec![0usize; n + 1];
    let mut table = vec![0usize];

    for (i, &v) in seq.iter().enumerate().skip(1) {
        let p = table.partition_point(|&x| seq[x] < v);
        if p == table.len() {
            table.push(i);
        } else {
            table[p] = i;
        }
        back[i] = table[p - 1];
    }

    let mut ptr = *table.last().unwrap();
    let mut ans: Vec<i64> = Vec::with_capacity(table.len() - 1);
    while ptr != 0 {
        ans.push(seq[ptr]);
        ptr = back[ptr];
    }

    ans.reverse();
    ans
}

Mo's

Mo's with Hilbert Curve Optimization

Reference: https://codeforces.com/blog/entry/61203

/// max_n: maximum number of l and r
/// queries: Vec<(id, l, r)>
fn mos_sort(max_n: u32, queries: &[(u32, u32, u32)]) -> Vec<&(u32, u32, u32)> {
    let n_bit = ceil_pow_2(max_n + 1).trailing_zeros();
    let mut arr: Vec<(u64, &(u32, u32, u32))> = queries.iter().map(|q| (0, q)).collect();
    for q in arr.iter_mut() {
        q.0 = hilbert_order(q.1 .1, q.1 .2, n_bit, 0);
    }
    arr.sort_unstable_by_key(|q| q.0);
    arr.into_iter().map(|x| x.1).collect()
}

#[inline(always)]
fn hilbert_order(x: u32, y: u32, pow: u32, rotate: u32) -> u64 {
    if pow == 0 {
        return 0;
    }
    let hpow: u32 = 1 << (pow - 1);
    let mut seg: u32 = if x < hpow {
        if y < hpow {
            0
        } else {
            3
        }
    } else {
        if y < hpow {
            1
        } else {
            2
        }
    };
    seg = (seg + rotate) & 3;

    let (nx, ny) = (x & (x ^ hpow), y & (y ^ hpow));
    let nrot = rotate + ROTATE_DELTA[seg as usize] & 3;
    let sub_square_size = 1u64 << (2 * pow - 2);
    let ans = seg as u64 * sub_square_size;
    let add = hilbert_order(nx, ny, pow - 1, nrot);
    if seg == 1 || seg == 2 {
        ans + add
    } else {
        ans + sub_square_size - add - 1
    }
}

const ROTATE_DELTA: [u32; 4] = [3, 0, 0, 1];

#[inline(always)]
fn ceil_pow_2(y: u32) -> u32 {
    let mut x = y;
    while x != (x & ((!x) + 1)) {
        x -= x & ((!x) + 1);
    }
    if x == y {
        x
    } else {
        x << 1
    }
}

Standard Mo's

/// queries: Vec<(id, l, r)>
fn mos_sort(queries: &mut [(u32, u32, u32)]) {
    let nsq = isqrt(queries.len() as u32);
    queries.sort_unstable_by(|&(_, l1, r1), &(_, l2, r2)| {
        if l1 / nsq == l2 / nsq {
            r1.cmp(&r2)
        } else {
            (l1 / nsq).cmp(&(l2 / nsq))
        }
    });
}

fn isqrt(s: u32) -> u32 {
    let mut x0 = s / 2;
    if x0 != 0 {
        let mut x1 = (x0 + s / x0) / 2;
        while x1 < x0 {
            x0 = x1;
            x1 = (x0 + s / x0) / 2;
        }
        x0
    } else {
        s
    }
}

Arbitrary-Precision Integer

Example

use bigint::*;

let _a = Int::from(0i8);
let b = Int::from(4i16);
let c = Int::from(11i32);

let mut x = Int::from(30i64);
x *= &b;
println!("{}", x);

let mut y = Int::from_str("123456789123456789123456789123456789").unwrap();
let z = &y * &c;
println!("{}", y);

Code

mod bigint {
    use core::{
        fmt::Display,
        num::ParseIntError,
        ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
        str::FromStr,
    };

    const CHUNK: usize = 5;
    const TENS: i64 = 100000;

    #[derive(Clone, Default, Debug, PartialEq, Eq)]
    pub struct Uint(Vec<i64>);

    macro_rules! flatten {
        ($uint:expr) => {
            let mut carry: i64 = 0;
            for i in 0..$uint.0.len() {
                $uint.0[i] += carry;
                carry = $uint.0[i].div_euclid(TENS);
                $uint.0[i] -= carry * TENS;
            }
            while carry != 0 {
                $uint.0.push(carry.rem_euclid(TENS));
                carry = carry.div_euclid(TENS);
            }
            while let Some(&x) = $uint.0.last() {
                if x != 0 {
                    break;
                }
                $uint.0.pop();
            }
        };
    }

    macro_rules! impl_from_for_uint {
        ($($t:ty),*) => {
            $(
                impl From<$t> for Uint {
                    fn from(x: $t) -> Self {
                        let mut x = Self(vec![x as i64]);
                        flatten!(x);
                        x
                    }
                }
            )*
        };
    }
    impl_from_for_uint!(u8, u16, u32, u64, u128, usize);

    impl FromStr for Uint {
        type Err = ParseIntError;
        fn from_str(s: &str) -> Result<Self, Self::Err> {
            let s = s.trim_start_matches("0");
            if s.is_empty() {
                return Ok(Self(vec![]));
            }
            let mut arr: Vec<i64> = Vec::with_capacity(s.len() / CHUNK + 2);
            let mut s = s;
            while s.len() > CHUNK {
                let (l, r) = s.split_at(s.len() - CHUNK);
                arr.push(r.parse()?);
                s = l;
            }
            arr.push(s.parse()?);
            Ok(Self(arr))
        }
    }

    impl Display for Uint {
        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
            write!(f, "{}", *self.0.last().unwrap_or(&0))?;
            for &v in self.0.iter().rev().skip(1) {
                write!(f, "{:0CHUNK$}", v)?;
            }
            Ok(())
        }
    }

    impl PartialOrd for Uint {
        fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
            use core::cmp::Ordering;
            match self.0.len().cmp(&other.0.len()) {
                Ordering::Equal => {
                    for i in (0..self.0.len()).rev() {
                        let x = self.0[i].cmp(&other.0[i]);
                        if x != Ordering::Equal {
                            return Some(x);
                        }
                    }
                    Some(Ordering::Equal)
                }
                x => Some(x),
            }
        }
    }

    impl Ord for Uint {
        fn cmp(&self, other: &Self) -> core::cmp::Ordering {
            use core::cmp::Ordering;
            match self.0.len().cmp(&other.0.len()) {
                Ordering::Equal => {
                    for i in (0..self.0.len()).rev() {
                        let x = self.0[i].cmp(&other.0[i]);
                        if x != Ordering::Equal {
                            return x;
                        }
                    }
                    Ordering::Equal
                }
                x => x,
            }
        }
    }

    impl AddAssign<&Uint> for Uint {
        fn add_assign(&mut self, rhs: &Uint) {
            if self.0.len() < rhs.0.len() {
                for i in 0..self.0.len() {
                    self.0[i] += rhs.0[i];
                }
                self.0.extend_from_slice(&rhs.0[self.0.len()..]);
            } else {
                for i in 0..rhs.0.len() {
                    self.0[i] += rhs.0[i];
                }
            }

            flatten!(self);
        }
    }

    impl Add for &Uint {
        type Output = Uint;
        fn add(self, rhs: Self) -> Self::Output {
            let mut c = self.clone();
            c += rhs;
            c
        }
    }

    impl SubAssign<&Uint> for Uint {
        fn sub_assign(&mut self, rhs: &Uint) {
            // Panics if self.len() < rhs.len(): Think it as a underflow error
            for (i, &v) in rhs.0.iter().enumerate() {
                self.0[i] -= v;
            }

            flatten!(self);
        }
    }

    impl Sub for &Uint {
        type Output = Uint;
        fn sub(self, rhs: Self) -> Self::Output {
            let mut c = self.clone();
            c -= rhs;
            c
        }
    }

    const NTT_THRES: usize = 5000;
    const KARAT_THRES: usize = 30;

    impl Mul for &Uint {
        type Output = Uint;
        fn mul(self, rhs: Self) -> Self::Output {
            let max_len = self.0.len().max(rhs.0.len());
            let max_2len = polymul::ceil_pow2(max_len);

            // For performance reasons regarding vector copying, we determine whether to use
            // NTT or not here.
            let mut ans = Uint(if max_2len > NTT_THRES {
                polymul::convolute(&self.0, &rhs.0)
            } else {
                let f: Vec<i64> = self
                    .0
                    .iter()
                    .copied()
                    .chain(core::iter::repeat(0))
                    .take(max_2len)
                    .collect();

                let g: Vec<i64> = rhs
                    .0
                    .iter()
                    .copied()
                    .chain(core::iter::repeat(0))
                    .take(max_2len)
                    .collect();

                polymul::mult_2pow(&f, &g)
            });

            flatten!(ans);
            ans
        }
    }

    impl MulAssign<&Uint> for Uint {
        fn mul_assign(&mut self, rhs: &Uint) {
            let x = &*self * rhs;
            *self = x;
        }
    }

    mod polymul {
        pub fn ceil_pow2(n: usize) -> usize {
            if n == 0 {
                return 0;
            }
            let mut m = n;
            while m != m & (!m + 1) {
                m -= m & (!m + 1);
            }
            if n == m {
                n
            } else {
                m * 2
            }
        }

        pub fn mult_2pow(f: &[i64], g: &[i64]) -> Vec<i64> {
            if f.len() > super::KARAT_THRES {
                return karatsuba(f, g);
            }

            let mut ans = vec![0; 2 * f.len()];
            for (i, &a) in f.iter().enumerate() {
                for (j, &b) in g.iter().enumerate() {
                    ans[i + j] += a * b;
                }
            }

            ans
        }

        // Length of f = Length of g = 2n = 2^(k+1)
        fn karatsuba(f: &[i64], g: &[i64]) -> Vec<i64> {
            if f.len() == 1 {
                return vec![f[0] * g[0]];
            }
            let n = f.len() / 2;
            let k = n.trailing_zeros();
            debug_assert_eq!(n, 1 << k);

            let (fl, fr) = (&f[..n], &f[n..]);
            let (gl, gr) = (&g[..n], &g[n..]);

            let flgl = mult_2pow(fl, gl);
            let frgr = mult_2pow(fr, gr);

            let fsum: Vec<_> = fl.iter().zip(fr.iter()).map(|(&a, &b)| (a + b)).collect();
            let gsum: Vec<_> = gl.iter().zip(gr.iter()).map(|(&a, &b)| (a + b)).collect();
            let fsgs = mult_2pow(&fsum, &gsum);

            let mut ans: Vec<_> = flgl.iter().copied().chain(frgr.iter().copied()).collect();
            for i in 0..fsgs.len() {
                ans[i + n] += fsgs[i];
            }
            for (i, v) in flgl
                .iter()
                .zip(frgr.iter())
                .map(|(&a, &b)| (a + b))
                .enumerate()
            {
                ans[i + n] -= v;
            }

            ans
        }

        const P2INV: i64 = 253522377;

        pub fn convolute(a: &[i64], b: &[i64]) -> Vec<i64> {
            let c1 = ntt1::convolute(a, b);
            let c2 = ntt2::convolute(a, b);

            c1.into_iter()
                .zip(c2.into_iter())
                .map(|(a1, a2)| {
                    let j = ((a1 + ntt1::NTT_P as i64 - a2) * P2INV) % ntt1::NTT_P as i64;
                    ntt2::NTT_P as i64 * j + a2
                })
                .collect()
        }

        // FFT_constname convention following https://algoshitpo.github.io/2020/05/20/fft-ntt/
        macro_rules! impl_ntt {
            ($modname:ident, $nttp:expr, $ntta:expr, $nttb:expr, $nttw:expr) => {
                mod $modname {
                    pub const NTT_P: u64 = $nttp;
                    const NTT_A: u64 = $ntta;
                    const NTT_B: u32 = $nttb;
                    const NTT_W: u64 = $nttw;

                    fn ceil_pow2(n: usize) -> usize {
                        let mut x: usize = 0;
                        while (1 << x) < n {
                            x += 1;
                        }
                        x
                    }

                    pub fn convolute(a: &[i64], b: &[i64]) -> Vec<i64> {
                        let nlen = 1 << ceil_pow2(a.len() + b.len());
                        let mut arr = vec![0; nlen];
                        let mut brr = vec![0; nlen];
                        for (i, &a) in a.iter().enumerate() {
                            arr[i] = a as u64;
                        }
                        for (i, &b) in b.iter().enumerate() {
                            brr[i] = b as u64;
                        }

                        inplace_ntt(&mut arr);
                        inplace_ntt(&mut brr);
                        let mut crr: Vec<_> =
                            arr.iter().zip(brr.iter()).map(|(&a, &b)| a * b).collect();
                        inplace_intt(&mut crr);
                        crr.iter().map(|&x| x as i64).collect()
                    }

                    #[inline(always)]
                    fn rem_pow(mut base: u64, exp: u64) -> u64 {
                        let mut result = 1u64;
                        for exp in core::iter::successors(Some(exp), |x| Some(x >> 1))
                            .take_while(|&v| v != 0)
                        {
                            if exp & 1 != 0 {
                                result *= base;
                                result %= NTT_P;
                            }
                            base *= base;
                            base %= NTT_P;
                        }
                        result
                    }

                    // unity(n, 1) ** (1<<n) = 1
                    fn unity(n: u32, k: u64) -> u64 {
                        rem_pow(rem_pow(NTT_W, NTT_A), k << (NTT_B - n))
                    }

                    fn recip(x: u64) -> u64 {
                        rem_pow(x, NTT_P - 2)
                    }

                    // Reverses k trailing bits of n
                    fn reverse_trailing_bits(n: usize, k: u32) -> usize {
                        let mut r: usize = 0;
                        for i in 0..k {
                            r |= ((n >> i) & 1) << (k - i - 1);
                        }
                        r
                    }

                    fn inplace_ntt(arr: &mut [u64]) {
                        let n: usize = arr.len();
                        let k = n.trailing_zeros();
                        assert_eq!(n, 1 << k);

                        for i in 0..n {
                            let j = reverse_trailing_bits(i, k);
                            if i < j {
                                arr.swap(i, j);
                            }
                        }

                        for x in 0..k {
                            let base: u64 = unity(x + 1, 1);
                            let s = 1 << x;
                            for i in (0..n).step_by(s << 1) {
                                let mut mult: u64 = 1;
                                for j in 0..s {
                                    let tmp = (arr[i + j + s] * mult) % NTT_P;
                                    arr[i + j + s] = (arr[i + j] + NTT_P - tmp) % NTT_P;
                                    arr[i + j] = (arr[i + j] + tmp) % NTT_P;
                                    mult *= base;
                                    mult %= NTT_P;
                                }
                            }
                        }
                    }

                    fn inplace_intt(arr: &mut [u64]) {
                        let n: usize = arr.len();
                        let k = n.trailing_zeros();
                        assert_eq!(n, 1 << k);

                        for i in 0..n {
                            let j = reverse_trailing_bits(i, k);
                            if i < j {
                                arr.swap(i, j);
                            }
                        }

                        for x in 0..k {
                            let base: u64 = recip(unity(x + 1, 1));
                            let s = 1 << x;
                            for i in (0..n).step_by(s << 1) {
                                let mut mult: u64 = 1;
                                for j in 0..s {
                                    let tmp = (arr[i + j + s] * mult) % NTT_P;
                                    arr[i + j + s] = (arr[i + j] + NTT_P - tmp) % NTT_P;
                                    arr[i + j] = (arr[i + j] + tmp) % NTT_P;
                                    mult *= base;
                                    mult %= NTT_P;
                                }
                            }
                        }

                        let r = recip(n as u64);
                        for f in arr.iter_mut() {
                            *f *= r;
                            *f %= NTT_P;
                        }
                    }
                }
            };
        }

        impl_ntt!(ntt1, 2281701377, 17, 27, 3);
        impl_ntt!(ntt2, 998244353, 119, 23, 3);
    }

    #[derive(Clone, Copy, PartialEq, Eq, Debug)]
    enum Sign {
        Neg,
        Pos, // Includes 0
    }
    use Sign::*;

    #[derive(Clone, Debug, PartialEq, Eq)]
    pub struct Int {
        sign: Sign,
        nat: Uint,
    }

    macro_rules! impl_from_for_int {
        ($($u:ty, $s:ty);*) => {
            $(
                impl From<$u> for Int {
                    fn from(x: $u) -> Self {
                        Self { sign: Pos, nat: x.into() }
                    }
                }
                impl From<$s> for Int {
                    fn from(x: $s) -> Self {
                        if x < 0 {
                            Self { sign: Neg, nat: ((-x) as $u).into() }
                        } else {
                            Self { sign: Pos, nat: (x as $u).into() }
                        }
                    }
                }
            )*
        };
    }

    impl_from_for_int!(u8, i8; u16, i16; u32, i32; u64, i64; u128, i128; usize, isize);

    impl FromStr for Int {
        type Err = ParseIntError;
        fn from_str(s: &str) -> Result<Self, Self::Err> {
            if s.len() == 0 {
                panic!("Empty string - TODO: Add a proper error propagation");
            }
            let mut x = match s.strip_prefix("-") {
                Some(t) => Self {
                    sign: Neg,
                    nat: t.parse()?,
                },
                None => Self {
                    sign: Pos,
                    nat: s.parse()?,
                },
            };
            if x.sign == Neg && x.nat.0.len() == 0 {
                x.sign = Pos;
            }
            Ok(x)
        }
    }

    impl Display for Int {
        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
            if let Neg = self.sign {
                write!(f, "-")?;
            }
            write!(f, "{}", self.nat)
        }
    }

    impl PartialOrd for Int {
        fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
            use core::cmp::Ordering;
            match (self.sign, other.sign) {
                (Neg, Neg) => other.nat.partial_cmp(&self.nat),
                (Neg, Pos) => Some(Ordering::Less),
                (Pos, Neg) => Some(Ordering::Greater),
                (Pos, Pos) => self.nat.partial_cmp(&other.nat),
            }
        }
    }

    impl Ord for Int {
        fn cmp(&self, other: &Self) -> core::cmp::Ordering {
            use core::cmp::Ordering;
            match (self.sign, other.sign) {
                (Neg, Neg) => other.nat.cmp(&self.nat),
                (Neg, Pos) => Ordering::Less,
                (Pos, Neg) => Ordering::Greater,
                (Pos, Pos) => self.nat.cmp(&other.nat),
            }
        }
    }

    impl AddAssign<&Int> for Int {
        fn add_assign(&mut self, rhs: &Int) {
            match (self.sign, rhs.sign) {
                (Neg, Neg) => {
                    self.nat += &rhs.nat;
                }
                (Neg, Pos) => {
                    if self.nat >= rhs.nat {
                        self.nat -= &rhs.nat;
                    } else {
                        let c = &rhs.nat - &self.nat;
                        self.nat = c;
                        self.sign = Pos;
                    }
                    if self.nat.0.len() == 0 {
                        self.sign = Pos;
                    }
                }
                (Pos, Neg) => {
                    if self.nat >= rhs.nat {
                        self.nat -= &rhs.nat;
                    } else {
                        let c = &rhs.nat - &self.nat;
                        self.nat = c;
                        self.sign = Neg;
                    }
                    if self.nat.0.len() == 0 {
                        self.sign = Pos;
                    }
                }
                (Pos, Pos) => {
                    self.nat += &rhs.nat;
                }
            }
        }
    }

    impl Add for &Int {
        type Output = Int;
        fn add(self, rhs: Self) -> Self::Output {
            let mut ans = self.clone();
            ans += rhs;
            ans
        }
    }

    impl SubAssign<&Int> for Int {
        fn sub_assign(&mut self, rhs: &Int) {
            match (self.sign, rhs.sign) {
                (Neg, Pos) => {
                    self.nat += &rhs.nat;
                }
                (Neg, Neg) => {
                    if self.nat >= rhs.nat {
                        self.nat -= &rhs.nat;
                    } else {
                        let c = &rhs.nat - &self.nat;
                        self.nat = c;
                        self.sign = Pos;
                    }
                    if self.nat.0.len() == 0 {
                        self.sign = Pos;
                    }
                }
                (Pos, Pos) => {
                    if self.nat >= rhs.nat {
                        self.nat -= &rhs.nat;
                    } else {
                        let c = &rhs.nat - &self.nat;
                        self.nat = c;
                        self.sign = Neg;
                    }
                    if self.nat.0.len() == 0 {
                        self.sign = Pos;
                    }
                }
                (Pos, Neg) => {
                    self.nat += &rhs.nat;
                }
            }
        }
    }

    impl Sub for &Int {
        type Output = Int;
        fn sub(self, rhs: Self) -> Self::Output {
            let mut ans = self.clone();
            ans -= &rhs;
            ans
        }
    }

    impl Mul for &Int {
        type Output = Int;
        fn mul(self, rhs: Self) -> Self::Output {
            let x = &self.nat * &rhs.nat;
            if x.0.len() == 0 || self.sign == rhs.sign {
                Int { sign: Pos, nat: x }
            } else {
                Int { sign: Neg, nat: x }
            }
        }
    }

    impl MulAssign<&Int> for Int {
        fn mul_assign(&mut self, rhs: &Int) {
            let x = &self.nat * &rhs.nat;
            self.nat = x;
            if self.nat.0.len() == 0 || self.sign == rhs.sign {
                self.sign = Pos;
            } else {
                self.sign = Neg;
            }
        }
    }
}

Fraction

TODO: description

Example

fn main() {
use frac::*;
let mut a = Frac::new(8374927, 2983178).simplify();
println!("{a}");

let b: i64 = 31;
a += b;
println!("{a}");

let af: f64 = a.into();
println!("{af:.5}");

let (l, r) = a.lower_den(100);
println!("{l} <= {a} <= {r}");
let (lf, rf): (f64, f64) = (l.into(), r.into());
println!("{lf:.5} {rf:.5}");
}

mod frac {
    /// Note: Basic arithmetics on the fraction types do not simplify the fraction to reduce calls of GCD.
    /// Simplifications should all be done manually.
    use std::{fmt::Display, ops::*};

    /// Numerator type
    pub type I = i64;
    /// Denominator type
    pub type U = u64;

    /// Fraction type.
    #[derive(Clone, Copy, Debug)]
    pub struct Frac {
        /// Numerator
        pub num: I,
        /// Denominator
        pub den: U,
    }

    impl Frac {
        /// Simplifies the fraction to the minimum denomicator.
        pub fn simplify(self) -> Self {
            fn gcd(mut a: U, mut b: U) -> U {
                while b != 0 {
                    (a, b) = (b, a % b);
                }
                a
            }
            let g = gcd(self.num.unsigned_abs(), self.den);
            Self {
                num: self.num / g as I,
                den: self.den / g,
            }
        }

        /// Returns a fraction from a given numerator and denominator
        pub fn new(num: I, den: I) -> Self {
            debug_assert_ne!(den, 0);
            if den < 0 {
                Self {
                    num: -num,
                    den: (-den) as U,
                }
            } else {
                Self { num, den: den as U }
            }
        }

        /// Returns a reciprocal of the fraction
        pub fn recip(self) -> Self {
            use std::cmp::Ordering::*;
            match self.num.cmp(&0) {
                Less => Self {
                    num: -(self.den as I),
                    den: (-self.num) as U,
                },
                Equal => panic!("Reciprocal of zero"),
                Greater => Self {
                    num: self.den as I,
                    den: self.num as U,
                },
            }
        }

        /// Returns a floor of the fraction in an integer form
        pub fn floor(self) -> I {
            let Self { num, den } = self;
            let den = den as I;
            num.div_euclid(den)
        }

        /// Returns a ceil of the fraction in an integer form
        pub fn ceil(self) -> I {
            let Self { num, den } = self;
            let den = den as I;
            (num + den - 1).div_euclid(den)
        }

        /// Returns a rounded fraction in an integer form
        pub fn round(self) -> I {
            let Self { num, den } = self;
            let den = den as I;
            (2 * num + den).div_euclid(2 * den)
        }

        /// Returns self - self.floor()
        pub fn fract(self) -> Self {
            self - self.floor()
        }

        /// Returns two closest fractions to `x` given a maximum possible value for denominators.
        /// If the fraction is equal to `x` when converted to f64, then the both bounds are equal.
        /// This behavior is subject to change for more accurate approximation.
        pub fn wrap(x: f64, max_den: U) -> (Self, Self) {
            let ipart = x.floor() as I;
            let d = x.fract();
            if d == 0. {
                return (ipart.into(), ipart.into());
            }

            let [(mut ln, mut ld), (mut rn, mut rd)]: [(U, U); 2] = [(0, 1), (1, 1)];
            while (ln, ld) != (rn, rd) {
                let (pl, pr) = ((ln, ld), (rn, rd));

                // Update l
                let k1 = (ld as f64 * d - ln as f64).div_euclid(rn as f64 - rd as f64 * d) as U;
                let k2 = (max_den - ld).div_euclid(rd);
                let k = k1.min(k2);
                ln += k * rn;
                ld += k * rd;

                // Update r
                let k1 = (rn as f64 - rd as f64 * d).div_euclid(ld as f64 * d - ln as f64) as U;
                let k2 = (max_den - rd).div_euclid(ld);
                let k = k1.min(k2);
                rn += k * ln;
                rd += k * ld;

                if pl == (ln, ld) && pr == (rn, rd) {
                    break;
                }
            }

            let l = Self::new(ln as I, ld as I) + ipart;
            let r = Self::new(rn as I, rd as I) + ipart;
            if x == l.into() {
                (l, l)
            } else if x == r.into() {
                (r, r)
            } else {
                (l, r)
            }
        }

        /// Returns two fractions `(l, r)` where `l <= self`, `r >= self`, and `l`, `r` both being
        /// the closest to `self` given a maximum value of denominators. This function can be
        /// used for approximating the fraction when the numberator or denominator is getting too
        /// large, but you don't need an exact value of the fraction.
        pub fn lower_den(self, max_den: U) -> (Self, Self) {
            if self.den <= max_den {
                return (self, self);
            }

            let ipart = self.floor();
            let Self { num: dn, den: dd } = self.fract();
            let dn = dn as U;

            let [(mut ln, mut ld), (mut rn, mut rd)]: [(U, U); 2] = [(0, 1), (1, 1)];
            while (ln, ld) != (rn, rd) {
                let (pl, pr) = ((ln, ld), (rn, rd));

                // Update l
                let k1 = (ld * dn - ln * dd).div_euclid(rn * dd - rd * dn);
                let k2 = (max_den - ld).div_euclid(rd);
                let k = k1.min(k2);
                ln += k * rn;
                ld += k * rd;

                // Update r
                let k1 = (rn * dd - rd * dn).div_euclid(ld * dn - ln * dd);
                let k2 = (max_den - rd).div_euclid(ld);
                let k = k1.min(k2);
                rn += k * ln;
                rd += k * ld;

                if pl == (ln, ld) && pr == (rn, rd) {
                    break;
                }
            }

            let l = Self::new(ln as I, ld as I) + ipart;
            let r = Self::new(rn as I, rd as I) + ipart;
            (l, r)
        }
    }

    impl Default for Frac {
        fn default() -> Self {
            Frac { num: 0, den: 1 }
        }
    }

    impl Display for Frac {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            write!(f, "{}/{}", self.num, self.den)
        }
    }

    impl From<I> for Frac {
        fn from(num: I) -> Self {
            Self { num, den: 1 }
        }
    }

    impl From<Frac> for f64 {
        fn from(value: Frac) -> Self {
            value.num as f64 / value.den as f64
        }
    }

    impl Neg for Frac {
        type Output = Self;
        fn neg(self) -> Self::Output {
            Self {
                num: -self.num,
                den: self.den,
            }
        }
    }

    impl Add for Frac {
        type Output = Self;
        fn add(self, rhs: Self) -> Self::Output {
            Self {
                num: self.num * rhs.den as I + self.den as I * rhs.num,
                den: self.den * rhs.den,
            }
        }
    }

    impl Sub for Frac {
        type Output = Self;
        fn sub(self, rhs: Self) -> Self::Output {
            Self {
                num: self.num * rhs.den as I - self.den as I * rhs.num,
                den: self.den * rhs.den,
            }
        }
    }

    impl Mul for Frac {
        type Output = Self;
        fn mul(self, rhs: Self) -> Self::Output {
            Self {
                num: self.num * rhs.num,
                den: self.den * rhs.den,
            }
        }
    }

    impl Div for Frac {
        type Output = Self;
        fn div(self, rhs: Self) -> Self::Output {
            self * rhs.recip()
        }
    }

    impl Add<I> for Frac {
        type Output = Self;
        fn add(self, rhs: I) -> Self::Output {
            let rhs: Frac = rhs.into();
            self + rhs
        }
    }

    impl Sub<I> for Frac {
        type Output = Self;
        fn sub(self, rhs: I) -> Self::Output {
            let rhs: Frac = rhs.into();
            self - rhs
        }
    }

    impl Mul<I> for Frac {
        type Output = Self;
        fn mul(self, rhs: I) -> Self::Output {
            let rhs: Frac = rhs.into();
            self * rhs
        }
    }

    impl Div<I> for Frac {
        type Output = Self;
        fn div(self, rhs: I) -> Self::Output {
            let rhs: Frac = rhs.into();
            self / rhs
        }
    }

    impl Add<Frac> for I {
        type Output = Frac;
        fn add(self, rhs: Frac) -> Self::Output {
            rhs + self
        }
    }

    impl Sub<Frac> for I {
        type Output = Frac;
        fn sub(self, rhs: Frac) -> Self::Output {
            -rhs + self
        }
    }

    impl Mul<Frac> for I {
        type Output = Frac;
        fn mul(self, rhs: Frac) -> Self::Output {
            rhs * self
        }
    }

    impl Div<Frac> for I {
        type Output = Frac;
        fn div(self, rhs: Frac) -> Self::Output {
            let lhs: Frac = self.into();
            lhs / rhs
        }
    }

    impl AddAssign for Frac {
        fn add_assign(&mut self, rhs: Self) {
            *self = *self + rhs;
        }
    }

    impl SubAssign for Frac {
        fn sub_assign(&mut self, rhs: Self) {
            *self = *self - rhs;
        }
    }

    impl MulAssign for Frac {
        fn mul_assign(&mut self, rhs: Self) {
            *self = *self * rhs;
        }
    }

    impl DivAssign for Frac {
        fn div_assign(&mut self, rhs: Self) {
            *self = *self / rhs;
        }
    }

    impl AddAssign<I> for Frac {
        fn add_assign(&mut self, rhs: I) {
            *self = *self + rhs;
        }
    }

    impl SubAssign<I> for Frac {
        fn sub_assign(&mut self, rhs: I) {
            *self = *self - rhs;
        }
    }

    impl MulAssign<I> for Frac {
        fn mul_assign(&mut self, rhs: I) {
            *self = *self * rhs;
        }
    }

    impl DivAssign<I> for Frac {
        fn div_assign(&mut self, rhs: I) {
            *self = *self / rhs;
        }
    }

    impl PartialEq for Frac {
        fn eq(&self, rhs: &Self) -> bool {
            (self.num * rhs.den as I).eq(&(rhs.num * self.den as I))
        }
    }

    impl Eq for Frac {}

    impl PartialOrd for Frac {
        fn partial_cmp(&self, rhs: &Self) -> Option<std::cmp::Ordering> {
            (self.num * rhs.den as I).partial_cmp(&(rhs.num * self.den as I))
        }
    }

    impl Ord for Frac {
        fn cmp(&self, rhs: &Self) -> std::cmp::Ordering {
            (self.num * rhs.den as I).cmp(&(rhs.num * self.den as I))
        }
    }

    impl PartialEq<I> for Frac {
        fn eq(&self, rhs: &I) -> bool {
            let rhs: Frac = (*rhs).into();
            self.eq(&rhs)
        }
    }

    impl PartialOrd<I> for Frac {
        fn partial_cmp(&self, rhs: &I) -> Option<std::cmp::Ordering> {
            let rhs: Frac = (*rhs).into();
            self.partial_cmp(&rhs)
        }
    }

    impl PartialEq<Frac> for I {
        fn eq(&self, rhs: &Frac) -> bool {
            let lhs: Frac = (*self).into();
            lhs.eq(rhs)
        }
    }

    impl PartialOrd<Frac> for I {
        fn partial_cmp(&self, rhs: &Frac) -> Option<std::cmp::Ordering> {
            let lhs: Frac = (*self).into();
            lhs.partial_cmp(rhs)
        }
    }
}

Code

mod frac {
    /// Note: Basic arithmetics on the fraction types do not simplify the fraction to reduce calls of GCD.
    /// Simplifications should all be done manually.
    use std::{fmt::Display, ops::*};

    /// Numerator type
    pub type I = i64;
    /// Denominator type
    pub type U = u64;

    /// Fraction type.
    #[derive(Clone, Copy, Debug)]
    pub struct Frac {
        /// Numerator
        pub num: I,
        /// Denominator
        pub den: U,
    }

    impl Frac {
        /// Simplifies the fraction to the minimum denomicator.
        pub fn simplify(self) -> Self {
            fn gcd(mut a: U, mut b: U) -> U {
                while b != 0 {
                    (a, b) = (b, a % b);
                }
                a
            }
            let g = gcd(self.num.unsigned_abs(), self.den);
            Self {
                num: self.num / g as I,
                den: self.den / g,
            }
        }

        /// Returns a fraction from a given numerator and denominator
        pub fn new(num: I, den: I) -> Self {
            debug_assert_ne!(den, 0);
            if den < 0 {
                Self {
                    num: -num,
                    den: (-den) as U,
                }
            } else {
                Self { num, den: den as U }
            }
        }

        /// Returns a reciprocal of the fraction
        pub fn recip(self) -> Self {
            use std::cmp::Ordering::*;
            match self.num.cmp(&0) {
                Less => Self {
                    num: -(self.den as I),
                    den: (-self.num) as U,
                },
                Equal => panic!("Reciprocal of zero"),
                Greater => Self {
                    num: self.den as I,
                    den: self.num as U,
                },
            }
        }

        /// Returns a floor of the fraction in an integer form
        pub fn floor(self) -> I {
            let Self { num, den } = self;
            let den = den as I;
            num.div_euclid(den)
        }

        /// Returns a ceil of the fraction in an integer form
        pub fn ceil(self) -> I {
            let Self { num, den } = self;
            let den = den as I;
            (num + den - 1).div_euclid(den)
        }

        /// Returns a rounded fraction in an integer form
        pub fn round(self) -> I {
            let Self { num, den } = self;
            let den = den as I;
            (2 * num + den).div_euclid(2 * den)
        }

        /// Returns self - self.floor()
        pub fn fract(self) -> Self {
            self - self.floor()
        }

        /// Returns two closest fractions to `x` given a maximum possible value for denominators.
        /// If the fraction is equal to `x` when converted to f64, then the both bounds are equal.
        /// This behavior is subject to change for more accurate approximation.
        pub fn wrap(x: f64, max_den: U) -> (Self, Self) {
            let ipart = x.floor() as I;
            let d = x.fract();
            if d == 0. {
                return (ipart.into(), ipart.into());
            }

            let [(mut ln, mut ld), (mut rn, mut rd)]: [(U, U); 2] = [(0, 1), (1, 1)];
            while (ln, ld) != (rn, rd) {
                let (pl, pr) = ((ln, ld), (rn, rd));

                // Update l
                let k1 = (ld as f64 * d - ln as f64).div_euclid(rn as f64 - rd as f64 * d) as U;
                let k2 = (max_den - ld).div_euclid(rd);
                let k = k1.min(k2);
                ln += k * rn;
                ld += k * rd;

                // Update r
                let k1 = (rn as f64 - rd as f64 * d).div_euclid(ld as f64 * d - ln as f64) as U;
                let k2 = (max_den - rd).div_euclid(ld);
                let k = k1.min(k2);
                rn += k * ln;
                rd += k * ld;

                if pl == (ln, ld) && pr == (rn, rd) {
                    break;
                }
            }

            let l = Self::new(ln as I, ld as I) + ipart;
            let r = Self::new(rn as I, rd as I) + ipart;
            if x == l.into() {
                (l, l)
            } else if x == r.into() {
                (r, r)
            } else {
                (l, r)
            }
        }

        /// Returns two fractions `(l, r)` where `l <= self`, `r >= self`, and `l`, `r` both being
        /// the closest to `self` given a maximum value of denominators. This function can be
        /// used for approximating the fraction when the numberator or denominator is getting too
        /// large, but you don't need an exact value of the fraction.
        pub fn lower_den(self, max_den: U) -> (Self, Self) {
            if self.den <= max_den {
                return (self, self);
            }

            let ipart = self.floor();
            let Self { num: dn, den: dd } = self.fract();
            let dn = dn as U;

            let [(mut ln, mut ld), (mut rn, mut rd)]: [(U, U); 2] = [(0, 1), (1, 1)];
            while (ln, ld) != (rn, rd) {
                let (pl, pr) = ((ln, ld), (rn, rd));

                // Update l
                let k1 = (ld * dn - ln * dd).div_euclid(rn * dd - rd * dn);
                let k2 = (max_den - ld).div_euclid(rd);
                let k = k1.min(k2);
                ln += k * rn;
                ld += k * rd;

                // Update r
                let k1 = (rn * dd - rd * dn).div_euclid(ld * dn - ln * dd);
                let k2 = (max_den - rd).div_euclid(ld);
                let k = k1.min(k2);
                rn += k * ln;
                rd += k * ld;

                if pl == (ln, ld) && pr == (rn, rd) {
                    break;
                }
            }

            let l = Self::new(ln as I, ld as I) + ipart;
            let r = Self::new(rn as I, rd as I) + ipart;
            (l, r)
        }
    }

    impl Default for Frac {
        fn default() -> Self {
            Frac { num: 0, den: 1 }
        }
    }

    impl Display for Frac {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            write!(f, "{}/{}", self.num, self.den)
        }
    }

    impl From<I> for Frac {
        fn from(num: I) -> Self {
            Self { num, den: 1 }
        }
    }

    impl From<Frac> for f64 {
        fn from(value: Frac) -> Self {
            value.num as f64 / value.den as f64
        }
    }

    impl Neg for Frac {
        type Output = Self;
        fn neg(self) -> Self::Output {
            Self {
                num: -self.num,
                den: self.den,
            }
        }
    }

    impl Add for Frac {
        type Output = Self;
        fn add(self, rhs: Self) -> Self::Output {
            Self {
                num: self.num * rhs.den as I + self.den as I * rhs.num,
                den: self.den * rhs.den,
            }
        }
    }

    impl Sub for Frac {
        type Output = Self;
        fn sub(self, rhs: Self) -> Self::Output {
            Self {
                num: self.num * rhs.den as I - self.den as I * rhs.num,
                den: self.den * rhs.den,
            }
        }
    }

    impl Mul for Frac {
        type Output = Self;
        fn mul(self, rhs: Self) -> Self::Output {
            Self {
                num: self.num * rhs.num,
                den: self.den * rhs.den,
            }
        }
    }

    impl Div for Frac {
        type Output = Self;
        fn div(self, rhs: Self) -> Self::Output {
            self * rhs.recip()
        }
    }

    impl Add<I> for Frac {
        type Output = Self;
        fn add(self, rhs: I) -> Self::Output {
            let rhs: Frac = rhs.into();
            self + rhs
        }
    }

    impl Sub<I> for Frac {
        type Output = Self;
        fn sub(self, rhs: I) -> Self::Output {
            let rhs: Frac = rhs.into();
            self - rhs
        }
    }

    impl Mul<I> for Frac {
        type Output = Self;
        fn mul(self, rhs: I) -> Self::Output {
            let rhs: Frac = rhs.into();
            self * rhs
        }
    }

    impl Div<I> for Frac {
        type Output = Self;
        fn div(self, rhs: I) -> Self::Output {
            let rhs: Frac = rhs.into();
            self / rhs
        }
    }

    impl Add<Frac> for I {
        type Output = Frac;
        fn add(self, rhs: Frac) -> Self::Output {
            rhs + self
        }
    }

    impl Sub<Frac> for I {
        type Output = Frac;
        fn sub(self, rhs: Frac) -> Self::Output {
            -rhs + self
        }
    }

    impl Mul<Frac> for I {
        type Output = Frac;
        fn mul(self, rhs: Frac) -> Self::Output {
            rhs * self
        }
    }

    impl Div<Frac> for I {
        type Output = Frac;
        fn div(self, rhs: Frac) -> Self::Output {
            let lhs: Frac = self.into();
            lhs / rhs
        }
    }

    impl AddAssign for Frac {
        fn add_assign(&mut self, rhs: Self) {
            *self = *self + rhs;
        }
    }

    impl SubAssign for Frac {
        fn sub_assign(&mut self, rhs: Self) {
            *self = *self - rhs;
        }
    }

    impl MulAssign for Frac {
        fn mul_assign(&mut self, rhs: Self) {
            *self = *self * rhs;
        }
    }

    impl DivAssign for Frac {
        fn div_assign(&mut self, rhs: Self) {
            *self = *self / rhs;
        }
    }

    impl AddAssign<I> for Frac {
        fn add_assign(&mut self, rhs: I) {
            *self = *self + rhs;
        }
    }

    impl SubAssign<I> for Frac {
        fn sub_assign(&mut self, rhs: I) {
            *self = *self - rhs;
        }
    }

    impl MulAssign<I> for Frac {
        fn mul_assign(&mut self, rhs: I) {
            *self = *self * rhs;
        }
    }

    impl DivAssign<I> for Frac {
        fn div_assign(&mut self, rhs: I) {
            *self = *self / rhs;
        }
    }

    impl PartialEq for Frac {
        fn eq(&self, rhs: &Self) -> bool {
            (self.num * rhs.den as I).eq(&(rhs.num * self.den as I))
        }
    }

    impl Eq for Frac {}

    impl PartialOrd for Frac {
        fn partial_cmp(&self, rhs: &Self) -> Option<std::cmp::Ordering> {
            (self.num * rhs.den as I).partial_cmp(&(rhs.num * self.den as I))
        }
    }

    impl Ord for Frac {
        fn cmp(&self, rhs: &Self) -> std::cmp::Ordering {
            (self.num * rhs.den as I).cmp(&(rhs.num * self.den as I))
        }
    }

    impl PartialEq<I> for Frac {
        fn eq(&self, rhs: &I) -> bool {
            let rhs: Frac = (*rhs).into();
            self.eq(&rhs)
        }
    }

    impl PartialOrd<I> for Frac {
        fn partial_cmp(&self, rhs: &I) -> Option<std::cmp::Ordering> {
            let rhs: Frac = (*rhs).into();
            self.partial_cmp(&rhs)
        }
    }

    impl PartialEq<Frac> for I {
        fn eq(&self, rhs: &Frac) -> bool {
            let lhs: Frac = (*self).into();
            lhs.eq(rhs)
        }
    }

    impl PartialOrd<Frac> for I {
        fn partial_cmp(&self, rhs: &Frac) -> Option<std::cmp::Ordering> {
            let lhs: Frac = (*self).into();
            lhs.partial_cmp(rhs)
        }
    }
}

Fast IO

Code

#![no_main]

#[allow(unused)]
use std::{cmp::*, collections::*, iter, mem::*, num::*, ops::*};

fn solve<'t, It: Iterator<Item = &'t str>>(sc: &mut fastio::Tokenizer<It>) {}

#[allow(unused)]
mod fastio {
	use super::ioutil::*;

	pub struct Tokenizer<It> {
		it: It,
	}

	impl<'i, 's: 'i, It> Tokenizer<It> {
		pub fn new(text: &'s str, split: impl FnOnce(&'i str) -> It) -> Self {
			Self { it: split(text) }
		}
	}

	impl<'t, It: Iterator<Item = &'t str>> Tokenizer<It> {
		pub fn next_ok<T: IterParse<'t>>(&mut self) -> PRes<'t, T> {
			T::parse_from_iter(&mut self.it)
		}

		pub fn next<T: IterParse<'t>>(&mut self) -> T {
			self.next_ok().unwrap()
		}

		pub fn next_map<T: IterParse<'t>, U, const N: usize>(&mut self, f: impl FnMut(T) -> U) -> [U; N] {
			let x: [T; N] = self.next();
			x.map(f)
		}

		pub fn next_it<T: IterParse<'t>>(&mut self) -> impl Iterator<Item = T> + '_ {
			std::iter::repeat_with(move || self.next_ok().ok()).map_while(|x| x)
		}

		pub fn next_collect<T: IterParse<'t>, V: FromIterator<T>>(&mut self, size: usize) -> V {
			self.next_it().take(size).collect()
		}
	}
}

mod ioutil {
	use std::{fmt::*, num::*};

	pub enum InputError<'t> {
		InputExhaust,
		ParseError(&'t str),
	}
	use InputError::*;

	pub type PRes<'t, T> = std::result::Result<T, InputError<'t>>;

	impl<'t> Debug for InputError<'t> {
		fn fmt(&self, f: &mut Formatter<'_>) -> Result {
			match self {
				InputExhaust => f.debug_struct("InputExhaust").finish(),
				ParseError(s) => f.debug_struct("ParseError").field("str", s).finish(),
			}
		}
	}

	pub trait Atom<'t>: Sized {
		fn parse(text: &'t str) -> PRes<'t, Self>;
	}

	impl<'t> Atom<'t> for &'t str {
		fn parse(text: &'t str) -> PRes<'t, Self> {
			Ok(text)
		}
	}

	impl<'t> Atom<'t> for &'t [u8] {
		fn parse(text: &'t str) -> PRes<'t, Self> {
			Ok(text.as_bytes())
		}
	}

	macro_rules! impl_atom {
        ($($t:ty) *) => { $(impl Atom<'_> for $t { fn parse(text: &str) -> PRes<Self> { text.parse().map_err(|_| ParseError(text)) } })* };
    }
	impl_atom!(u8 u16 u32 u64 u128 usize i8 i16 i32 i64 i128 isize f32 f64 bool char String NonZeroI8 NonZeroI16 NonZeroI32 NonZeroI64 NonZeroI128 NonZeroIsize NonZeroU8 NonZeroU16 NonZeroU32 NonZeroU64 NonZeroU128 NonZeroUsize);

	pub trait IterParse<'t>: Sized {
		fn parse_from_iter<'s, It: Iterator<Item = &'t str>>(it: &'s mut It) -> PRes<'t, Self>
		where
			't: 's;
	}

	impl<'t, A: Atom<'t>> IterParse<'t> for A {
		fn parse_from_iter<'s, It: Iterator<Item = &'t str>>(it: &'s mut It) -> PRes<'t, Self>
		where
			't: 's,
		{
			it.next().map_or(Err(InputExhaust), <Self as Atom>::parse)
		}
	}

	impl<'t, A: IterParse<'t>, const N: usize> IterParse<'t> for [A; N] {
		fn parse_from_iter<'s, It: Iterator<Item = &'t str>>(it: &'s mut It) -> PRes<'t, Self>
		where
			't: 's,
		{
			use std::mem::*;
			let mut x: [MaybeUninit<A>; N] = unsafe { MaybeUninit::uninit().assume_init() };
			for p in x.iter_mut() {
				*p = MaybeUninit::new(A::parse_from_iter(it)?);
			}
			Ok(unsafe { transmute_copy(&x) })
		}
	}

	macro_rules! impl_tuple {
        ($u:ident) => {};
        ($u:ident $($t:ident)+) => { impl<'t, $u: IterParse<'t>, $($t: IterParse<'t>),+> IterParse<'t> for ($u, $($t),+) { fn parse_from_iter<'s, It: Iterator<Item = &'t str>>(_it: &'s mut It) -> PRes<'t, Self> where 't: 's { Ok(($u::parse_from_iter(_it)?, $($t::parse_from_iter(_it)?),+)) } } impl_tuple!($($t) +); };
    }

	impl_tuple!(Q W E R T Y U I O P A S D F G H J K L Z X C V B N M);
}

#[link(name = "c")]
extern "C" {
	fn mmap(addr: usize, len: usize, p: i32, f: i32, fd: i32, o: i64) -> *mut u8;
	fn fstat(fd: i32, stat: *mut usize) -> i32;
}

fn get_input() -> &'static str {
	let mut stat = [0; 20];
	unsafe { fstat(0, stat.as_mut_ptr()) };
	let buffer = unsafe { mmap(0, stat[6], 1, 2, 0, 0) };
	unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(buffer, stat[6])) }
}

#[no_mangle]
unsafe fn main() -> i32 {
	use std::io::*;
	let mut sc = fastio::Tokenizer::new(get_input(), |s| s.split_ascii_whitespace());
	let stdout = stdout();
	WRITER = Some(BufWriter::new(stdout.lock()));
	solve(&mut sc);
	WRITER.as_mut().unwrap_unchecked().flush().ok();
	0
}

use std::io::{BufWriter, StdoutLock};
static mut WRITER: Option<BufWriter<StdoutLock>> = None;
#[macro_export]
macro_rules! print {
    ($($t:tt)*) => {{ use std::io::*; write!(unsafe{ WRITER.as_mut().unwrap_unchecked() }, $($t)*).unwrap(); }};
}
#[macro_export]
macro_rules! println {
    ($($t:tt)*) => {{ use std::io::*; writeln!(unsafe{ WRITER.as_mut().unwrap_unchecked() }, $($t)*).unwrap(); }};
}

Last modified on 231008.

Iterator Tools

Grid Iteration

gen_diriter(r, c, tr, tc) returns an iterator of (usize, usize) which iterates through right, down, left, up from (r, c) in a grid of tr rows and tc columns. gen_torusiter(r, c, tr, tc) acts similarly, but in a grid where your position wrap up when you go out of bounds.

The direction of iteration can be customized by modifying DR and DC.

const DR: [usize; 4] = [0, 1, 0, !0];
const DC: [usize; 4] = [1, 0, !0, 0];

fn gen_diriter(r: usize, c: usize, tr: usize, tc: usize) -> impl Iterator<Item = (usize, usize)> {
    std::iter::zip(DR.iter(), DC.iter())
        .map(move |(&dr, &dc)| (r.wrapping_add(dr), c.wrapping_add(dc)))
        .filter(move |&(nr, nc)| nr < tr && nc < tc)
}

fn gen_torusiter(r: usize, c: usize, tr: usize, tc: usize) -> impl Iterator<Item = (usize, usize)> {
    std::iter::zip(DR.iter(), DC.iter())
        .map(move |(&dr, &dc)| (r.wrapping_add(dr), c.wrapping_add(dc)))
        .map(move |(nr, nc)| {
            let r = if nr > usize::MAX / 2 {
                let delta = (usize::MAX - nr) % tr;
                tr - 1 - delta
            } else {
                nr
            };
            let c = if nc > usize::MAX / 2 {
                let delta = (usize::MAX - nc) % tc;
                tc - 1 - delta
            } else {
                nc % tc
            };
            (r, c)
        })
}

Cartesian Product

cart_prod(a, b) returns an iterator of cartesian product of two iterators a and b.

fn cart_prod<I, J, S, T>(a: I, b: J) -> impl Iterator<Item = (S, T)>
where I: Iterator<Item = S>, J: Iterator<Item = T> + Clone, S: Clone {
    a.flat_map(move |a| b.clone().map(move |b| (a.clone(), b)))
}

Intersperse

intersperse(iter, v) returns an iterator which inserts v between elements of iter.

fn intersperse<T: Clone>(iter: impl Iterator<Item = T>, with: T) -> impl Iterator<Item = T> {
    iter.map(move |v| [with.clone(), v]).flatten().skip(1)
}

Last modified on 231008.

Macros

Conversion between Enum and u32

Macro from https://stackoverflow.com/questions/28028854/how-do-i-match-enum-values-with-an-integer.

Example

use std::convert::{TryFrom, TryInto};
macro_rules! back_to_enum {
    {$(#[$meta:meta])* $vis:vis enum $name:ident {
        $($(#[$vmeta:meta])* $vname:ident $(= $val:expr)?,)*
    }} => {
        $(#[$meta])*
        $vis enum $name {
            $($(#[$vmeta])* $vname $(= $val)?,)*
        }

        impl TryFrom<u32> for $name {
            type Error = ();

            fn try_from(v: u32) -> Result<Self, Self::Error> {
                match v {
                    $(x if x == $name::$vname as u32 => Ok($name::$vname),)*
                    _ => Err(()),
                }
            }
        }
    }
}

back_to_enum! {
    #[derive(Clone, Copy, Debug)]
    enum Number {
        Zero,
        One,
        Two,
    }
}

fn main() {
let i: u32 = 1;
let n: Number = i.try_into().unwrap();
println!("{:?}", n);  // One
let i = n as i32;
println!("{}", i);    // 1
}

Code

use std::convert::{TryFrom, TryInto};
macro_rules! back_to_enum {
    {$(#[$meta:meta])* $vis:vis enum $name:ident {
        $($(#[$vmeta:meta])* $vname:ident $(= $val:expr)?,)*
    }} => {
        $(#[$meta])*
        $vis enum $name {
            $($(#[$vmeta])* $vname $(= $val)?,)*
        }
        impl TryFrom<u32> for $name {
            type Error = ();
            fn try_from(v: u32) -> Result<Self, Self::Error> {
                match v {
                    $(x if x == $name::$vname as u32 => Ok($name::$vname),)*
                    _ => Err(()),
                }
            }
        }
    }
}

Last modified on 231008.

Adjacency List Graph Representation

Deprecated because it's way easier to use Vec<Vec<(usize, T)>>. The performance gain isn't worth giving up usability.

Credits to kiwiyou

#[derive(Debug)]
struct Graph<T> {
    n: usize,
    first: Vec<u32>,
    edge: Vec<(u32, u32, T)>, // (to, prev, data)
}

impl<T> Graph<T> {
    fn new(n: usize, e: usize) -> Self {
        Self {
            n,
            first: vec![u32::MAX; n],
            edge: Vec::with_capacity(e),
        }
    }

    fn add_edge(&mut self, from: usize, to: usize, data: T) {
        let prev = std::mem::replace(&mut self.first[from], self.edge.len() as u32);
        self.edge.push((to as u32, prev, data));
    }

    fn neighbor(&self, of: usize) -> Neighbor<T> {
        Neighbor {
            graph: self,
            next_edge: self.first[of],
        }
    }
}

struct Neighbor<'g, T> {
    graph: &'g Graph<T>,
    next_edge: u32,
}

impl<'g, T> Iterator for Neighbor<'g, T> {
    type Item = (usize, &'g T);

    fn next(&mut self) -> Option<Self::Item> {
        let (to, next_edge, data) = self.graph.edge.get(self.next_edge as usize)?;
        self.next_edge = *next_edge;
        Some((*to as usize, data))
    }
}

Last modified on 231008.

Zero/One Trait

Deprecated because letting the generic type to implement either From<u8> or From<i8> is enough.

pub trait ZeroOne: Sized + Copy {
    fn zero() -> Self;
    fn one() -> Self;
}

macro_rules! impl_zero_one {
    ($($ty:ty) *) => { $(
        impl ZeroOne for $ty {
            #[inline(always)]
            fn one() -> Self {1}
            #[inline(always)]
            fn zero() -> Self {0}
        }
    )+ };
}

impl_zero_one!(isize i8 i16 i32 i64 i128 usize u8 u16 u32 u64 u128);

Last modified on 231008.

Deprecated Macros

HashMap

Deprecated because it's better to utilize the fact that HashMap<K, V> is From<[(K, V); N]> for data hardcoding of a hashmap.

macro_rules! count_tts {
    () => { 0 };
    ($odd:tt $($a:tt $b:tt)*) => { (count_tts!($($a)*) << 1) | 1 };
    ($($a:tt $even:tt)*) => { count_tts!($($a)*) << 1 };
}

// let map: HashMap<i64, i64> = hashmap![1,1; 2,2; 3,3];
macro_rules! hashmap {
    ($($k:expr,$v:expr);*) => {{
        let mut map = HashMap::with_capacity(count_tts![$($k )*]);
        $( map.insert($k, $v); )*
        map
    }}
}

Last modified on 231008.