Number Theoretic Transform (NTT)
Number Theoretic Transform (NTT) is an alternative of FFT where the domain of numbers is \(\mathbb{Z}_p\) where \(p\) is a prime with a form of \(p = a \times 2^b + 1\).
ntt::convolute(&[u64], &[u64]) -> Vec<u64>
convolutes two slices using two NTTs on two different primes and CRT, within a range that none of the numbers in the result exceed the range of u64
.
Example
fn main() { let a: Vec<u64> = vec![1, 2]; let b: Vec<u64> = vec![3, 4, 5]; println!("{:?}", ntt::convolute(&a, &b)); // [3, 10, 13, 10, 0, 0, 0, 0] } mod ntt { // FFT_constname convention following https://algoshitpo.github.io/2020/05/20/fft-ntt/ // p: prime for modulo // w: primitive root of p // p = a * 2^b + 1 // p ntt_a ntt_b ntt_w // 998,244,353 119 23 3 // 2,281,701,377 17 27 3 // 2,483,027,969 37 26 3 // 2,113,929,217 63 25 5 // 104,857,601 25 22 3 // 1,092,616,193 521 21 3 fn ceil_pow2(n: usize) -> u32 { n.next_power_of_two().trailing_zeros() } /// Reverses k trailing bits of n. Assumes that the rest of usize::BITS-k bits are all zero. const fn reverse_trailing_bits(n: usize, k: u32) -> usize { n.reverse_bits() >> (usize::BITS - k) } #[derive(Clone, Debug)] pub struct Ntt<const P: u64> { pub arr: Vec<u64>, } impl<const P: u64> Ntt<P> { pub const fn ntt_a() -> u64 { let mut p = P - 1; while p & 1 == 0 { p >>= 1; } p } pub const fn ntt_b() -> u32 { let mut p = P - 1; let mut ret = 0; while p & 1 == 0 { p >>= 1; ret += 1; } ret } pub const fn ntt_w() -> u64 { match P { 998244353 | 2281701377 | 2483027969 | 104857601 | 1092616193 => 3, 2113929217 => 5, _ => todo!(), } } const fn pow(base: u64, exp: u64) -> u64 { let mut base = base; let mut exp = exp; let mut ret = 1; while exp != 0 { if exp & 1 != 0 { ret = ret * base % P; } base = base * base % P; exp >>= 1; } ret } // unity(n, 1) ^ (1<<n) == 1 const fn unity(n: u32, k: u64) -> u64 { Self::pow(Self::pow(Self::ntt_w(), Self::ntt_a()), k << (Self::ntt_b() - n)) } const fn recip(x: u64) -> u64 { Self::pow(x, P - 2) } pub fn new(arr: Vec<u64>) -> Self { Self { arr } } pub fn ntt(&mut self) { let n: usize = self.arr.len(); let k = n.trailing_zeros(); debug_assert_eq!(n, 1 << k); for i in 0..n { let j = reverse_trailing_bits(i, k); if i < j { self.arr.swap(i, j); } } for x in 0..k { let base: u64 = Self::unity(x + 1, 1); let s = 1 << x; for i in (0..n).step_by(s << 1) { let mut mult: u64 = 1; for j in 0..s { let tmp = (self.arr[i + j + s] * mult) % P; self.arr[i + j + s] = (self.arr[i + j] + P - tmp) % P; self.arr[i + j] = (self.arr[i + j] + tmp) % P; mult = mult * base % P; } } } } pub fn intt(&mut self) { let n: usize = self.arr.len(); let k = n.trailing_zeros(); debug_assert_eq!(n, 1 << k); for i in 0..n { let j = reverse_trailing_bits(i, k); if i < j { self.arr.swap(i, j); } } for x in 0..k { let base: u64 = Self::recip(Self::unity(x + 1, 1)); let s = 1 << x; for i in (0..n).step_by(s << 1) { let mut mult: u64 = 1; for j in 0..s { let tmp = (self.arr[i + j + s] * mult) % P; self.arr[i + j + s] = (self.arr[i + j] + P - tmp) % P; self.arr[i + j] = (self.arr[i + j] + tmp) % P; mult = mult * base % P; } } } let r = Self::recip(n as u64); for f in self.arr.iter_mut() { *f *= r; *f %= P; } } pub fn convolute(a: &[u64], b: &[u64]) -> Self { let nlen = 1 << ceil_pow2(a.len() + b.len()); let pad = |a: &[u64]| a.iter().copied().chain(std::iter::repeat(0)).take(nlen).collect(); let arr = pad(a); let brr = pad(b); let mut arr = Self::new(arr); let mut brr = Self::new(brr); arr.ntt(); brr.ntt(); let crr: Vec<_> = arr.arr.iter().zip(brr.arr.iter()).map(|(&a, &b)| a * b % P).collect(); let mut crr = Self::new(crr); crr.intt(); crr } } fn merge<const P: u64, const Q: u64>(one: &[u64], two: &[u64]) -> Vec<u64> { let p: u64 = Ntt::<Q>::recip(P); let q: u64 = Ntt::<P>::recip(Q); let r: u64 = P * Q; one.iter() .zip(two.iter()) .map(|(&a1, &a2)| ((a1 as u128 * q as u128 * Q as u128 + a2 as u128 * p as u128 * P as u128) % r as u128) as u64) .collect() } pub fn convolute(a: &[u64], b: &[u64]) -> Vec<u64> { const P: u64 = 2281701377; const Q: u64 = 998244353; let a: Vec<u64> = a.iter().copied().collect(); let b: Vec<u64> = b.iter().copied().collect(); let arr = Ntt::<P>::convolute(&a, &b); let brr = Ntt::<Q>::convolute(&a, &b); merge::<P, Q>(&arr.arr, &brr.arr) } }
Code
mod ntt {
// FFT_constname convention following https://algoshitpo.github.io/2020/05/20/fft-ntt/
// p: prime for modulo
// w: primitive root of p
// p = a * 2^b + 1
// p ntt_a ntt_b ntt_w
// 998,244,353 119 23 3
// 2,281,701,377 17 27 3
// 2,483,027,969 37 26 3
// 2,113,929,217 63 25 5
// 104,857,601 25 22 3
// 1,092,616,193 521 21 3
fn ceil_pow2(n: usize) -> u32 { n.next_power_of_two().trailing_zeros() }
/// Reverses k trailing bits of n. Assumes that the rest of usize::BITS-k bits are all zero.
const fn reverse_trailing_bits(n: usize, k: u32) -> usize { n.reverse_bits() >> (usize::BITS - k) }
#[derive(Clone, Debug)]
pub struct Ntt<const P: u64> {
pub arr: Vec<u64>,
}
impl<const P: u64> Ntt<P> {
pub const fn ntt_a() -> u64 {
let p = P - 1;
p >> p.trailing_zeros()
}
pub const fn ntt_b() -> u32 { (P - 1).trailing_zeros() }
/// Primitive root of P
pub const fn ntt_w() -> u64 {
match P {
998244353 | 2281701377 | 2483027969 | 104857601 | 1092616193 => 3,
2113929217 => 5,
_ => todo!(),
}
}
const fn pow(mut base: u64, mut exp: u64) -> u64 {
let mut ret = 1;
while exp != 0 {
if exp & 1 != 0 {
ret = ret * base % P;
}
base = base * base % P;
exp >>= 1;
}
ret
}
/// Returns an integer x where x^(2^n) == 1 mod P.
/// That is, it returns (2^n)-th root of unity.
const fn unity(n: u32, k: u64) -> u64 { Self::pow(Self::pow(Self::ntt_w(), Self::ntt_a()), k << (Self::ntt_b() - n)) }
const fn recip(x: u64) -> u64 { Self::pow(x, P - 2) }
pub fn new(arr: Vec<u64>) -> Self { Self { arr } }
pub fn ntt(&mut self) {
let n: usize = self.arr.len();
let k = n.trailing_zeros();
debug_assert_eq!(n, 1 << k);
for i in 0..n {
let j = reverse_trailing_bits(i, k);
if i < j {
self.arr.swap(i, j);
}
}
let mut basis = vec![Self::unity(k, 1)];
for i in 1..k as usize {
basis.push(basis[i - 1] * basis[i - 1] % P);
}
for (x, &base) in basis.iter().rev().enumerate() {
let s = 1 << x;
for i in (0..n).step_by(s << 1) {
let mut mult: u64 = 1;
for j in i..i + s {
let tmp = (self.arr[j + s] * mult) % P;
self.arr[j + s] = (self.arr[j] + P - tmp) % P;
self.arr[j] = (self.arr[j] + tmp) % P;
mult = mult * base % P;
}
}
}
}
pub fn intt(&mut self) {
let n: usize = self.arr.len();
let k = n.trailing_zeros();
debug_assert_eq!(n, 1 << k);
for i in 0..n {
let j = reverse_trailing_bits(i, k);
if i < j {
self.arr.swap(i, j);
}
}
let mut basis = vec![Self::recip(Self::unity(k, 1))];
for i in 1..k as usize {
basis.push(basis[i - 1] * basis[i - 1] % P);
}
for (x, &base) in basis.iter().rev().enumerate() {
let s = 1 << x;
for i in (0..n).step_by(s << 1) {
let mut mult: u64 = 1;
for j in i..i + s {
let tmp = (self.arr[j + s] * mult) % P;
self.arr[j + s] = (self.arr[j] + P - tmp) % P;
self.arr[j] = (self.arr[j] + tmp) % P;
mult = mult * base % P;
}
}
}
let r = Self::recip(n as u64);
for f in self.arr.iter_mut() {
*f *= r;
*f %= P;
}
}
pub fn convolute(a: &[u64], b: &[u64]) -> Self {
let nlen = 1 << ceil_pow2(a.len() + b.len());
let pad = |a: &[u64]| a.iter().copied().chain(std::iter::repeat(0)).take(nlen).collect();
let arr = pad(a);
let brr = pad(b);
let mut arr = Self::new(arr);
let mut brr = Self::new(brr);
arr.ntt();
brr.ntt();
let crr: Vec<_> = arr.arr.iter().zip(brr.arr.iter()).map(|(&a, &b)| a * b % P).collect();
let mut crr = Self::new(crr);
crr.intt();
crr
}
}
fn merge<const P: u64, const Q: u64>(one: &[u64], two: &[u64]) -> Vec<u64> {
let p = Ntt::<Q>::recip(P) as u128;
let q = Ntt::<P>::recip(Q) as u128;
let [pp, qq] = [P, Q].map(|x| x as u128);
let r = (P * Q) as u128;
one.iter()
.zip(two.iter())
.map(|(&a1, &a2)| {
let [a, b] = [a1, a2].map(|x| x as u128);
(a * q * qq + b * p * pp) % r
})
.map(|x| x as u64)
.collect()
}
pub fn convolute(a: &[u64], b: &[u64]) -> Vec<u64> {
const P: u64 = 2281701377;
const Q: u64 = 998244353;
let arr = Ntt::<P>::convolute(a, b);
let brr = Ntt::<Q>::convolute(a, b);
merge::<P, Q>(&arr.arr, &brr.arr)
}
}
Last modified on 231203.