Skip to main content

openmls/treesync/
mod.rs

1//! This module implements the ratchet tree component of MLS.
2//!
3//! It exposes the [`Node`] enum that can contain either a [`LeafNode`] or a [`ParentNode`].
4
5// # Internal documentation
6//
7// This module provides the [`TreeSync`] struct, which contains the state
8// shared between a group of MLS clients in the shape of a tree, where each
9// non-blank leaf corresponds to one group member. The functions provided by
10// its implementation allow the creation of a [`TreeSyncDiff`] instance, which
11// in turn can be mutably operated on and merged back into the original
12// [`TreeSync`] instance.
13//
14// The submodules of this module define the nodes of the tree (`nodes`),
15// helper functions and structs for the algorithms used to sync the tree across
16// the group ([`hashes`]) and the diff functionality ([`diff`]).
17//
18// Finally, this module contains the [`treekem`] module, which allows the
19// encryption and decryption of updates to the tree.
20
21#[cfg(any(feature = "test-utils", test))]
22use std::fmt;
23
24use openmls_traits::{
25    crypto::OpenMlsCrypto,
26    signatures::Signer,
27    types::{Ciphersuite, CryptoError},
28};
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
32
33use self::{
34    diff::{StagedTreeSyncDiff, TreeSyncDiff},
35    node::{
36        leaf_node::{
37            Capabilities, NewLeafNodeParams, TreeInfoTbs, TreePosition, VerifiableLeafNode,
38        },
39        NodeIn,
40    },
41    treesync_node::{TreeSyncLeafNode, TreeSyncNode, TreeSyncParentNode},
42};
43use crate::binary_tree::array_representation::ParentNodeIndex;
44#[cfg(any(feature = "test-utils", test))]
45use crate::{binary_tree::array_representation::level, test_utils::bytes_to_hex};
46use crate::{
47    binary_tree::{
48        array_representation::{is_node_in_tree, LeafNodeIndex, TreeSize},
49        MlsBinaryTree, MlsBinaryTreeError,
50    },
51    ciphersuite::{signable::Verifiable, Secret},
52    credentials::CredentialWithKey,
53    error::LibraryError,
54    extensions::Extensions,
55    group::{GroupId, Member},
56    key_packages::Lifetime,
57    messages::{PathSecret, PathSecretError},
58    schedule::CommitSecret,
59    storage::OpenMlsProvider,
60};
61
62// Private
63mod hashes;
64use errors::*;
65
66// Crate
67pub(crate) mod diff;
68pub(crate) mod node;
69pub(crate) mod treekem;
70pub(crate) mod treesync_node;
71
72use node::encryption_keys::EncryptionKeyPair;
73
74// Public
75pub mod errors;
76#[cfg(feature = "test-utils")]
77pub use node::encryption_keys::test_utils;
78pub use node::encryption_keys::EncryptionKey;
79
80// Public re-exports
81pub use node::{
82    leaf_node::{
83        LeafNode, LeafNodeParameters, LeafNodeParametersBuilder, LeafNodeSource,
84        LeafNodeUpdateError,
85    },
86    parent_node::ParentNode,
87    Node,
88};
89
90// Tests
91#[cfg(any(feature = "test-utils", test))]
92pub mod tests_and_kats;
93
94/// An exported ratchet tree as used in, e.g., [`GroupInfo`](crate::messages::group_info::GroupInfo).
95#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
96pub struct RatchetTree(Vec<Option<Node>>);
97
98/// An error during processing of an incoming ratchet tree.
99#[derive(Error, Debug, PartialEq, Clone)]
100pub enum RatchetTreeError {
101    /// The ratchet tree is empty.
102    #[error("The ratchet tree has no nodes.")]
103    MissingNodes,
104    /// The ratchet tree has a trailing blank node.
105    #[error("The ratchet tree has trailing blank nodes.")]
106    TrailingBlankNodes,
107    /// Invalid node signature.
108    #[error("Invalid node signature.")]
109    InvalidNodeSignature,
110    /// Wrong node type.
111    #[error("Wrong node type.")]
112    WrongNodeType,
113}
114
115impl RatchetTree {
116    /// Create a [`RatchetTree`] from a vector of nodes stripping all trailing blank nodes.
117    ///
118    /// Note: The caller must ensure to call this with a vector that is *not* empty after removing all trailing blank nodes.
119    fn trimmed(mut nodes: Vec<Option<Node>>) -> Self {
120        // Remove all trailing blank nodes.
121        match nodes.iter().enumerate().rfind(|(_, node)| node.is_some()) {
122            Some((rightmost_nonempty_position, _)) => {
123                // We need to add 1 to `rightmost_nonempty_position` to keep the rightmost node.
124                nodes.resize(rightmost_nonempty_position + 1, None);
125            }
126            None => {
127                // If there is no rightmost non-blank node, the vector consist of blank nodes only.
128                nodes.clear();
129            }
130        }
131
132        debug_assert!(!nodes.is_empty(), "Caller should have ensured that `RatchetTree::trimmed` is not called with a vector that is empty after removing all trailing blank nodes.");
133        Self(nodes)
134    }
135
136    /// Create a new [`RatchetTree`] from a vector of nodes.
137    pub(crate) fn try_from_nodes(
138        ciphersuite: Ciphersuite,
139        crypto: &impl OpenMlsCrypto,
140        nodes: Vec<Option<NodeIn>>,
141        group_id: &GroupId,
142    ) -> Result<Self, RatchetTreeError> {
143        // ValSem300: "Exported ratchet trees must not have trailing blank nodes."
144        //
145        // We can check this by only looking at the last node (if any).
146        match nodes.last() {
147            Some(None) => {
148                // The ratchet tree is not empty, i.e., has a last node, *but* the last node *is* blank.
149                Err(RatchetTreeError::TrailingBlankNodes)
150            }
151            None => {
152                // The ratchet tree is empty.
153                Err(RatchetTreeError::MissingNodes)
154            }
155            Some(Some(_)) => {
156                // The ratchet tree is not empty, i.e., has a last node, and the last node is not blank.
157
158                // Verify the nodes.
159                // https://validation.openmls.tech/#valn1407
160                let mut verified_nodes = Vec::new();
161                for (index, node) in nodes.into_iter().enumerate() {
162                    let verified_node = match (index % 2, node) {
163                        // Even indices must be leaf nodes.
164                        (0, Some(NodeIn::LeafNode(leaf_node))) => {
165                            let tree_position = TreePosition::new(
166                                group_id.clone(),
167                                LeafNodeIndex::new((index / 2) as u32),
168                            );
169                            let verifiable_leaf_node = leaf_node.into_verifiable_leaf_node();
170                            let signature_key = verifiable_leaf_node
171                                .signature_key()
172                                .clone()
173                                .into_signature_public_key_enriched(
174                                    ciphersuite.signature_algorithm(),
175                                );
176                            Some(Node::leaf_node(match verifiable_leaf_node {
177                                VerifiableLeafNode::KeyPackage(leaf_node) => leaf_node
178                                    .verify(crypto, &signature_key)
179                                    .map_err(|_| RatchetTreeError::InvalidNodeSignature)?,
180                                VerifiableLeafNode::Update(mut leaf_node) => {
181                                    leaf_node.add_tree_position(tree_position);
182                                    leaf_node
183                                        .verify(crypto, &signature_key)
184                                        .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
185                                }
186                                VerifiableLeafNode::Commit(mut leaf_node) => {
187                                    leaf_node.add_tree_position(tree_position);
188                                    leaf_node
189                                        .verify(crypto, &signature_key)
190                                        .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
191                                }
192                            }))
193                        }
194                        // Odd indices must be parent nodes.
195                        (1, Some(NodeIn::ParentNode(parent_node))) => {
196                            Some(Node::ParentNode(parent_node))
197                        }
198                        // Blank nodes.
199                        (_, None) => None,
200                        // All other cases are invalid.
201                        _ => {
202                            return Err(RatchetTreeError::WrongNodeType);
203                        }
204                    };
205                    verified_nodes.push(verified_node);
206                }
207                Ok(Self::trimmed(verified_nodes))
208            }
209        }
210    }
211}
212
213/// A ratchet tree made of unverified nodes. This is used for deserialization
214/// and verification.
215#[derive(
216    PartialEq,
217    Eq,
218    Clone,
219    Debug,
220    Serialize,
221    Deserialize,
222    TlsDeserialize,
223    TlsDeserializeBytes,
224    TlsSerialize,
225    TlsSize,
226)]
227pub struct RatchetTreeIn(Vec<Option<NodeIn>>);
228
229impl RatchetTreeIn {
230    /// Create a new [`RatchetTreeIn`] from a vector of nodes after verifying
231    /// the nodes.
232    pub fn into_verified(
233        self,
234        ciphersuite: Ciphersuite,
235        crypto: &impl OpenMlsCrypto,
236        group_id: &GroupId,
237    ) -> Result<RatchetTree, RatchetTreeError> {
238        RatchetTree::try_from_nodes(ciphersuite, crypto, self.0, group_id)
239    }
240
241    fn from_ratchet_tree(ratchet_tree: RatchetTree) -> Self {
242        let nodes = ratchet_tree
243            .0
244            .into_iter()
245            .map(|node| node.map(NodeIn::from))
246            .collect();
247        Self(nodes)
248    }
249
250    #[cfg(test)]
251    pub(crate) fn from_nodes(nodes: Vec<Option<NodeIn>>) -> Self {
252        Self(nodes)
253    }
254}
255
256impl From<RatchetTree> for RatchetTreeIn {
257    fn from(ratchet_tree: RatchetTree) -> Self {
258        RatchetTreeIn::from_ratchet_tree(ratchet_tree)
259    }
260}
261
262// The following `From` implementation breaks abstraction layers and MUST
263// NOT be made available outside of tests or "test-utils".
264#[cfg(any(feature = "test-utils", test))]
265impl From<RatchetTreeIn> for RatchetTree {
266    fn from(ratchet_tree_in: RatchetTreeIn) -> Self {
267        Self(
268            ratchet_tree_in
269                .0
270                .into_iter()
271                .map(|node| node.map(Node::from))
272                .collect(),
273        )
274    }
275}
276
277#[cfg(any(feature = "test-utils", test))]
278fn log2(x: u32) -> usize {
279    if x == 0 {
280        return 0;
281    }
282    (31 - x.leading_zeros()) as usize
283}
284
285#[cfg(any(feature = "test-utils", test))]
286pub(crate) fn root(size: u32) -> u32 {
287    (1 << log2(size)) - 1
288}
289
290#[cfg(any(feature = "test-utils", test))]
291impl fmt::Display for RatchetTree {
292    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293        let factor = 3;
294        let nodes = &self.0;
295        let tree_size = nodes.len() as u32;
296
297        for (i, node) in nodes.iter().enumerate() {
298            let level = level(i as u32);
299            write!(f, "{i:04}")?;
300            if let Some(node) = node {
301                let (key_bytes, parent_hash_bytes) = match node {
302                    Node::LeafNode(leaf_node) => {
303                        write!(f, "\tL      ")?;
304                        let key_bytes = leaf_node.encryption_key().as_slice();
305                        let parent_hash_bytes = leaf_node
306                            .parent_hash()
307                            .map(bytes_to_hex)
308                            .unwrap_or_default();
309                        (key_bytes, parent_hash_bytes)
310                    }
311                    Node::ParentNode(parent_node) => {
312                        if root(tree_size) == i as u32 {
313                            write!(f, "\tP (*)  ")?;
314                        } else {
315                            write!(f, "\tP      ")?;
316                        }
317                        let key_bytes = parent_node.public_key().as_slice();
318                        let parent_hash_string = bytes_to_hex(parent_node.parent_hash());
319                        (key_bytes, parent_hash_string)
320                    }
321                };
322                write!(
323                    f,
324                    "PK: {}  PH: {} | ",
325                    bytes_to_hex(key_bytes),
326                    if !parent_hash_bytes.is_empty() {
327                        parent_hash_bytes
328                    } else {
329                        str::repeat("  ", 32)
330                    }
331                )?;
332
333                write!(f, "{}◼︎", str::repeat(" ", level * factor))?;
334            } else {
335                if root(tree_size) == i as u32 {
336                    write!(
337                        f,
338                        "\t_ (*)  PK: {}  PH: {} | ",
339                        str::repeat("__", 32),
340                        str::repeat("__", 32)
341                    )?;
342                } else {
343                    write!(
344                        f,
345                        "\t_      PK: {}  PH: {} | ",
346                        str::repeat("__", 32),
347                        str::repeat("__", 32)
348                    )?;
349                }
350
351                write!(f, "{}❑", str::repeat(" ", level * factor))?;
352            }
353            writeln!(f)?;
354        }
355
356        Ok(())
357    }
358}
359
360/// The [`TreeSync`] struct holds an `MlsBinaryTree` instance, which contains
361/// the state that is synced across the group, as well as the [`LeafNodeIndex`]
362/// pointing to the leaf of this group member and the current hash of the tree.
363///
364/// It follows the same pattern of tree and diff as the underlying
365/// `MlsBinaryTree`, where the [`TreeSync`] instance is immutable safe for
366/// merging a `TreeSyncDiff`, which can be created, staged and merged (see
367/// `TreeSyncDiff`).
368///
369/// [`TreeSync`] instance guarantee a few invariants that are checked upon
370/// creating a new instance from an imported set of nodes, as well as when
371/// merging a diff.
372#[derive(Debug, Serialize, Deserialize)]
373#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq, Clone))]
374pub struct TreeSync {
375    tree: MlsBinaryTree<TreeSyncLeafNode, TreeSyncParentNode>,
376    tree_hash: Vec<u8>,
377}
378
379impl TreeSync {
380    /// Create a new tree with an own leaf for the given credential.
381    ///
382    /// Returns the resulting [`TreeSync`] instance, as well as the
383    /// corresponding [`CommitSecret`].
384    pub(crate) fn new(
385        provider: &impl OpenMlsProvider,
386        signer: &impl Signer,
387        ciphersuite: Ciphersuite,
388        credential_with_key: CredentialWithKey,
389        life_time: Lifetime,
390        capabilities: Capabilities,
391        extensions: Extensions<LeafNode>,
392    ) -> Result<(Self, CommitSecret, EncryptionKeyPair), LibraryError> {
393        let new_leaf_node_params = NewLeafNodeParams {
394            ciphersuite,
395            credential_with_key,
396            // Creation of a group is considered to be from a key package.
397            leaf_node_source: LeafNodeSource::KeyPackage(life_time),
398            capabilities,
399            extensions,
400            tree_info_tbs: TreeInfoTbs::KeyPackage,
401        };
402        let (leaf, encryption_key_pair) = LeafNode::new(provider, signer, new_leaf_node_params)?;
403
404        let node = Node::leaf_node(leaf);
405        let path_secret: PathSecret = Secret::random(ciphersuite, provider.rand())
406            .map_err(LibraryError::unexpected_crypto_error)?
407            .into();
408        let commit_secret: CommitSecret = path_secret
409            .derive_path_secret(provider.crypto(), ciphersuite)?
410            .into();
411        let nodes = vec![TreeSyncNode::from(node).into()];
412        let tree = MlsBinaryTree::new(nodes)
413            .map_err(|_| LibraryError::custom("Unexpected error creating the binary tree."))?;
414        let mut tree_sync = Self {
415            tree,
416            tree_hash: vec![],
417        };
418        // Populate tree hash caches.
419        tree_sync.populate_parent_hashes(provider.crypto(), ciphersuite)?;
420
421        Ok((tree_sync, commit_secret, encryption_key_pair))
422    }
423
424    /// Return the full tree
425    pub(crate) fn tree(&self) -> &MlsBinaryTree<TreeSyncLeafNode, TreeSyncParentNode> {
426        &self.tree
427    }
428
429    /// Return the tree hash of the root node of the tree.
430    pub(crate) fn tree_hash(&self) -> &[u8] {
431        self.tree_hash.as_slice()
432    }
433
434    /// Merge the given diff into this `TreeSync` instance, refreshing the
435    /// `tree_hash` value in the process.
436    pub(crate) fn merge_diff(&mut self, tree_sync_diff: StagedTreeSyncDiff) {
437        let (diff, new_tree_hash) = tree_sync_diff.into_parts();
438        self.tree_hash = new_tree_hash;
439        self.tree.merge_diff(diff);
440    }
441
442    /// Create an empty diff based on this [`TreeSync`] instance all operations
443    /// are created based on an initial, empty [`TreeSyncDiff`].
444    pub(crate) fn empty_diff(&self) -> TreeSyncDiff<'_> {
445        self.into()
446    }
447
448    /// A helper function that generates a [`TreeSync`] instance from the given
449    /// slice of nodes. It verifies that the provided encryption key is present
450    /// in the tree and that the invariants documented in [`TreeSync`] hold.
451    pub(crate) fn from_ratchet_tree(
452        crypto: &impl OpenMlsCrypto,
453        ciphersuite: Ciphersuite,
454        ratchet_tree: RatchetTree,
455    ) -> Result<Self, TreeSyncFromNodesError> {
456        // TODO #800: Unmerged leaves should be checked
457        let total_nodes = ratchet_tree.0.len();
458        let mut leaf_nodes = Vec::with_capacity(total_nodes.div_ceil(2));
459        let mut parent_nodes = Vec::with_capacity(total_nodes / 2);
460
461        // Set the leaf indices in all the leaves and convert the node types.
462        for (node_index, node_option) in ratchet_tree.0.into_iter().enumerate() {
463            if node_index % 2 == 0 {
464                let leaf = match node_option {
465                    Some(node) => match TreeSyncNode::from(node) {
466                        TreeSyncNode::Leaf(l) => *l,
467                        TreeSyncNode::Parent(_) => {
468                            return Err(TreeSyncFromNodesError::from(
469                                PublicTreeError::MalformedTree,
470                            ))
471                        }
472                    },
473                    None => TreeSyncLeafNode::blank(),
474                };
475                leaf_nodes.push(leaf);
476            } else {
477                let parent = match node_option {
478                    Some(node) => match TreeSyncNode::from(node) {
479                        TreeSyncNode::Parent(p) => *p,
480                        TreeSyncNode::Leaf(_) => {
481                            return Err(TreeSyncFromNodesError::from(
482                                PublicTreeError::MalformedTree,
483                            ))
484                        }
485                    },
486                    None => TreeSyncParentNode::blank(),
487                };
488                parent_nodes.push(parent);
489            }
490        }
491
492        let tree = MlsBinaryTree::from_components(leaf_nodes, parent_nodes)
493            .map_err(|_| PublicTreeError::MalformedTree)?;
494        let mut tree_sync = Self {
495            tree,
496            tree_hash: vec![],
497        };
498
499        // Verify all parent hashes.
500        tree_sync
501            .verify_parent_hashes(crypto, ciphersuite)
502            .map_err(|e| match e {
503                TreeSyncParentHashError::LibraryError(e) => e.into(),
504                TreeSyncParentHashError::InvalidParentHash => {
505                    TreeSyncFromNodesError::from(PublicTreeError::InvalidParentHash)
506                }
507            })?;
508
509        // Populate tree hash caches.
510        tree_sync.populate_parent_hashes(crypto, ciphersuite)?;
511        Ok(tree_sync)
512    }
513
514    /// Find the `LeafNodeIndex` which a new leaf would have if it were added to the
515    /// tree. This is either the left-most blank node or, if there are no blank
516    /// leaves, the leaf count, since adding a member would extend the tree by
517    /// one leaf.
518    pub(crate) fn free_leaf_index(&self) -> LeafNodeIndex {
519        let diff = self.empty_diff();
520        diff.free_leaf_index()
521    }
522
523    /// Populate the parent hash caches of all nodes in the tree.
524    fn populate_parent_hashes(
525        &mut self,
526        crypto: &impl OpenMlsCrypto,
527        ciphersuite: Ciphersuite,
528    ) -> Result<(), LibraryError> {
529        let diff = self.empty_diff();
530        // Make the diff into a staged diff. This implicitly computes the
531        // tree hashes and poulates the tree hash caches.
532        let staged_diff = diff.into_staged_diff(crypto, ciphersuite)?;
533        // Merge the diff.
534        self.merge_diff(staged_diff);
535        Ok(())
536    }
537
538    /// Verify the parent hashes of all parent nodes in the tree.
539    ///
540    /// Returns an error if one of the parent nodes in the tree has an invalid
541    /// parent hash.
542    fn verify_parent_hashes(
543        &self,
544        crypto: &impl OpenMlsCrypto,
545        ciphersuite: Ciphersuite,
546    ) -> Result<(), TreeSyncParentHashError> {
547        // The ability to verify parent hashes is required both for diffs and
548        // treesync instances. We choose the computationally slightly more
549        // expensive solution of implementing parent hash verification for the
550        // diff and creating an empty diff whenever we need to verify parent
551        // hashes for a `TreeSync` instance. At the time of writing, this
552        // happens only upon construction of a `TreeSync` instance from a vector
553        // of nodes. The alternative solution would be to create a `TreeLike`
554        // trait, which allows tree navigation and node access. We could then
555        // implement `TreeLike` for both `TreeSync` and `TreeSyncDiff` and
556        // finally implement parent hash verification for any struct that
557        // implements `TreeLike`. We choose the less complex version for now.
558        // Should this turn out to cause too much computational overhead, we
559        // should reconsider and choose the alternative sketched above
560        let diff = self.empty_diff();
561        // No need to merge the diff, since we didn't actually modify any state.
562        diff.verify_parent_hashes(crypto, ciphersuite)
563    }
564
565    /// Returns the tree size
566    pub(crate) fn tree_size(&self) -> TreeSize {
567        self.tree.tree_size()
568    }
569
570    /// Returns an iterator over the (non-blank) [`LeafNode`]s in the tree.
571    pub fn full_leaves(&self) -> impl Iterator<Item = (LeafNodeIndex, &LeafNode)> {
572        self.tree
573            .leaves()
574            .filter_map(|(index, tsn)| tsn.node().as_ref().map(|ln| (index, ln)))
575    }
576
577    /// Returns an iterator over the (non-blank) [`ParentNode`]s in the tree.
578    pub fn full_parents(&self) -> impl Iterator<Item = (ParentNodeIndex, &ParentNode)> {
579        self.tree
580            .parents()
581            .filter_map(|(index, tsn)| tsn.node().as_ref().map(|pn| (index, pn)))
582    }
583
584    /// Returns an iterator over the [`ParentNodeIndex`]es of blank [`ParentNode`]s in the tree.
585    pub fn blank_parents<'a>(&'a self) -> impl Iterator<Item = ParentNodeIndex> + 'a {
586        self.tree
587            .parents()
588            .filter_map(|(index, tsn)| tsn.node().as_ref().map_or(Some(index), |_| None))
589    }
590
591    /// Returns an iterator over the [`LeafNodeIndex`]es of blank [`LeafNode`]s in the tree.
592    pub fn blank_leaves<'a>(&'a self) -> impl Iterator<Item = LeafNodeIndex> + 'a {
593        self.tree
594            .leaves()
595            .filter_map(|(index, tsn)| tsn.node().as_ref().map_or(Some(index), |_| None))
596    }
597
598    /// Returns the index of the last full leaf in the tree.
599    fn rightmost_full_leaf(&self) -> LeafNodeIndex {
600        let mut index = LeafNodeIndex::new(0);
601        for (leaf_index, leaf) in self.tree.leaves() {
602            if leaf.node().as_ref().is_some() {
603                index = leaf_index;
604            }
605        }
606        index
607    }
608
609    /// Returns a list of [`Member`]s containing only full nodes.
610    ///
611    /// XXX: For performance reasons we probably want to have this in a borrowing
612    ///      version as well. But it might well go away again.
613    pub(crate) fn full_leaf_members(&self) -> impl Iterator<Item = Member> + '_ {
614        self.tree
615            .leaves()
616            // Filter out blank nodes
617            .filter_map(|(index, tsn)| tsn.node().as_ref().map(|node| (index, node)))
618            // Map to `Member`
619            .map(|(index, leaf_node)| {
620                Member::new(
621                    index,
622                    leaf_node.encryption_key().as_slice().to_vec(),
623                    leaf_node.signature_key().as_slice().to_vec(),
624                    leaf_node.credential().clone(),
625                )
626            })
627    }
628
629    /// Returns the nodes in the tree ordered according to the
630    /// array-representation of the underlying binary tree.
631    pub fn export_ratchet_tree(&self) -> RatchetTree {
632        let mut nodes = Vec::new();
633
634        // Determine the index of the rightmost full leaf.
635        let max_length = self.rightmost_full_leaf();
636
637        // We take all the leaves including the rightmost full leaf, blank
638        // leaves beyond that are trimmed.
639        let mut leaves = self
640            .tree
641            .leaves()
642            .map(|(_, leaf)| leaf)
643            .take(max_length.usize() + 1);
644
645        // Get the first leaf.
646        if let Some(leaf) = leaves.next() {
647            nodes.push(leaf.node().clone().map(Node::leaf_node));
648        } else {
649            // The tree was empty.
650            return RatchetTree::trimmed(vec![]);
651        }
652
653        // Blank parent node used for padding
654        let default_parent = TreeSyncParentNode::default();
655
656        // Get the parents.
657        let parents = self
658            .tree
659            .parents()
660            // Drop the index
661            .map(|(_, parent)| parent)
662            // Take the parents up to the max length
663            .take(max_length.usize())
664            // Pad the parents with blank nodes if needed
665            .chain(
666                (self.tree.parents().count()..self.tree.leaves().count() - 1)
667                    .map(|_| &default_parent),
668            );
669
670        // Interleave the leaves and parents.
671        for (leaf, parent) in leaves.zip(parents) {
672            nodes.push(parent.node().clone().map(Node::parent_node));
673            nodes.push(leaf.node().clone().map(Node::leaf_node));
674        }
675
676        RatchetTree::trimmed(nodes)
677    }
678
679    /// Return a reference to the leaf at the given `LeafNodeIndex` or `None` if the
680    /// leaf is blank.
681    pub(crate) fn leaf(&self, leaf_index: LeafNodeIndex) -> Option<&LeafNode> {
682        let tsn = self.tree.leaf(leaf_index);
683        tsn.node().as_ref()
684    }
685
686    /// Returns a [`TreeSyncError`] if the `leaf_index` is not a leaf in this
687    /// tree or empty.
688    pub(crate) fn is_leaf_in_tree(&self, leaf_index: LeafNodeIndex) -> bool {
689        is_node_in_tree(leaf_index.into(), self.tree.tree_size())
690    }
691
692    /// Return a vector containing all [`EncryptionKey`]s for which the owner of
693    /// the given `leaf_index` should have private key material.
694    pub(crate) fn owned_encryption_keys(&self, leaf_index: LeafNodeIndex) -> Vec<EncryptionKey> {
695        self.empty_diff()
696            .encryption_keys(leaf_index)
697            .cloned()
698            .collect::<Vec<EncryptionKey>>()
699    }
700
701    /// Derives [`EncryptionKeyPair`]s for the nodes in the shared direct path
702    /// of the leaves with index `leaf_index` and `sender_index`.  This function
703    /// also checks that the derived public keys match the existing public keys.
704    ///
705    /// Returns the `CommitSecret` derived from the path secret of the root
706    /// node, as well as the derived [`EncryptionKeyPair`]s. Returns an error if
707    /// the target leaf is outside of the tree.
708    ///
709    /// Returns TreeSyncSetPathError::PublicKeyMismatch if the derived keys don't
710    /// match with the existing ones.
711    ///
712    /// Returns TreeSyncSetPathError::LibraryError if the sender_index is not
713    /// in the tree.
714    pub(crate) fn derive_path_secrets(
715        &self,
716        crypto: &impl OpenMlsCrypto,
717        ciphersuite: Ciphersuite,
718        mut path_secret: PathSecret,
719        sender_index: LeafNodeIndex,
720        leaf_index: LeafNodeIndex,
721    ) -> Result<(Vec<EncryptionKeyPair>, CommitSecret), DerivePathError> {
722        // We assume both nodes are in the tree, since the sender_index must be in the tree
723        // Skip the nodes in the subtree path for which we are an unmerged leaf.
724        let subtree_path = self.tree.subtree_path(leaf_index, sender_index);
725        let mut keypairs = Vec::new();
726        for parent_index in subtree_path {
727            // We know the node is in the tree, since it is in the subtree path
728            let tsn = self.tree.parent_by_index(parent_index);
729            // We only care about non-blank nodes.
730            if let Some(ref parent_node) = tsn.node() {
731                // If our own leaf index is not in the list of unmerged leaves
732                // then we should have the secret for this node.
733                if !parent_node.unmerged_leaves().contains(&leaf_index) {
734                    let keypair = path_secret.derive_key_pair(crypto, ciphersuite)?;
735                    // The derived public key should match the one in the node.
736                    // If not, the tree is corrupt.
737                    if parent_node.encryption_key() != keypair.public_key() {
738                        return Err(DerivePathError::PublicKeyMismatch);
739                    } else {
740                        // If everything is ok, set the private key and derive
741                        // the next path secret.
742                        keypairs.push(keypair);
743                        path_secret = path_secret.derive_path_secret(crypto, ciphersuite)?;
744                    }
745                };
746                // If the leaf is blank or our index is in the list of unmerged
747                // leaves, go to the next node.
748            }
749        }
750        Ok((keypairs, path_secret.into()))
751    }
752
753    /// Return a reference to the parent node at the given `ParentNodeIndex` or
754    /// `None` if the node is blank.
755    pub(crate) fn parent(&self, node_index: ParentNodeIndex) -> Option<&ParentNode> {
756        let tsn = self.tree.parent(node_index);
757        tsn.node().as_ref()
758    }
759}
760
761#[cfg(test)]
762impl TreeSync {
763    pub(crate) fn leaf_count(&self) -> u32 {
764        self.tree.leaf_count()
765    }
766}
767
768#[cfg(test)]
769mod test {
770    use super::*;
771
772    #[cfg(debug_assertions)]
773    #[test]
774    #[should_panic]
775    /// This should only panic in debug-builds.
776    fn test_ratchet_tree_internal_empty() {
777        RatchetTree::trimmed(vec![]);
778    }
779
780    #[cfg(debug_assertions)]
781    #[test]
782    #[should_panic]
783    /// This should only panic in debug-builds.
784    fn test_ratchet_tree_internal_empty_after_trim() {
785        RatchetTree::trimmed(vec![None]);
786    }
787
788    #[openmls_test::openmls_test]
789    fn test_ratchet_tree_trailing_blank_nodes() {
790        let provider = &Provider::default();
791        let (key_package, _, _) = crate::key_packages::tests::key_package(ciphersuite, provider);
792        let node_in = NodeIn::from(Node::leaf_node(LeafNode::from(key_package)));
793        let tests = [
794            (vec![], false),
795            (vec![None], false),
796            (vec![None, None], false),
797            (vec![None, None, None], false),
798            (vec![Some(node_in.clone())], true),
799            (vec![Some(node_in.clone()), None], false),
800            (
801                vec![Some(node_in.clone()), None, Some(node_in.clone())],
802                true,
803            ),
804            (
805                vec![Some(node_in.clone()), None, Some(node_in), None],
806                false,
807            ),
808        ];
809
810        for (test, expected) in tests.into_iter() {
811            let got = RatchetTree::try_from_nodes(
812                ciphersuite,
813                provider.crypto(),
814                test,
815                &GroupId::random(provider.rand()),
816            )
817            .is_ok();
818            assert_eq!(got, expected);
819        }
820    }
821
822    #[cfg(not(debug_assertions))]
823    #[test]
824    /// This should not panic in release-builds.
825    fn test_ratchet_tree_internal_empty() {
826        RatchetTree::trimmed(vec![]);
827    }
828
829    #[cfg(not(debug_assertions))]
830    #[test]
831    /// This should not panic in release-builds.
832    fn test_ratchet_tree_internal_empty_after_trim() {
833        RatchetTree::trimmed(vec![None]);
834    }
835}