openmls/binary_tree/array_representation/
treemath.rs

1use std::cmp::Ordering;
2
3use serde::{Deserialize, Serialize};
4use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
5
6pub(crate) const MAX_TREE_SIZE: u32 = 1 << 30;
7pub(crate) const MIN_TREE_SIZE: u32 = 1;
8
9/// LeafNodeIndex references a leaf node in a tree.
10#[derive(
11    Debug,
12    Clone,
13    Copy,
14    PartialEq,
15    Eq,
16    PartialOrd,
17    Ord,
18    Hash,
19    Serialize,
20    Deserialize,
21    TlsDeserialize,
22    TlsDeserializeBytes,
23    TlsSerialize,
24    TlsSize,
25)]
26pub struct LeafNodeIndex(u32);
27
28impl std::fmt::Display for LeafNodeIndex {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.write_fmt(format_args!("{:?}", self.0))
31    }
32}
33
34impl LeafNodeIndex {
35    /// Create a new `LeafNodeIndex` from a `u32`.
36    pub fn new(index: u32) -> Self {
37        LeafNodeIndex(index)
38    }
39
40    /// Return the inner value as `u32`.
41    pub fn u32(&self) -> u32 {
42        self.0
43    }
44
45    /// Return the inner value as `usize`.
46    pub fn usize(&self) -> usize {
47        self.u32() as usize
48    }
49
50    /// Return the index as a TreeNodeIndex value.
51    fn to_tree_index(self) -> u32 {
52        self.0 * 2
53    }
54
55    /// Warning: Only use when the node index represents a leaf node
56    fn from_tree_index(node_index: u32) -> Self {
57        debug_assert!(node_index % 2 == 0);
58        LeafNodeIndex(node_index / 2)
59    }
60}
61
62/// ParentNodeIndex references a parent node in a tree.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
64pub struct ParentNodeIndex(u32);
65
66impl ParentNodeIndex {
67    /// Create a new `ParentNodeIndex` from a `u32`.
68    pub(crate) fn new(index: u32) -> Self {
69        ParentNodeIndex(index)
70    }
71
72    /// Return the inner value as `u32`.
73    pub(crate) fn u32(&self) -> u32 {
74        self.0
75    }
76
77    pub(crate) fn usize(&self) -> usize {
78        self.0 as usize
79    }
80
81    /// Return the index as a TreeNodeIndex value.
82    fn to_tree_index(self) -> u32 {
83        self.0 * 2 + 1
84    }
85
86    /// Warning: Only use when the node index represents a parent node
87    fn from_tree_index(node_index: u32) -> Self {
88        debug_assert!(node_index > 0);
89        debug_assert!(node_index % 2 == 1);
90        ParentNodeIndex((node_index - 1) / 2)
91    }
92}
93
94#[cfg(test)]
95impl ParentNodeIndex {
96    /// Re-exported for testing.
97    pub(crate) fn test_from_tree_index(node_index: u32) -> Self {
98        Self::from_tree_index(node_index)
99    }
100}
101
102#[cfg(any(feature = "test-utils", test))]
103impl ParentNodeIndex {
104    /// Re-exported for testing.
105    pub(crate) fn test_to_tree_index(self) -> u32 {
106        self.to_tree_index()
107    }
108}
109
110impl From<LeafNodeIndex> for TreeNodeIndex {
111    fn from(leaf_index: LeafNodeIndex) -> Self {
112        TreeNodeIndex::Leaf(leaf_index)
113    }
114}
115
116impl From<ParentNodeIndex> for TreeNodeIndex {
117    fn from(parent_index: ParentNodeIndex) -> Self {
118        TreeNodeIndex::Parent(parent_index)
119    }
120}
121
122/// TreeNodeIndex references a node in a tree.
123#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
124pub enum TreeNodeIndex {
125    Leaf(LeafNodeIndex),
126    Parent(ParentNodeIndex),
127}
128
129impl TreeNodeIndex {
130    /// Create a new `TreeNodeIndex` from a `u32`.
131    fn new(index: u32) -> Self {
132        if index % 2 == 0 {
133            TreeNodeIndex::Leaf(LeafNodeIndex::from_tree_index(index))
134        } else {
135            TreeNodeIndex::Parent(ParentNodeIndex::from_tree_index(index))
136        }
137    }
138
139    /// Re-exported for testing.
140    #[cfg(any(feature = "test-utils", test))]
141    pub(crate) fn test_new(index: u32) -> Self {
142        Self::new(index)
143    }
144
145    /// Return the inner value as `u32`.
146    fn u32(&self) -> u32 {
147        match self {
148            TreeNodeIndex::Leaf(index) => index.to_tree_index(),
149            TreeNodeIndex::Parent(index) => index.to_tree_index(),
150        }
151    }
152
153    /// Re-exported for testing.
154    #[cfg(any(feature = "test-utils", feature = "crypto-debug", test))]
155    pub(crate) fn test_u32(&self) -> u32 {
156        self.u32()
157    }
158
159    /// Return the inner value as `usize`.
160    #[cfg(any(feature = "test-utils", test))]
161    fn usize(&self) -> usize {
162        self.u32() as usize
163    }
164
165    /// Re-exported for testing.
166    #[cfg(any(feature = "test-utils", test))]
167    pub(crate) fn test_usize(&self) -> usize {
168        self.usize()
169    }
170}
171
172impl Ord for TreeNodeIndex {
173    fn cmp(&self, other: &TreeNodeIndex) -> Ordering {
174        self.u32().cmp(&other.u32())
175    }
176}
177
178impl PartialOrd for TreeNodeIndex {
179    fn partial_cmp(&self, other: &TreeNodeIndex) -> Option<Ordering> {
180        Some(self.cmp(other))
181    }
182}
183
184#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
185pub(crate) struct TreeSize(u32);
186
187impl TreeSize {
188    /// Create a new `TreeSize` from `nodes`, which will be rounded up to the
189    /// next power of 2. The tree size then reflects the smallest tree that can
190    /// contain the number of nodes.
191    pub(crate) fn new(nodes: u32) -> Self {
192        let k = log2(nodes);
193        TreeSize((1 << (k + 1)) - 1)
194    }
195
196    /// Creates a new `TreeSize` from a specific leaf count
197    #[cfg(any(feature = "test-utils", test))]
198    pub(crate) fn from_leaf_count(leaf_count: u32) -> Self {
199        TreeSize::new(leaf_count * 2)
200    }
201
202    /// Return the number of leaf nodes in the tree.
203    pub(crate) fn leaf_count(&self) -> u32 {
204        (self.0 / 2) + 1
205    }
206
207    /// Return the number of parent nodes in the tree.
208    pub(crate) fn parent_count(&self) -> u32 {
209        self.0 / 2
210    }
211
212    /// Return the inner value as `u32`.
213    pub(crate) fn u32(&self) -> u32 {
214        self.0
215    }
216
217    /// Returns `true` if the leaf is in the left subtree and `false` otherwise.
218    /// If there is only one leaf in the tree, it returns `false`.
219    pub(crate) fn leaf_is_left(&self, leaf_index: LeafNodeIndex) -> bool {
220        leaf_index.u32() < self.leaf_count() / 2
221    }
222
223    /// Increase the size.
224    pub(super) fn inc(&mut self) {
225        self.0 = self.0 * 2 + 1;
226    }
227
228    /// Decrease the size.
229    pub(super) fn dec(&mut self) {
230        debug_assert!(self.0 >= 2);
231        if self.0 >= 2 {
232            self.0 = self.0.div_ceil(2) - 1;
233        } else {
234            self.0 = 0;
235        }
236    }
237}
238
239#[test]
240fn tree_size() {
241    assert_eq!(TreeSize::new(1).u32(), 1);
242    assert_eq!(TreeSize::new(3).u32(), 3);
243    assert_eq!(TreeSize::new(5).u32(), 7);
244    assert_eq!(TreeSize::new(7).u32(), 7);
245    assert_eq!(TreeSize::new(9).u32(), 15);
246    assert_eq!(TreeSize::new(11).u32(), 15);
247    assert_eq!(TreeSize::new(13).u32(), 15);
248    assert_eq!(TreeSize::new(15).u32(), 15);
249    assert_eq!(TreeSize::new(17).u32(), 31);
250}
251
252/// Test if the leaf is in the left subtree.
253#[test]
254fn test_leaf_is_left() {
255    assert!(!TreeSize::new(1).leaf_is_left(LeafNodeIndex::new(0)));
256
257    assert!(TreeSize::new(3).leaf_is_left(LeafNodeIndex::new(0)));
258    assert!(!TreeSize::new(3).leaf_is_left(LeafNodeIndex::new(1)));
259
260    assert!(TreeSize::new(5).leaf_is_left(LeafNodeIndex::new(0)));
261    assert!(TreeSize::new(5).leaf_is_left(LeafNodeIndex::new(1)));
262    assert!(!TreeSize::new(5).leaf_is_left(LeafNodeIndex::new(2)));
263    assert!(!TreeSize::new(5).leaf_is_left(LeafNodeIndex::new(3)));
264
265    assert!(TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(0)));
266    assert!(TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(1)));
267    assert!(TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(2)));
268    assert!(TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(3)));
269    assert!(!TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(4)));
270    assert!(!TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(5)));
271    assert!(!TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(6)));
272    assert!(!TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(7)));
273}
274
275fn log2(x: u32) -> usize {
276    if x == 0 {
277        return 0;
278    }
279    (31 - x.leading_zeros()) as usize
280}
281
282pub fn level(index: u32) -> usize {
283    let x = index;
284    if (x & 0x01) == 0 {
285        return 0;
286    }
287    let mut k = 0;
288    while ((x >> k) & 0x01) == 1 {
289        k += 1;
290    }
291    k
292}
293
294pub(crate) fn root(size: TreeSize) -> TreeNodeIndex {
295    let size = size.u32();
296    debug_assert!(size > 0);
297    TreeNodeIndex::new((1 << log2(size)) - 1)
298}
299
300pub(crate) fn left(index: ParentNodeIndex) -> TreeNodeIndex {
301    let x = index.to_tree_index();
302    let k = level(x);
303    debug_assert!(k > 0);
304    let index = x ^ (0x01 << (k - 1));
305    TreeNodeIndex::new(index)
306}
307
308pub(crate) fn right(index: ParentNodeIndex) -> TreeNodeIndex {
309    let x = index.to_tree_index();
310    let k = level(x);
311    debug_assert!(k > 0);
312    let index = x ^ (0x03 << (k - 1));
313    TreeNodeIndex::new(index)
314}
315
316/// Warning: There is no check about the tree size and whether the parent is
317/// beyond the root
318fn parent(x: TreeNodeIndex) -> ParentNodeIndex {
319    let x = x.u32();
320    let k = level(x);
321    let b = (x >> (k + 1)) & 0x01;
322    let index = (x | (1 << k)) ^ (b << (k + 1));
323    ParentNodeIndex::from_tree_index(index)
324}
325
326/// Re-exported for testing.
327#[cfg(any(feature = "test-utils", test))]
328pub(crate) fn test_parent(index: TreeNodeIndex) -> ParentNodeIndex {
329    parent(index)
330}
331
332fn sibling(index: TreeNodeIndex) -> TreeNodeIndex {
333    let p = parent(index);
334    match index.u32().cmp(&p.to_tree_index()) {
335        Ordering::Less => right(p),
336        Ordering::Greater => left(p),
337        Ordering::Equal => left(p),
338    }
339}
340
341/// Re-exported for testing.
342#[cfg(any(feature = "test-utils", test))]
343pub(crate) fn test_sibling(index: TreeNodeIndex) -> TreeNodeIndex {
344    sibling(index)
345}
346
347/// Direct path from a node to the root.
348/// Does not include the node itself.
349pub(crate) fn direct_path(node_index: LeafNodeIndex, size: TreeSize) -> Vec<ParentNodeIndex> {
350    let r = root(size).u32();
351
352    let mut d = vec![];
353    let mut x = node_index.to_tree_index();
354    while x != r {
355        let parent = parent(TreeNodeIndex::new(x));
356        d.push(parent);
357        x = parent.to_tree_index();
358    }
359    d
360}
361
362/// Copath of a leaf node.
363pub(crate) fn copath(leaf_index: LeafNodeIndex, size: TreeSize) -> Vec<TreeNodeIndex> {
364    // Start with leaf
365    let mut full_path = vec![TreeNodeIndex::Leaf(leaf_index)];
366    let mut direct_path = direct_path(leaf_index, size);
367    if !direct_path.is_empty() {
368        // Remove root
369        direct_path.pop();
370    }
371    full_path.append(
372        &mut direct_path
373            .iter()
374            .map(|i| TreeNodeIndex::Parent(*i))
375            .collect(),
376    );
377
378    full_path.into_iter().map(sibling).collect()
379}
380
381/// Common ancestor of two leaf nodes, aka the node where their direct paths
382/// intersect.
383pub(super) fn lowest_common_ancestor(x: LeafNodeIndex, y: LeafNodeIndex) -> ParentNodeIndex {
384    let x = x.to_tree_index();
385    let y = y.to_tree_index();
386    let (lx, ly) = (level(x) + 1, level(y) + 1);
387    if (lx <= ly) && (x >> ly == y >> ly) {
388        return ParentNodeIndex::from_tree_index(y);
389    } else if (ly <= lx) && (x >> lx == y >> lx) {
390        return ParentNodeIndex::from_tree_index(x);
391    }
392
393    let (mut xn, mut yn) = (x, y);
394    let mut k = 0;
395    while xn != yn {
396        xn >>= 1;
397        yn >>= 1;
398        k += 1;
399    }
400    ParentNodeIndex::from_tree_index((xn << k) + (1 << (k - 1)) - 1)
401}
402
403/// The common direct path of two leaf nodes, i.e. the path from their common
404/// ancestor to the root.
405pub(crate) fn common_direct_path(
406    x: LeafNodeIndex,
407    y: LeafNodeIndex,
408    size: TreeSize,
409) -> Vec<ParentNodeIndex> {
410    let mut x_path = direct_path(x, size);
411    let mut y_path = direct_path(y, size);
412    x_path.reverse();
413    y_path.reverse();
414
415    let mut common_path = vec![];
416
417    for (x, y) in x_path.iter().zip(y_path.iter()) {
418        if x == y {
419            common_path.push(*x);
420        } else {
421            break;
422        }
423    }
424
425    common_path.reverse();
426    common_path
427}
428
429#[cfg(any(feature = "test-utils", test))]
430pub(crate) fn node_width(n: usize) -> usize {
431    if n == 0 {
432        0
433    } else {
434        2 * (n - 1) + 1
435    }
436}
437
438pub(crate) fn is_node_in_tree(node_index: TreeNodeIndex, size: TreeSize) -> bool {
439    node_index.u32() < size.u32()
440}
441
442#[test]
443fn test_node_in_tree() {
444    let tests = [(0u32, 3u32), (1, 3), (2, 5), (5, 7), (2, 11)];
445    for test in tests.iter() {
446        assert!(is_node_in_tree(
447            TreeNodeIndex::new(test.0),
448            TreeSize::new(test.1)
449        ));
450    }
451}
452
453#[test]
454fn test_node_not_in_tree() {
455    let tests = [(3u32, 1u32), (13, 7)];
456    for test in tests.iter() {
457        assert!(!is_node_in_tree(
458            TreeNodeIndex::new(test.0),
459            TreeSize::new(test.1)
460        ));
461    }
462}