openmls/binary_tree/array_representation/
kat_treemath.rs

1//! # Known Answer Tests for treemath
2//!
3//! This test file generates and read test vectors for tree math.
4//! See <https://github.com/mlswg/mls-implementations/blob/master/test-vectors.md>
5//! for more description on the test vectors.
6//!
7//! ## Parameter:
8//! Number of leaves `n_leaves`.
9//!
10//! ## Format:
11//! ```text
12//! {
13//!     "cipher_suite": /* uint16 */,
14//!     "root": /* uint32 */,
15//!     "left": [ /* array of option<uint32> */ ],
16//!     "right": [ /* array of option<uint32> */ ],
17//!     "parent": [ /* array of option<uint32> */ ],
18//!     "sibling": [ /* array of option<uint32> */ ]
19//! }
20//! ```
21//!
22//! Any value that is invalid is represented as `null`.
23//!
24//! ## Verification:
25//! * `n_nodes` is the number of nodes in the tree with `n_leaves` leaves
26//! * `root` is the root node index of the tree
27//! * `left[i]` is the node index of the left child of the node with index `i`
28//!   in a tree with `n_leaves` leaves
29//! * `right[i]` is the node index of the right child of the node with index `i`
30//!   in a tree with `n_leaves` leaves
31//! * `parent[i]` is the node index of the parent of the node with index `i` in
32//!   a tree with `n_leaves` leaves
33//! * `sibling[i]` is the node index of the sibling of the node with index `i`
34//!   in a tree with `n_leaves` leaves
35
36#[cfg(test)]
37use crate::test_utils::*;
38
39use super::treemath::*;
40
41use serde::{self, Deserialize, Serialize};
42use thiserror::Error;
43
44#[derive(Serialize, Deserialize, Debug, Clone)]
45pub struct TreeMathTestVector {
46    n_leaves: u32,
47    n_nodes: u32,
48    root: u32,
49    left: Vec<Option<u32>>,
50    right: Vec<Option<u32>>,
51    parent: Vec<Option<u32>>,
52    sibling: Vec<Option<u32>>,
53}
54
55#[cfg(any(feature = "test-utils", test))]
56pub fn generate_test_vector(n_leaves: u32) -> TreeMathTestVector {
57    let n_nodes = TreeSize::new(node_width(n_leaves as usize) as u32);
58    let mut test_vector = TreeMathTestVector {
59        n_leaves,
60        n_nodes: n_nodes.u32(),
61        root: 0,
62        left: Vec::new(),
63        right: Vec::new(),
64        parent: Vec::new(),
65        sibling: Vec::new(),
66    };
67
68    test_vector.root = root(TreeSize::new(node_width(n_leaves as usize) as u32)).test_u32();
69    for i in 0..n_nodes.u32() {
70        let tree_index = TreeNodeIndex::test_new(i);
71
72        match tree_index {
73            TreeNodeIndex::Leaf(_) => {
74                // Leaves don't have children
75                test_vector.left.push(None);
76                test_vector.right.push(None);
77                // Exclude root
78                let parent = if i != root(n_nodes).test_u32() {
79                    Some(test_parent(tree_index).test_to_tree_index())
80                } else {
81                    None
82                };
83                test_vector.parent.push(parent);
84                // Exclude root
85                let sibling = if i != root(n_nodes).test_u32() {
86                    Some(test_sibling(tree_index).test_u32())
87                } else {
88                    None
89                };
90                test_vector.sibling.push(sibling);
91            }
92            TreeNodeIndex::Parent(parent_index) => {
93                test_vector.left.push(Some(left(parent_index).test_u32()));
94                test_vector.right.push(Some(right(parent_index).test_u32()));
95                // Exclude root
96                let parent = if i != root(n_nodes).test_u32() {
97                    Some(test_parent(tree_index).test_to_tree_index())
98                } else {
99                    None
100                };
101                test_vector.parent.push(parent);
102                // Exclude root
103                let sibling = if i != root(n_nodes).test_u32() {
104                    Some(test_sibling(tree_index).test_u32())
105                } else {
106                    None
107                };
108                test_vector.sibling.push(sibling);
109            }
110        }
111    }
112
113    test_vector
114}
115
116#[test]
117fn write_test_vectors() {
118    let mut tests = Vec::new();
119
120    for n_leaves in 0..10 {
121        let test_vector = generate_test_vector(1 << n_leaves);
122        tests.push(test_vector);
123    }
124
125    write("test_vectors/tree-math-new.json", &tests);
126}
127
128#[cfg(any(feature = "test-utils", test))]
129pub fn run_test_vector(test_vector: TreeMathTestVector) -> Result<(), TmTestVectorError> {
130    let n_leaves = test_vector.n_leaves as usize;
131    let n_nodes = TreeSize::new(node_width(n_leaves) as u32);
132    if test_vector.n_nodes != node_width(n_leaves) as u32 {
133        return Err(TmTestVectorError::TreeSizeMismatch);
134    }
135    if test_vector.root != root(TreeSize::new(node_width(n_leaves) as u32)).test_u32() {
136        return Err(TmTestVectorError::RootIndexMismatch);
137    }
138
139    for i in 0..n_nodes.u32() as usize {
140        let tree_index = TreeNodeIndex::test_new(i as u32);
141        match tree_index {
142            TreeNodeIndex::Leaf(_) => {
143                if test_vector.left[i].is_some() {
144                    return Err(TmTestVectorError::LeftIndexMismatch);
145                }
146                if test_vector.right[i].is_some() {
147                    return Err(TmTestVectorError::RightIndexMismatch);
148                }
149
150                if i != root(n_nodes).test_usize()
151                    && test_vector.parent[i] != Some(test_parent(tree_index).test_to_tree_index())
152                {
153                    return Err(TmTestVectorError::ParentIndexMismatch);
154                }
155
156                if i != root(n_nodes).test_usize()
157                    && test_vector.sibling[i] != Some(test_sibling(tree_index).test_u32())
158                {
159                    return Err(TmTestVectorError::SiblingIndexMismatch);
160                }
161            }
162            TreeNodeIndex::Parent(parent_index) => {
163                if test_vector.left[i] != Some(left(parent_index).test_u32()) {
164                    return Err(TmTestVectorError::LeftIndexMismatch);
165                }
166                if test_vector.right[i] != Some(right(parent_index).test_u32()) {
167                    return Err(TmTestVectorError::RightIndexMismatch);
168                }
169
170                if i != root(n_nodes).test_usize()
171                    && test_vector.parent[i] != Some(test_parent(tree_index).test_to_tree_index())
172                {
173                    return Err(TmTestVectorError::ParentIndexMismatch);
174                }
175
176                if i != root(n_nodes).test_usize()
177                    && test_vector.sibling[i] != Some(test_sibling(tree_index).test_u32())
178                {
179                    return Err(TmTestVectorError::SiblingIndexMismatch);
180                }
181            }
182        }
183    }
184    Ok(())
185}
186
187#[test]
188fn read_test_vectors_tm() {
189    let tests: Vec<TreeMathTestVector> = read_json!("../../../test_vectors/tree-math.json");
190    for test_vector in tests {
191        match run_test_vector(test_vector) {
192            Ok(_) => {}
193            Err(e) => panic!("Error while checking tree math test vector.\n{e:?}"),
194        }
195    }
196}
197
198#[cfg(any(feature = "test-utils", test))]
199/// TreeMath test vector error
200#[derive(Error, Debug, PartialEq, Eq, Clone)]
201pub enum TmTestVectorError {
202    /// The computed tree size doesn't match the one in the test vector.
203    #[error("The computed tree size doesn't match the one in the test vector.")]
204    TreeSizeMismatch,
205    /// The computed root index doesn't match the one in the test vector.
206    #[error("The computed root index doesn't match the one in the test vector.")]
207    RootIndexMismatch,
208    /// A computed left child index doesn't match the one in the test vector.
209    #[error("A computed left child index doesn't match the one in the test vector.")]
210    LeftIndexMismatch,
211    /// A computed right child index doesn't match the one in the test vector.
212    #[error("A computed right child index doesn't match the one in the test vector.")]
213    RightIndexMismatch,
214    /// A computed parent index doesn't match the one in the test vector.
215    #[error("A computed parent index doesn't match the one in the test vector.")]
216    ParentIndexMismatch,
217    /// A computed sibling index doesn't match the one in the test vector.
218    #[error("A computed sibling index doesn't match the one in the test vector.")]
219    SiblingIndexMismatch,
220}