Source code

Revision control

Copy as Markdown

Other Tools

// Copyright (c) 2020 Apple Inc.
// SPDX-License-Identifier: MPL-2.0
//! Functions for polynomial interpolation and evaluation
#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish};
use crate::field::FftFriendlyFieldElement;
use std::convert::TryFrom;
/// Temporary memory used for FFT
#[derive(Clone, Debug)]
pub struct PolyFFTTempMemory<F> {
fft_tmp: Vec<F>,
fft_y_sub: Vec<F>,
fft_roots_sub: Vec<F>,
}
impl<F: FftFriendlyFieldElement> PolyFFTTempMemory<F> {
fn new(length: usize) -> Self {
PolyFFTTempMemory {
fft_tmp: vec![F::zero(); length],
fft_y_sub: vec![F::zero(); length],
fft_roots_sub: vec![F::zero(); length],
}
}
}
/// Auxiliary memory for polynomial interpolation and evaluation
#[derive(Clone, Debug)]
pub struct PolyAuxMemory<F> {
pub roots_2n: Vec<F>,
pub roots_2n_inverted: Vec<F>,
pub roots_n: Vec<F>,
pub roots_n_inverted: Vec<F>,
pub coeffs: Vec<F>,
pub fft_memory: PolyFFTTempMemory<F>,
}
impl<F: FftFriendlyFieldElement> PolyAuxMemory<F> {
pub fn new(n: usize) -> Self {
PolyAuxMemory {
roots_2n: fft_get_roots(2 * n, false),
roots_2n_inverted: fft_get_roots(2 * n, true),
roots_n: fft_get_roots(n, false),
roots_n_inverted: fft_get_roots(n, true),
coeffs: vec![F::zero(); 2 * n],
fft_memory: PolyFFTTempMemory::new(2 * n),
}
}
}
fn fft_recurse<F: FftFriendlyFieldElement>(
out: &mut [F],
n: usize,
roots: &[F],
ys: &[F],
tmp: &mut [F],
y_sub: &mut [F],
roots_sub: &mut [F],
) {
if n == 1 {
out[0] = ys[0];
return;
}
let half_n = n / 2;
let (tmp_first, tmp_second) = tmp.split_at_mut(half_n);
let (y_sub_first, y_sub_second) = y_sub.split_at_mut(half_n);
let (roots_sub_first, roots_sub_second) = roots_sub.split_at_mut(half_n);
// Recurse on the first half
for i in 0..half_n {
y_sub_first[i] = ys[i] + ys[i + half_n];
roots_sub_first[i] = roots[2 * i];
}
fft_recurse(
tmp_first,
half_n,
roots_sub_first,
y_sub_first,
tmp_second,
y_sub_second,
roots_sub_second,
);
for i in 0..half_n {
out[2 * i] = tmp_first[i];
}
// Recurse on the second half
for i in 0..half_n {
y_sub_first[i] = ys[i] - ys[i + half_n];
y_sub_first[i] *= roots[i];
}
fft_recurse(
tmp_first,
half_n,
roots_sub_first,
y_sub_first,
tmp_second,
y_sub_second,
roots_sub_second,
);
for i in 0..half_n {
out[2 * i + 1] = tmp[i];
}
}
/// Calculate `count` number of roots of unity of order `count`
fn fft_get_roots<F: FftFriendlyFieldElement>(count: usize, invert: bool) -> Vec<F> {
let mut roots = vec![F::zero(); count];
let mut gen = F::generator();
if invert {
gen = gen.inv();
}
roots[0] = F::one();
let step_size = F::generator_order() / F::Integer::try_from(count).unwrap();
// generator for subgroup of order count
gen = gen.pow(step_size);
roots[1] = gen;
for i in 2..count {
roots[i] = gen * roots[i - 1];
}
roots
}
fn fft_interpolate_raw<F: FftFriendlyFieldElement>(
out: &mut [F],
ys: &[F],
n_points: usize,
roots: &[F],
invert: bool,
mem: &mut PolyFFTTempMemory<F>,
) {
fft_recurse(
out,
n_points,
roots,
ys,
&mut mem.fft_tmp,
&mut mem.fft_y_sub,
&mut mem.fft_roots_sub,
);
if invert {
let n_inverse = F::from(F::Integer::try_from(n_points).unwrap()).inv();
for out_val in out[0..n_points].iter_mut() {
*out_val *= n_inverse;
}
}
}
pub fn poly_fft<F: FftFriendlyFieldElement>(
points_out: &mut [F],
points_in: &[F],
scaled_roots: &[F],
n_points: usize,
invert: bool,
mem: &mut PolyFFTTempMemory<F>,
) {
fft_interpolate_raw(points_out, points_in, n_points, scaled_roots, invert, mem)
}
// Evaluate a polynomial using Horner's method.
pub fn poly_eval<F: FftFriendlyFieldElement>(poly: &[F], eval_at: F) -> F {
if poly.is_empty() {
return F::zero();
}
let mut result = poly[poly.len() - 1];
for i in (0..poly.len() - 1).rev() {
result *= eval_at;
result += poly[i];
}
result
}
// Returns the degree of polynomial `p`.
pub fn poly_deg<F: FftFriendlyFieldElement>(p: &[F]) -> usize {
let mut d = p.len();
while d > 0 && p[d - 1] == F::zero() {
d -= 1;
}
d.saturating_sub(1)
}
// Multiplies polynomials `p` and `q` and returns the result.
pub fn poly_mul<F: FftFriendlyFieldElement>(p: &[F], q: &[F]) -> Vec<F> {
let p_size = poly_deg(p) + 1;
let q_size = poly_deg(q) + 1;
let mut out = vec![F::zero(); p_size + q_size];
for i in 0..p_size {
for j in 0..q_size {
out[i + j] += p[i] * q[j];
}
}
out.truncate(poly_deg(&out) + 1);
out
}
#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
#[inline]
pub fn poly_interpret_eval<F: FftFriendlyFieldElement>(
points: &[F],
eval_at: F,
tmp_coeffs: &mut [F],
) -> F {
let size_inv = F::from(F::Integer::try_from(points.len()).unwrap()).inv();
discrete_fourier_transform(tmp_coeffs, points, points.len()).unwrap();
discrete_fourier_transform_inv_finish(tmp_coeffs, points.len(), size_inv);
poly_eval(&tmp_coeffs[..points.len()], eval_at)
}
// Returns a polynomial that evaluates to `0` if the input is in range `[start, end)`. Otherwise,
// the output is not `0`.
pub(crate) fn poly_range_check<F: FftFriendlyFieldElement>(start: usize, end: usize) -> Vec<F> {
let mut p = vec![F::one()];
let mut q = [F::zero(), F::one()];
for i in start..end {
q[0] = -F::from(F::Integer::try_from(i).unwrap());
p = poly_mul(&p, &q);
}
p
}
#[cfg(test)]
mod tests {
use crate::{
field::{
FftFriendlyFieldElement, Field64, FieldElement, FieldElementWithInteger, FieldPrio2,
},
polynomial::{
fft_get_roots, poly_deg, poly_eval, poly_fft, poly_mul, poly_range_check, PolyAuxMemory,
},
};
use rand::prelude::*;
use std::convert::TryFrom;
#[test]
fn test_roots() {
let count = 128;
let roots = fft_get_roots::<FieldPrio2>(count, false);
let roots_inv = fft_get_roots::<FieldPrio2>(count, true);
for i in 0..count {
assert_eq!(roots[i] * roots_inv[i], 1);
assert_eq!(roots[i].pow(u32::try_from(count).unwrap()), 1);
assert_eq!(roots_inv[i].pow(u32::try_from(count).unwrap()), 1);
}
}
#[test]
fn test_eval() {
let mut poly = [FieldPrio2::from(0); 4];
poly[0] = 2.into();
poly[1] = 1.into();
poly[2] = 5.into();
// 5*3^2 + 3 + 2 = 50
assert_eq!(poly_eval(&poly[..3], 3.into()), 50);
poly[3] = 4.into();
// 4*3^3 + 5*3^2 + 3 + 2 = 158
assert_eq!(poly_eval(&poly[..4], 3.into()), 158);
}
#[test]
fn test_poly_deg() {
let zero = FieldPrio2::zero();
let one = FieldPrio2::root(0).unwrap();
assert_eq!(poly_deg(&[zero]), 0);
assert_eq!(poly_deg(&[one]), 0);
assert_eq!(poly_deg(&[zero, one]), 1);
assert_eq!(poly_deg(&[zero, zero, one]), 2);
assert_eq!(poly_deg(&[zero, one, one]), 2);
assert_eq!(poly_deg(&[zero, one, one, one]), 3);
assert_eq!(poly_deg(&[zero, one, one, one, zero]), 3);
assert_eq!(poly_deg(&[zero, one, one, one, zero, zero]), 3);
}
#[test]
fn test_poly_mul() {
let p = [
Field64::from(u64::try_from(2).unwrap()),
Field64::from(u64::try_from(3).unwrap()),
];
let q = [
Field64::one(),
Field64::zero(),
Field64::from(u64::try_from(5).unwrap()),
];
let want = [
Field64::from(u64::try_from(2).unwrap()),
Field64::from(u64::try_from(3).unwrap()),
Field64::from(u64::try_from(10).unwrap()),
Field64::from(u64::try_from(15).unwrap()),
];
let got = poly_mul(&p, &q);
assert_eq!(&got, &want);
}
#[test]
fn test_poly_range_check() {
let start = 74;
let end = 112;
let p = poly_range_check(start, end);
// Check each number in the range.
for i in start..end {
let x = Field64::from(i as u64);
let y = poly_eval(&p, x);
assert_eq!(y, Field64::zero(), "range check failed for {i}");
}
// Check the number below the range.
let x = Field64::from((start - 1) as u64);
let y = poly_eval(&p, x);
assert_ne!(y, Field64::zero());
// Check a number above the range.
let x = Field64::from(end as u64);
let y = poly_eval(&p, x);
assert_ne!(y, Field64::zero());
}
#[test]
fn test_fft() {
let count = 128;
let mut mem = PolyAuxMemory::new(count / 2);
let mut poly = vec![FieldPrio2::from(0); count];
let mut points2 = vec![FieldPrio2::from(0); count];
let points = (0..count)
.map(|_| FieldPrio2::from(random::<u32>()))
.collect::<Vec<FieldPrio2>>();
// From points to coeffs and back
poly_fft(
&mut poly,
&points,
&mem.roots_2n,
count,
false,
&mut mem.fft_memory,
);
poly_fft(
&mut points2,
&poly,
&mem.roots_2n_inverted,
count,
true,
&mut mem.fft_memory,
);
assert_eq!(points, points2);
// interpolation
poly_fft(
&mut poly,
&points,
&mem.roots_2n,
count,
false,
&mut mem.fft_memory,
);
for (poly_coeff, root) in poly[..count].iter().zip(mem.roots_2n[..count].iter()) {
let mut should_be = FieldPrio2::from(0);
for (j, point_j) in points[..count].iter().enumerate() {
should_be = root.pow(u32::try_from(j).unwrap()) * *point_j + should_be;
}
assert_eq!(should_be, *poly_coeff);
}
}
}