Revision control
Copy as Markdown
Other Tools
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
use alloc::vec::Vec;
use core::{fmt::Debug, hash::Hash};
use mls_rs_codec::{MlsDecode, MlsEncode};
use super::node::LeafIndex;
pub trait TreeIndex:
Send + Sync + Eq + Clone + Debug + Default + MlsEncode + MlsDecode + Hash + Ord
{
fn root(&self) -> Self;
fn left_unchecked(&self) -> Self;
fn right_unchecked(&self) -> Self;
fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>>;
fn is_leaf(&self) -> bool;
fn is_in_tree(&self, root: &Self) -> bool;
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
fn zero() -> Self;
#[cfg(any(feature = "secret_tree_access", feature = "private_message", test))]
fn left(&self) -> Option<Self> {
(!self.is_leaf()).then(|| self.left_unchecked())
}
#[cfg(any(feature = "secret_tree_access", feature = "private_message", test))]
fn right(&self) -> Option<Self> {
(!self.is_leaf()).then(|| self.right_unchecked())
}
fn direct_copath(&self, leaf_count: &Self) -> Vec<CopathNode<Self>> {
let root = leaf_count.root();
if !self.is_in_tree(&root) {
return Vec::new();
}
let mut path = Vec::new();
let mut parent = self.clone();
while let Some(ps) = parent.parent_sibling(leaf_count) {
path.push(CopathNode::new(ps.parent.clone(), ps.sibling));
parent = ps.parent;
}
path
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct CopathNode<T> {
pub path: T,
pub copath: T,
}
impl<T: Clone + PartialEq + Eq + core::fmt::Debug> CopathNode<T> {
pub fn new(path: T, copath: T) -> CopathNode<T> {
CopathNode { path, copath }
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct ParentSibling<T> {
pub parent: T,
pub sibling: T,
}
impl<T: Clone + PartialEq + Eq + core::fmt::Debug> ParentSibling<T> {
pub fn new(parent: T, sibling: T) -> ParentSibling<T> {
ParentSibling { parent, sibling }
}
}
macro_rules! impl_tree_stdint {
($t:ty) => {
impl TreeIndex for $t {
fn root(&self) -> $t {
*self - 1
}
/// Panicks if `x` is even in debug, overflows in release.
fn left_unchecked(&self) -> Self {
*self ^ (0x01 << (self.trailing_ones() - 1))
}
/// Panicks if `x` is even in debug, overflows in release.
fn right_unchecked(&self) -> Self {
*self ^ (0x03 << (self.trailing_ones() - 1))
}
fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>> {
if self == &leaf_count.root() {
return None;
}
let lvl = self.trailing_ones();
let p = (self & !(1 << (lvl + 1))) | (1 << lvl);
let s = if *self < p {
p.right_unchecked()
} else {
p.left_unchecked()
};
Some(ParentSibling::new(p, s))
}
fn is_leaf(&self) -> bool {
self & 1 == 0
}
fn is_in_tree(&self, root: &Self) -> bool {
*self <= 2 * root
}
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
fn zero() -> Self {
0
}
}
};
}
impl_tree_stdint!(u32);
#[cfg(test)]
impl_tree_stdint!(u64);
pub fn leaf_lca_level(x: u32, y: u32) -> u32 {
let mut xn = x;
let mut yn = y;
let mut k = 0;
while xn != yn {
xn >>= 1;
yn >>= 1;
k += 1;
}
k
}
pub fn subtree(x: u32) -> (LeafIndex, LeafIndex) {
let breadth = 1 << x.trailing_ones();
(
LeafIndex((x + 1 - breadth) >> 1),
LeafIndex(((x + breadth) >> 1) + 1),
)
}
pub struct BfsIterTopDown {
level: usize,
mask: usize,
level_end: usize,
ctr: usize,
}
impl BfsIterTopDown {
pub fn new(num_leaves: usize) -> Self {
let depth = num_leaves.trailing_zeros() as usize;
Self {
level: depth + 1,
mask: (1 << depth) - 1,
level_end: 1,
ctr: 0,
}
}
}
impl Iterator for BfsIterTopDown {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
if self.ctr == self.level_end {
if self.level == 1 {
return None;
}
self.level_end = (((self.level_end - 1) << 1) | 1) + 1;
self.level -= 1;
self.ctr = 0;
self.mask >>= 1;
}
let res = Some((self.ctr << self.level) | self.mask);
self.ctr += 1;
res
}
}
#[cfg(test)]
mod tests {
use super::*;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test as test;
#[derive(Serialize, Deserialize)]
struct TestCase {
n_leaves: u32,
n_nodes: u32,
root: u32,
left: Vec<Option<u32>>,
right: Vec<Option<u32>>,
parent: Vec<Option<u32>>,
sibling: Vec<Option<u32>>,
}
pub fn node_width(n: u32) -> u32 {
if n == 0 {
0
} else {
2 * (n - 1) + 1
}
}
#[test]
fn test_bfs_iterator() {
let expected = [7, 3, 11, 1, 5, 9, 13, 0, 2, 4, 6, 8, 10, 12, 14];
let bfs = BfsIterTopDown::new(8);
assert_eq!(bfs.collect::<Vec<_>>(), expected);
}
#[cfg_attr(coverage_nightly, coverage(off))]
fn generate_tree_math_test_cases() -> Vec<TestCase> {
let mut test_cases = Vec::new();
for log_n_leaves in 0..8 {
let n_leaves = 1 << log_n_leaves;
let n_nodes = node_width(n_leaves);
let left = (0..n_nodes).map(|x| x.left()).collect::<Vec<_>>();
let right = (0..n_nodes).map(|x| x.right()).collect::<Vec<_>>();
let (parent, sibling) = (0..n_nodes)
.map(|x| {
x.parent_sibling(&n_leaves)
.map(|ps| (ps.parent, ps.sibling))
.unzip()
})
.unzip();
test_cases.push(TestCase {
n_leaves,
n_nodes,
root: n_leaves.root(),
left,
right,
parent,
sibling,
})
}
test_cases
}
fn load_test_cases() -> Vec<TestCase> {
load_test_case_json!(tree_math, generate_tree_math_test_cases())
}
#[test]
fn test_tree_math() {
let test_cases = load_test_cases();
for case in test_cases {
assert_eq!(node_width(case.n_leaves), case.n_nodes);
assert_eq!(case.n_leaves.root(), case.root);
for x in 0..case.n_nodes {
assert_eq!(x.left(), case.left[x as usize]);
assert_eq!(x.right(), case.right[x as usize]);
let (p, s) = x
.parent_sibling(&case.n_leaves)
.map(|ps| (ps.parent, ps.sibling))
.unzip();
assert_eq!(p, case.parent[x as usize]);
assert_eq!(s, case.sibling[x as usize]);
}
}
}
#[test]
fn test_direct_path() {
let expected: Vec<Vec<u32>> = [
[0x01, 0x03, 0x07, 0x0f].to_vec(),
[0x03, 0x07, 0x0f].to_vec(),
[0x01, 0x03, 0x07, 0x0f].to_vec(),
[0x07, 0x0f].to_vec(),
[0x05, 0x03, 0x07, 0x0f].to_vec(),
[0x03, 0x07, 0x0f].to_vec(),
[0x05, 0x03, 0x07, 0x0f].to_vec(),
[0x0f].to_vec(),
[0x09, 0x0b, 0x07, 0x0f].to_vec(),
[0x0b, 0x07, 0x0f].to_vec(),
[0x09, 0x0b, 0x07, 0x0f].to_vec(),
[0x07, 0x0f].to_vec(),
[0x0d, 0x0b, 0x07, 0x0f].to_vec(),
[0x0b, 0x07, 0x0f].to_vec(),
[0x0d, 0x0b, 0x07, 0x0f].to_vec(),
[].to_vec(),
[0x11, 0x13, 0x17, 0x0f].to_vec(),
[0x13, 0x17, 0x0f].to_vec(),
[0x11, 0x13, 0x17, 0x0f].to_vec(),
[0x17, 0x0f].to_vec(),
[0x15, 0x13, 0x17, 0x0f].to_vec(),
[0x13, 0x17, 0x0f].to_vec(),
[0x15, 0x13, 0x17, 0x0f].to_vec(),
[0x0f].to_vec(),
[0x19, 0x1b, 0x17, 0x0f].to_vec(),
[0x1b, 0x17, 0x0f].to_vec(),
[0x19, 0x1b, 0x17, 0x0f].to_vec(),
[0x17, 0x0f].to_vec(),
[0x1d, 0x1b, 0x17, 0x0f].to_vec(),
[0x1b, 0x17, 0x0f].to_vec(),
[0x1d, 0x1b, 0x17, 0x0f].to_vec(),
]
.to_vec();
for (i, item) in expected.iter().enumerate() {
let path = (i as u32)
.direct_copath(&16)
.into_iter()
.map(|cp| cp.path)
.collect_vec();
assert_eq!(item, &path)
}
}
#[test]
fn test_copath_path() {
let expected: Vec<Vec<u32>> = [
[0x02, 0x05, 0x0b, 0x17].to_vec(),
[0x05, 0x0b, 0x17].to_vec(),
[0x00, 0x05, 0x0b, 0x17].to_vec(),
[0x0b, 0x17].to_vec(),
[0x06, 0x01, 0x0b, 0x17].to_vec(),
[0x01, 0x0b, 0x17].to_vec(),
[0x04, 0x01, 0x0b, 0x17].to_vec(),
[0x17].to_vec(),
[0x0a, 0x0d, 0x03, 0x17].to_vec(),
[0x0d, 0x03, 0x17].to_vec(),
[0x08, 0x0d, 0x03, 0x17].to_vec(),
[0x03, 0x17].to_vec(),
[0x0e, 0x09, 0x03, 0x17].to_vec(),
[0x09, 0x03, 0x17].to_vec(),
[0x0c, 0x09, 0x03, 0x17].to_vec(),
[].to_vec(),
[0x12, 0x15, 0x1b, 0x07].to_vec(),
[0x15, 0x1b, 0x07].to_vec(),
[0x10, 0x15, 0x1b, 0x07].to_vec(),
[0x1b, 0x07].to_vec(),
[0x16, 0x11, 0x1b, 0x07].to_vec(),
[0x11, 0x1b, 0x07].to_vec(),
[0x14, 0x11, 0x1b, 0x07].to_vec(),
[0x07].to_vec(),
[0x1a, 0x1d, 0x13, 0x07].to_vec(),
[0x1d, 0x13, 0x07].to_vec(),
[0x18, 0x1d, 0x13, 0x07].to_vec(),
[0x13, 0x07].to_vec(),
[0x1e, 0x19, 0x13, 0x07].to_vec(),
[0x19, 0x13, 0x07].to_vec(),
[0x1c, 0x19, 0x13, 0x07].to_vec(),
]
.to_vec();
for (i, item) in expected.iter().enumerate() {
let copath = (i as u32)
.direct_copath(&16)
.into_iter()
.map(|cp| cp.copath)
.collect_vec();
assert_eq!(item, &copath)
}
}
}