openmls/binary_tree/array_representation/
treemath.rs1use std::cmp::Ordering;
2
3use serde::{Deserialize, Serialize};
4use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
5
6pub(crate) const MAX_TREE_SIZE: u32 = 1 << 30;
7pub(crate) const MIN_TREE_SIZE: u32 = 1;
8
9#[derive(
11 Debug,
12 Clone,
13 Copy,
14 PartialEq,
15 Eq,
16 PartialOrd,
17 Ord,
18 Hash,
19 Serialize,
20 Deserialize,
21 TlsDeserialize,
22 TlsDeserializeBytes,
23 TlsSerialize,
24 TlsSize,
25)]
26pub struct LeafNodeIndex(u32);
27
28impl std::fmt::Display for LeafNodeIndex {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.write_fmt(format_args!("{:?}", self.0))
31 }
32}
33
34impl LeafNodeIndex {
35 pub fn new(index: u32) -> Self {
37 LeafNodeIndex(index)
38 }
39
40 pub fn u32(&self) -> u32 {
42 self.0
43 }
44
45 pub fn usize(&self) -> usize {
47 self.u32() as usize
48 }
49
50 fn to_tree_index(self) -> u32 {
52 self.0 * 2
53 }
54
55 fn from_tree_index(node_index: u32) -> Self {
57 debug_assert!(node_index % 2 == 0);
58 LeafNodeIndex(node_index / 2)
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
64pub struct ParentNodeIndex(u32);
65
66impl ParentNodeIndex {
67 pub(crate) fn new(index: u32) -> Self {
69 ParentNodeIndex(index)
70 }
71
72 pub(crate) fn u32(&self) -> u32 {
74 self.0
75 }
76
77 pub(crate) fn usize(&self) -> usize {
78 self.0 as usize
79 }
80
81 fn to_tree_index(self) -> u32 {
83 self.0 * 2 + 1
84 }
85
86 fn from_tree_index(node_index: u32) -> Self {
88 debug_assert!(node_index > 0);
89 debug_assert!(node_index % 2 == 1);
90 ParentNodeIndex((node_index - 1) / 2)
91 }
92}
93
94#[cfg(test)]
95impl ParentNodeIndex {
96 pub(crate) fn test_from_tree_index(node_index: u32) -> Self {
98 Self::from_tree_index(node_index)
99 }
100}
101
102#[cfg(any(feature = "test-utils", test))]
103impl ParentNodeIndex {
104 pub(crate) fn test_to_tree_index(self) -> u32 {
106 self.to_tree_index()
107 }
108}
109
110impl From<LeafNodeIndex> for TreeNodeIndex {
111 fn from(leaf_index: LeafNodeIndex) -> Self {
112 TreeNodeIndex::Leaf(leaf_index)
113 }
114}
115
116impl From<ParentNodeIndex> for TreeNodeIndex {
117 fn from(parent_index: ParentNodeIndex) -> Self {
118 TreeNodeIndex::Parent(parent_index)
119 }
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
124pub enum TreeNodeIndex {
125 Leaf(LeafNodeIndex),
126 Parent(ParentNodeIndex),
127}
128
129impl TreeNodeIndex {
130 fn new(index: u32) -> Self {
132 if index % 2 == 0 {
133 TreeNodeIndex::Leaf(LeafNodeIndex::from_tree_index(index))
134 } else {
135 TreeNodeIndex::Parent(ParentNodeIndex::from_tree_index(index))
136 }
137 }
138
139 #[cfg(any(feature = "test-utils", test))]
141 pub(crate) fn test_new(index: u32) -> Self {
142 Self::new(index)
143 }
144
145 fn u32(&self) -> u32 {
147 match self {
148 TreeNodeIndex::Leaf(index) => index.to_tree_index(),
149 TreeNodeIndex::Parent(index) => index.to_tree_index(),
150 }
151 }
152
153 #[cfg(any(feature = "test-utils", feature = "crypto-debug", test))]
155 pub(crate) fn test_u32(&self) -> u32 {
156 self.u32()
157 }
158
159 #[cfg(any(feature = "test-utils", test))]
161 fn usize(&self) -> usize {
162 self.u32() as usize
163 }
164
165 #[cfg(any(feature = "test-utils", test))]
167 pub(crate) fn test_usize(&self) -> usize {
168 self.usize()
169 }
170}
171
172impl Ord for TreeNodeIndex {
173 fn cmp(&self, other: &TreeNodeIndex) -> Ordering {
174 self.u32().cmp(&other.u32())
175 }
176}
177
178impl PartialOrd for TreeNodeIndex {
179 fn partial_cmp(&self, other: &TreeNodeIndex) -> Option<Ordering> {
180 Some(self.cmp(other))
181 }
182}
183
184#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
185pub(crate) struct TreeSize(u32);
186
187impl TreeSize {
188 pub(crate) fn new(nodes: u32) -> Self {
192 let k = log2(nodes);
193 TreeSize((1 << (k + 1)) - 1)
194 }
195
196 #[cfg(any(feature = "test-utils", test))]
198 pub(crate) fn from_leaf_count(leaf_count: u32) -> Self {
199 TreeSize::new(leaf_count * 2)
200 }
201
202 pub(crate) fn leaf_count(&self) -> u32 {
204 (self.0 / 2) + 1
205 }
206
207 pub(crate) fn parent_count(&self) -> u32 {
209 self.0 / 2
210 }
211
212 pub(crate) fn u32(&self) -> u32 {
214 self.0
215 }
216
217 pub(crate) fn leaf_is_left(&self, leaf_index: LeafNodeIndex) -> bool {
220 leaf_index.u32() < self.leaf_count() / 2
221 }
222
223 pub(super) fn inc(&mut self) {
225 self.0 = self.0 * 2 + 1;
226 }
227
228 pub(super) fn dec(&mut self) {
230 debug_assert!(self.0 >= 2);
231 if self.0 >= 2 {
232 self.0 = self.0.div_ceil(2) - 1;
233 } else {
234 self.0 = 0;
235 }
236 }
237}
238
239#[test]
240fn tree_size() {
241 assert_eq!(TreeSize::new(1).u32(), 1);
242 assert_eq!(TreeSize::new(3).u32(), 3);
243 assert_eq!(TreeSize::new(5).u32(), 7);
244 assert_eq!(TreeSize::new(7).u32(), 7);
245 assert_eq!(TreeSize::new(9).u32(), 15);
246 assert_eq!(TreeSize::new(11).u32(), 15);
247 assert_eq!(TreeSize::new(13).u32(), 15);
248 assert_eq!(TreeSize::new(15).u32(), 15);
249 assert_eq!(TreeSize::new(17).u32(), 31);
250}
251
252#[test]
254fn test_leaf_is_left() {
255 assert!(!TreeSize::new(1).leaf_is_left(LeafNodeIndex::new(0)));
256
257 assert!(TreeSize::new(3).leaf_is_left(LeafNodeIndex::new(0)));
258 assert!(!TreeSize::new(3).leaf_is_left(LeafNodeIndex::new(1)));
259
260 assert!(TreeSize::new(5).leaf_is_left(LeafNodeIndex::new(0)));
261 assert!(TreeSize::new(5).leaf_is_left(LeafNodeIndex::new(1)));
262 assert!(!TreeSize::new(5).leaf_is_left(LeafNodeIndex::new(2)));
263 assert!(!TreeSize::new(5).leaf_is_left(LeafNodeIndex::new(3)));
264
265 assert!(TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(0)));
266 assert!(TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(1)));
267 assert!(TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(2)));
268 assert!(TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(3)));
269 assert!(!TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(4)));
270 assert!(!TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(5)));
271 assert!(!TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(6)));
272 assert!(!TreeSize::new(15).leaf_is_left(LeafNodeIndex::new(7)));
273}
274
275fn log2(x: u32) -> usize {
276 if x == 0 {
277 return 0;
278 }
279 (31 - x.leading_zeros()) as usize
280}
281
282pub fn level(index: u32) -> usize {
283 let x = index;
284 if (x & 0x01) == 0 {
285 return 0;
286 }
287 let mut k = 0;
288 while ((x >> k) & 0x01) == 1 {
289 k += 1;
290 }
291 k
292}
293
294pub(crate) fn root(size: TreeSize) -> TreeNodeIndex {
295 let size = size.u32();
296 debug_assert!(size > 0);
297 TreeNodeIndex::new((1 << log2(size)) - 1)
298}
299
300pub(crate) fn left(index: ParentNodeIndex) -> TreeNodeIndex {
301 let x = index.to_tree_index();
302 let k = level(x);
303 debug_assert!(k > 0);
304 let index = x ^ (0x01 << (k - 1));
305 TreeNodeIndex::new(index)
306}
307
308pub(crate) fn right(index: ParentNodeIndex) -> TreeNodeIndex {
309 let x = index.to_tree_index();
310 let k = level(x);
311 debug_assert!(k > 0);
312 let index = x ^ (0x03 << (k - 1));
313 TreeNodeIndex::new(index)
314}
315
316fn parent(x: TreeNodeIndex) -> ParentNodeIndex {
319 let x = x.u32();
320 let k = level(x);
321 let b = (x >> (k + 1)) & 0x01;
322 let index = (x | (1 << k)) ^ (b << (k + 1));
323 ParentNodeIndex::from_tree_index(index)
324}
325
326#[cfg(any(feature = "test-utils", test))]
328pub(crate) fn test_parent(index: TreeNodeIndex) -> ParentNodeIndex {
329 parent(index)
330}
331
332fn sibling(index: TreeNodeIndex) -> TreeNodeIndex {
333 let p = parent(index);
334 match index.u32().cmp(&p.to_tree_index()) {
335 Ordering::Less => right(p),
336 Ordering::Greater => left(p),
337 Ordering::Equal => left(p),
338 }
339}
340
341#[cfg(any(feature = "test-utils", test))]
343pub(crate) fn test_sibling(index: TreeNodeIndex) -> TreeNodeIndex {
344 sibling(index)
345}
346
347pub(crate) fn direct_path(node_index: LeafNodeIndex, size: TreeSize) -> Vec<ParentNodeIndex> {
350 let r = root(size).u32();
351
352 let mut d = vec![];
353 let mut x = node_index.to_tree_index();
354 while x != r {
355 let parent = parent(TreeNodeIndex::new(x));
356 d.push(parent);
357 x = parent.to_tree_index();
358 }
359 d
360}
361
362pub(crate) fn copath(leaf_index: LeafNodeIndex, size: TreeSize) -> Vec<TreeNodeIndex> {
364 let mut full_path = vec![TreeNodeIndex::Leaf(leaf_index)];
366 let mut direct_path = direct_path(leaf_index, size);
367 if !direct_path.is_empty() {
368 direct_path.pop();
370 }
371 full_path.append(
372 &mut direct_path
373 .iter()
374 .map(|i| TreeNodeIndex::Parent(*i))
375 .collect(),
376 );
377
378 full_path.into_iter().map(sibling).collect()
379}
380
381pub(super) fn lowest_common_ancestor(x: LeafNodeIndex, y: LeafNodeIndex) -> ParentNodeIndex {
384 let x = x.to_tree_index();
385 let y = y.to_tree_index();
386 let (lx, ly) = (level(x) + 1, level(y) + 1);
387 if (lx <= ly) && (x >> ly == y >> ly) {
388 return ParentNodeIndex::from_tree_index(y);
389 } else if (ly <= lx) && (x >> lx == y >> lx) {
390 return ParentNodeIndex::from_tree_index(x);
391 }
392
393 let (mut xn, mut yn) = (x, y);
394 let mut k = 0;
395 while xn != yn {
396 xn >>= 1;
397 yn >>= 1;
398 k += 1;
399 }
400 ParentNodeIndex::from_tree_index((xn << k) + (1 << (k - 1)) - 1)
401}
402
403pub(crate) fn common_direct_path(
406 x: LeafNodeIndex,
407 y: LeafNodeIndex,
408 size: TreeSize,
409) -> Vec<ParentNodeIndex> {
410 let mut x_path = direct_path(x, size);
411 let mut y_path = direct_path(y, size);
412 x_path.reverse();
413 y_path.reverse();
414
415 let mut common_path = vec![];
416
417 for (x, y) in x_path.iter().zip(y_path.iter()) {
418 if x == y {
419 common_path.push(*x);
420 } else {
421 break;
422 }
423 }
424
425 common_path.reverse();
426 common_path
427}
428
429#[cfg(any(feature = "test-utils", test))]
430pub(crate) fn node_width(n: usize) -> usize {
431 if n == 0 {
432 0
433 } else {
434 2 * (n - 1) + 1
435 }
436}
437
438pub(crate) fn is_node_in_tree(node_index: TreeNodeIndex, size: TreeSize) -> bool {
439 node_index.u32() < size.u32()
440}
441
442#[test]
443fn test_node_in_tree() {
444 let tests = [(0u32, 3u32), (1, 3), (2, 5), (5, 7), (2, 11)];
445 for test in tests.iter() {
446 assert!(is_node_in_tree(
447 TreeNodeIndex::new(test.0),
448 TreeSize::new(test.1)
449 ));
450 }
451}
452
453#[test]
454fn test_node_not_in_tree() {
455 let tests = [(3u32, 1u32), (13, 7)];
456 for test in tests.iter() {
457 assert!(!is_node_in_tree(
458 TreeNodeIndex::new(test.0),
459 TreeSize::new(test.1)
460 ));
461 }
462}