Skip to main content

openmls/binary_tree/array_representation/
kat_treemath.rs

1//! # Known Answer Tests for treemath
2//!
3//! This test file generates and read test vectors for tree math.
4//! See <https://github.com/mlswg/mls-implementations/blob/master/test-vectors.md>
5//! for more description on the test vectors.
6//!
7//! ## Parameter:
8//! Number of leaves `n_leaves`.
9//!
10//! ## Format:
11//! ```text
12//! {
13//!     "cipher_suite": /* uint16 */,
14//!     "root": /* uint32 */,
15//!     "left": [ /* array of option<uint32> */ ],
16//!     "right": [ /* array of option<uint32> */ ],
17//!     "parent": [ /* array of option<uint32> */ ],
18//!     "sibling": [ /* array of option<uint32> */ ]
19//! }
20//! ```
21//!
22//! Any value that is invalid is represented as `null`.
23//!
24//! ## Verification:
25//! * `n_nodes` is the number of nodes in the tree with `n_leaves` leaves
26//! * `root` is the root node index of the tree
27//! * `left[i]` is the node index of the left child of the node with index `i`
28//!   in a tree with `n_leaves` leaves
29//! * `right[i]` is the node index of the right child of the node with index `i`
30//!   in a tree with `n_leaves` leaves
31//! * `parent[i]` is the node index of the parent of the node with index `i` in
32//!   a tree with `n_leaves` leaves
33//! * `sibling[i]` is the node index of the sibling of the node with index `i`
34//!   in a tree with `n_leaves` leaves
35
36#[cfg(all(test, feature = "generate-kats"))]
37use crate::test_utils::*;
38
39use super::treemath::*;
40
41use serde::{self, Deserialize, Serialize};
42use thiserror::Error;
43
44#[derive(Serialize, Deserialize, Debug, Clone)]
45pub struct TreeMathTestVector {
46    n_leaves: u32,
47    n_nodes: u32,
48    root: u32,
49    left: Vec<Option<u32>>,
50    right: Vec<Option<u32>>,
51    parent: Vec<Option<u32>>,
52    sibling: Vec<Option<u32>>,
53}
54
55#[cfg(any(feature = "test-utils", test))]
56pub fn generate_test_vector(n_leaves: u32) -> TreeMathTestVector {
57    let n_nodes = TreeSize::new(node_width(n_leaves as usize) as u32);
58    let mut test_vector = TreeMathTestVector {
59        n_leaves,
60        n_nodes: n_nodes.u32(),
61        root: 0,
62        left: Vec::new(),
63        right: Vec::new(),
64        parent: Vec::new(),
65        sibling: Vec::new(),
66    };
67
68    test_vector.root = root(TreeSize::new(node_width(n_leaves as usize) as u32)).test_u32();
69    for i in 0..n_nodes.u32() {
70        let tree_index = TreeNodeIndex::test_new(i);
71
72        match tree_index {
73            TreeNodeIndex::Leaf(_) => {
74                // Leaves don't have children
75                test_vector.left.push(None);
76                test_vector.right.push(None);
77                // Exclude root
78                let parent = if i != root(n_nodes).test_u32() {
79                    Some(test_parent(tree_index).test_to_tree_index())
80                } else {
81                    None
82                };
83                test_vector.parent.push(parent);
84                // Exclude root
85                let sibling = if i != root(n_nodes).test_u32() {
86                    Some(test_sibling(tree_index).test_u32())
87                } else {
88                    None
89                };
90                test_vector.sibling.push(sibling);
91            }
92            TreeNodeIndex::Parent(parent_index) => {
93                test_vector.left.push(Some(left(parent_index).test_u32()));
94                test_vector.right.push(Some(right(parent_index).test_u32()));
95                // Exclude root
96                let parent = if i != root(n_nodes).test_u32() {
97                    Some(test_parent(tree_index).test_to_tree_index())
98                } else {
99                    None
100                };
101                test_vector.parent.push(parent);
102                // Exclude root
103                let sibling = if i != root(n_nodes).test_u32() {
104                    Some(test_sibling(tree_index).test_u32())
105                } else {
106                    None
107                };
108                test_vector.sibling.push(sibling);
109            }
110        }
111    }
112
113    test_vector
114}
115
116#[cfg(feature = "generate-kats")]
117#[test]
118fn write_test_vectors() {
119    let mut tests = Vec::new();
120
121    for n_leaves in 0..10 {
122        let test_vector = generate_test_vector(1 << n_leaves);
123        tests.push(test_vector);
124    }
125
126    write("test_vectors/tree-math-new.json", &tests);
127}
128
129#[cfg(any(feature = "test-utils", test))]
130pub fn run_test_vector(test_vector: TreeMathTestVector) -> Result<(), TmTestVectorError> {
131    let n_leaves = test_vector.n_leaves as usize;
132    let n_nodes = TreeSize::new(node_width(n_leaves) as u32);
133    if test_vector.n_nodes != node_width(n_leaves) as u32 {
134        return Err(TmTestVectorError::TreeSizeMismatch);
135    }
136    if test_vector.root != root(TreeSize::new(node_width(n_leaves) as u32)).test_u32() {
137        return Err(TmTestVectorError::RootIndexMismatch);
138    }
139
140    for i in 0..n_nodes.u32() as usize {
141        let tree_index = TreeNodeIndex::test_new(i as u32);
142        match tree_index {
143            TreeNodeIndex::Leaf(_) => {
144                if test_vector.left[i].is_some() {
145                    return Err(TmTestVectorError::LeftIndexMismatch);
146                }
147                if test_vector.right[i].is_some() {
148                    return Err(TmTestVectorError::RightIndexMismatch);
149                }
150
151                if i != root(n_nodes).test_usize()
152                    && test_vector.parent[i] != Some(test_parent(tree_index).test_to_tree_index())
153                {
154                    return Err(TmTestVectorError::ParentIndexMismatch);
155                }
156
157                if i != root(n_nodes).test_usize()
158                    && test_vector.sibling[i] != Some(test_sibling(tree_index).test_u32())
159                {
160                    return Err(TmTestVectorError::SiblingIndexMismatch);
161                }
162            }
163            TreeNodeIndex::Parent(parent_index) => {
164                if test_vector.left[i] != Some(left(parent_index).test_u32()) {
165                    return Err(TmTestVectorError::LeftIndexMismatch);
166                }
167                if test_vector.right[i] != Some(right(parent_index).test_u32()) {
168                    return Err(TmTestVectorError::RightIndexMismatch);
169                }
170
171                if i != root(n_nodes).test_usize()
172                    && test_vector.parent[i] != Some(test_parent(tree_index).test_to_tree_index())
173                {
174                    return Err(TmTestVectorError::ParentIndexMismatch);
175                }
176
177                if i != root(n_nodes).test_usize()
178                    && test_vector.sibling[i] != Some(test_sibling(tree_index).test_u32())
179                {
180                    return Err(TmTestVectorError::SiblingIndexMismatch);
181                }
182            }
183        }
184    }
185    Ok(())
186}
187
188#[test]
189fn read_test_vectors_tm() {
190    let tests: Vec<TreeMathTestVector> = read_json!("../../../test_vectors/tree-math.json");
191    for test_vector in tests {
192        match run_test_vector(test_vector) {
193            Ok(_) => {}
194            Err(e) => panic!("Error while checking tree math test vector.\n{e:?}"),
195        }
196    }
197}
198
199#[cfg(any(feature = "test-utils", test))]
200/// TreeMath test vector error
201#[derive(Error, Debug, PartialEq, Eq, Clone)]
202pub enum TmTestVectorError {
203    /// The computed tree size doesn't match the one in the test vector.
204    #[error("The computed tree size doesn't match the one in the test vector.")]
205    TreeSizeMismatch,
206    /// The computed root index doesn't match the one in the test vector.
207    #[error("The computed root index doesn't match the one in the test vector.")]
208    RootIndexMismatch,
209    /// A computed left child index doesn't match the one in the test vector.
210    #[error("A computed left child index doesn't match the one in the test vector.")]
211    LeftIndexMismatch,
212    /// A computed right child index doesn't match the one in the test vector.
213    #[error("A computed right child index doesn't match the one in the test vector.")]
214    RightIndexMismatch,
215    /// A computed parent index doesn't match the one in the test vector.
216    #[error("A computed parent index doesn't match the one in the test vector.")]
217    ParentIndexMismatch,
218    /// A computed sibling index doesn't match the one in the test vector.
219    #[error("A computed sibling index doesn't match the one in the test vector.")]
220    SiblingIndexMismatch,
221}