openmls/treesync/
mod.rs

1//! This module implements the ratchet tree component of MLS.
2//!
3//! It exposes the [`Node`] enum that can contain either a [`LeafNode`] or a [`ParentNode`].
4
5// # Internal documentation
6//
7// This module provides the [`TreeSync`] struct, which contains the state
8// shared between a group of MLS clients in the shape of a tree, where each
9// non-blank leaf corresponds to one group member. The functions provided by
10// its implementation allow the creation of a [`TreeSyncDiff`] instance, which
11// in turn can be mutably operated on and merged back into the original
12// [`TreeSync`] instance.
13//
14// The submodules of this module define the nodes of the tree (`nodes`),
15// helper functions and structs for the algorithms used to sync the tree across
16// the group ([`hashes`]) and the diff functionality ([`diff`]).
17//
18// Finally, this module contains the [`treekem`] module, which allows the
19// encryption and decryption of updates to the tree.
20
21#[cfg(any(feature = "test-utils", test))]
22use std::fmt;
23
24use openmls_traits::{
25    crypto::OpenMlsCrypto,
26    signatures::Signer,
27    types::{Ciphersuite, CryptoError},
28};
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
32
33use self::{
34    diff::{StagedTreeSyncDiff, TreeSyncDiff},
35    node::{
36        leaf_node::{
37            Capabilities, NewLeafNodeParams, TreeInfoTbs, TreePosition, VerifiableLeafNode,
38        },
39        NodeIn,
40    },
41    treesync_node::{TreeSyncLeafNode, TreeSyncNode, TreeSyncParentNode},
42};
43use crate::binary_tree::array_representation::ParentNodeIndex;
44#[cfg(any(feature = "test-utils", test))]
45use crate::{binary_tree::array_representation::level, test_utils::bytes_to_hex};
46use crate::{
47    binary_tree::{
48        array_representation::{is_node_in_tree, LeafNodeIndex, TreeSize},
49        MlsBinaryTree, MlsBinaryTreeError,
50    },
51    ciphersuite::{signable::Verifiable, Secret},
52    credentials::CredentialWithKey,
53    error::LibraryError,
54    extensions::Extensions,
55    group::{GroupId, Member},
56    key_packages::Lifetime,
57    messages::{PathSecret, PathSecretError},
58    schedule::CommitSecret,
59    storage::OpenMlsProvider,
60};
61
62// Private
63mod hashes;
64use errors::*;
65
66// Crate
67pub(crate) mod diff;
68pub(crate) mod node;
69pub(crate) mod treekem;
70pub(crate) mod treesync_node;
71
72use node::encryption_keys::EncryptionKeyPair;
73
74// Public
75pub mod errors;
76#[cfg(feature = "test-utils")]
77pub use node::encryption_keys::test_utils;
78pub use node::encryption_keys::EncryptionKey;
79
80// Public re-exports
81pub use node::{
82    leaf_node::{
83        LeafNode, LeafNodeParameters, LeafNodeParametersBuilder, LeafNodeSource,
84        LeafNodeUpdateError,
85    },
86    parent_node::ParentNode,
87    Node,
88};
89
90// Tests
91#[cfg(any(feature = "test-utils", test))]
92pub mod tests_and_kats;
93
94/// An exported ratchet tree as used in, e.g., [`GroupInfo`](crate::messages::group_info::GroupInfo).
95#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
96pub struct RatchetTree(Vec<Option<Node>>);
97
98/// An error during processing of an incoming ratchet tree.
99#[derive(Error, Debug, PartialEq, Clone)]
100pub enum RatchetTreeError {
101    /// The ratchet tree is empty.
102    #[error("The ratchet tree has no nodes.")]
103    MissingNodes,
104    /// The ratchet tree has a trailing blank node.
105    #[error("The ratchet tree has trailing blank nodes.")]
106    TrailingBlankNodes,
107    /// Invalid node signature.
108    #[error("Invalid node signature.")]
109    InvalidNodeSignature,
110    /// Wrong node type.
111    #[error("Wrong node type.")]
112    WrongNodeType,
113}
114
115impl RatchetTree {
116    /// Create a [`RatchetTree`] from a vector of nodes stripping all trailing blank nodes.
117    ///
118    /// Note: The caller must ensure to call this with a vector that is *not* empty after removing all trailing blank nodes.
119    fn trimmed(mut nodes: Vec<Option<Node>>) -> Self {
120        // Remove all trailing blank nodes.
121        match nodes.iter().enumerate().rfind(|(_, node)| node.is_some()) {
122            Some((rightmost_nonempty_position, _)) => {
123                // We need to add 1 to `rightmost_nonempty_position` to keep the rightmost node.
124                nodes.resize(rightmost_nonempty_position + 1, None);
125            }
126            None => {
127                // If there is no rightmost non-blank node, the vector consist of blank nodes only.
128                nodes.clear();
129            }
130        }
131
132        debug_assert!(!nodes.is_empty(), "Caller should have ensured that `RatchetTree::trimmed` is not called with a vector that is empty after removing all trailing blank nodes.");
133        Self(nodes)
134    }
135
136    /// Create a new [`RatchetTree`] from a vector of nodes.
137    pub(crate) fn try_from_nodes(
138        ciphersuite: Ciphersuite,
139        crypto: &impl OpenMlsCrypto,
140        nodes: Vec<Option<NodeIn>>,
141        group_id: &GroupId,
142    ) -> Result<Self, RatchetTreeError> {
143        // ValSem300: "Exported ratchet trees must not have trailing blank nodes."
144        //
145        // We can check this by only looking at the last node (if any).
146        match nodes.last() {
147            Some(None) => {
148                // The ratchet tree is not empty, i.e., has a last node, *but* the last node *is* blank.
149                Err(RatchetTreeError::TrailingBlankNodes)
150            }
151            None => {
152                // The ratchet tree is empty.
153                Err(RatchetTreeError::MissingNodes)
154            }
155            Some(Some(_)) => {
156                // The ratchet tree is not empty, i.e., has a last node, and the last node is not blank.
157
158                // Verify the nodes.
159                // https://validation.openmls.tech/#valn1407
160                let mut verified_nodes = Vec::new();
161                for (index, node) in nodes.into_iter().enumerate() {
162                    let verified_node = match (index % 2, node) {
163                        // Even indices must be leaf nodes.
164                        (0, Some(NodeIn::LeafNode(leaf_node))) => {
165                            let tree_position = TreePosition::new(
166                                group_id.clone(),
167                                LeafNodeIndex::new((index / 2) as u32),
168                            );
169                            let verifiable_leaf_node = leaf_node.into_verifiable_leaf_node();
170                            let signature_key = verifiable_leaf_node
171                                .signature_key()
172                                .clone()
173                                .into_signature_public_key_enriched(
174                                    ciphersuite.signature_algorithm(),
175                                );
176                            Some(Node::leaf_node(match verifiable_leaf_node {
177                                VerifiableLeafNode::KeyPackage(leaf_node) => leaf_node
178                                    .verify(crypto, &signature_key)
179                                    .map_err(|_| RatchetTreeError::InvalidNodeSignature)?,
180                                VerifiableLeafNode::Update(mut leaf_node) => {
181                                    leaf_node.add_tree_position(tree_position);
182                                    leaf_node
183                                        .verify(crypto, &signature_key)
184                                        .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
185                                }
186                                VerifiableLeafNode::Commit(mut leaf_node) => {
187                                    leaf_node.add_tree_position(tree_position);
188                                    leaf_node
189                                        .verify(crypto, &signature_key)
190                                        .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
191                                }
192                            }))
193                        }
194                        // Odd indices must be parent nodes.
195                        (1, Some(NodeIn::ParentNode(parent_node))) => {
196                            Some(Node::ParentNode(parent_node))
197                        }
198                        // Blank nodes.
199                        (_, None) => None,
200                        // All other cases are invalid.
201                        _ => {
202                            return Err(RatchetTreeError::WrongNodeType);
203                        }
204                    };
205                    verified_nodes.push(verified_node);
206                }
207                Ok(Self::trimmed(verified_nodes))
208            }
209        }
210    }
211}
212
213/// A ratchet tree made of unverified nodes. This is used for deserialization
214/// and verification.
215#[derive(
216    PartialEq,
217    Eq,
218    Clone,
219    Debug,
220    Serialize,
221    Deserialize,
222    TlsDeserialize,
223    TlsDeserializeBytes,
224    TlsSerialize,
225    TlsSize,
226)]
227pub struct RatchetTreeIn(Vec<Option<NodeIn>>);
228
229impl RatchetTreeIn {
230    /// Create a new [`RatchetTreeIn`] from a vector of nodes after verifying
231    /// the nodes.
232    pub fn into_verified(
233        self,
234        ciphersuite: Ciphersuite,
235        crypto: &impl OpenMlsCrypto,
236        group_id: &GroupId,
237    ) -> Result<RatchetTree, RatchetTreeError> {
238        RatchetTree::try_from_nodes(ciphersuite, crypto, self.0, group_id)
239    }
240
241    fn from_ratchet_tree(ratchet_tree: RatchetTree) -> Self {
242        let nodes = ratchet_tree
243            .0
244            .into_iter()
245            .map(|node| node.map(NodeIn::from))
246            .collect();
247        Self(nodes)
248    }
249
250    #[cfg(test)]
251    pub(crate) fn from_nodes(nodes: Vec<Option<NodeIn>>) -> Self {
252        Self(nodes)
253    }
254}
255
256impl From<RatchetTree> for RatchetTreeIn {
257    fn from(ratchet_tree: RatchetTree) -> Self {
258        RatchetTreeIn::from_ratchet_tree(ratchet_tree)
259    }
260}
261
262// The following `From` implementation breaks abstraction layers and MUST
263// NOT be made available outside of tests or "test-utils".
264#[cfg(any(feature = "test-utils", test))]
265impl From<RatchetTreeIn> for RatchetTree {
266    fn from(ratchet_tree_in: RatchetTreeIn) -> Self {
267        Self(
268            ratchet_tree_in
269                .0
270                .into_iter()
271                .map(|node| node.map(Node::from))
272                .collect(),
273        )
274    }
275}
276
277#[cfg(any(feature = "test-utils", test))]
278fn log2(x: u32) -> usize {
279    if x == 0 {
280        return 0;
281    }
282    (31 - x.leading_zeros()) as usize
283}
284
285#[cfg(any(feature = "test-utils", test))]
286pub(crate) fn root(size: u32) -> u32 {
287    (1 << log2(size)) - 1
288}
289
290#[cfg(any(feature = "test-utils", test))]
291impl fmt::Display for RatchetTree {
292    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293        let factor = 3;
294        let nodes = &self.0;
295        let tree_size = nodes.len() as u32;
296
297        for (i, node) in nodes.iter().enumerate() {
298            let level = level(i as u32);
299            write!(f, "{i:04}")?;
300            if let Some(node) = node {
301                let (key_bytes, parent_hash_bytes) = match node {
302                    Node::LeafNode(leaf_node) => {
303                        write!(f, "\tL      ")?;
304                        let key_bytes = leaf_node.encryption_key().as_slice();
305                        let parent_hash_bytes = leaf_node
306                            .parent_hash()
307                            .map(bytes_to_hex)
308                            .unwrap_or_default();
309                        (key_bytes, parent_hash_bytes)
310                    }
311                    Node::ParentNode(parent_node) => {
312                        if root(tree_size) == i as u32 {
313                            write!(f, "\tP (*)  ")?;
314                        } else {
315                            write!(f, "\tP      ")?;
316                        }
317                        let key_bytes = parent_node.public_key().as_slice();
318                        let parent_hash_string = bytes_to_hex(parent_node.parent_hash());
319                        (key_bytes, parent_hash_string)
320                    }
321                };
322                write!(
323                    f,
324                    "PK: {}  PH: {} | ",
325                    bytes_to_hex(key_bytes),
326                    if !parent_hash_bytes.is_empty() {
327                        parent_hash_bytes
328                    } else {
329                        str::repeat("  ", 32)
330                    }
331                )?;
332
333                write!(f, "{}◼︎", str::repeat(" ", level * factor))?;
334            } else {
335                if root(tree_size) == i as u32 {
336                    write!(
337                        f,
338                        "\t_ (*)  PK: {}  PH: {} | ",
339                        str::repeat("__", 32),
340                        str::repeat("__", 32)
341                    )?;
342                } else {
343                    write!(
344                        f,
345                        "\t_      PK: {}  PH: {} | ",
346                        str::repeat("__", 32),
347                        str::repeat("__", 32)
348                    )?;
349                }
350
351                write!(f, "{}❑", str::repeat(" ", level * factor))?;
352            }
353            writeln!(f)?;
354        }
355
356        Ok(())
357    }
358}
359
360/// The [`TreeSync`] struct holds an `MlsBinaryTree` instance, which contains
361/// the state that is synced across the group, as well as the [`LeafNodeIndex`]
362/// pointing to the leaf of this group member and the current hash of the tree.
363///
364/// It follows the same pattern of tree and diff as the underlying
365/// `MlsBinaryTree`, where the [`TreeSync`] instance is immutable safe for
366/// merging a `TreeSyncDiff`, which can be created, staged and merged (see
367/// `TreeSyncDiff`).
368///
369/// [`TreeSync`] instance guarantee a few invariants that are checked upon
370/// creating a new instance from an imported set of nodes, as well as when
371/// merging a diff.
372#[derive(Debug, Serialize, Deserialize)]
373#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq, Clone))]
374pub(crate) struct TreeSync {
375    tree: MlsBinaryTree<TreeSyncLeafNode, TreeSyncParentNode>,
376    tree_hash: Vec<u8>,
377}
378
379impl TreeSync {
380    /// Create a new tree with an own leaf for the given credential.
381    ///
382    /// Returns the resulting [`TreeSync`] instance, as well as the
383    /// corresponding [`CommitSecret`].
384    pub(crate) fn new(
385        provider: &impl OpenMlsProvider,
386        signer: &impl Signer,
387        ciphersuite: Ciphersuite,
388        credential_with_key: CredentialWithKey,
389        life_time: Lifetime,
390        capabilities: Capabilities,
391        extensions: Extensions,
392    ) -> Result<(Self, CommitSecret, EncryptionKeyPair), LibraryError> {
393        let new_leaf_node_params = NewLeafNodeParams {
394            ciphersuite,
395            credential_with_key,
396            // Creation of a group is considered to be from a key package.
397            leaf_node_source: LeafNodeSource::KeyPackage(life_time),
398            capabilities,
399            extensions,
400            tree_info_tbs: TreeInfoTbs::KeyPackage,
401        };
402        let (leaf, encryption_key_pair) = LeafNode::new(provider, signer, new_leaf_node_params)?;
403
404        let node = Node::leaf_node(leaf);
405        let path_secret: PathSecret = Secret::random(ciphersuite, provider.rand())
406            .map_err(LibraryError::unexpected_crypto_error)?
407            .into();
408        let commit_secret: CommitSecret = path_secret
409            .derive_path_secret(provider.crypto(), ciphersuite)?
410            .into();
411        let nodes = vec![TreeSyncNode::from(node).into()];
412        let tree = MlsBinaryTree::new(nodes)
413            .map_err(|_| LibraryError::custom("Unexpected error creating the binary tree."))?;
414        let mut tree_sync = Self {
415            tree,
416            tree_hash: vec![],
417        };
418        // Populate tree hash caches.
419        tree_sync.populate_parent_hashes(provider.crypto(), ciphersuite)?;
420
421        Ok((tree_sync, commit_secret, encryption_key_pair))
422    }
423
424    /// Return the tree hash of the root node of the tree.
425    pub(crate) fn tree_hash(&self) -> &[u8] {
426        self.tree_hash.as_slice()
427    }
428
429    /// Merge the given diff into this `TreeSync` instance, refreshing the
430    /// `tree_hash` value in the process.
431    pub(crate) fn merge_diff(&mut self, tree_sync_diff: StagedTreeSyncDiff) {
432        let (diff, new_tree_hash) = tree_sync_diff.into_parts();
433        self.tree_hash = new_tree_hash;
434        self.tree.merge_diff(diff);
435    }
436
437    /// Create an empty diff based on this [`TreeSync`] instance all operations
438    /// are created based on an initial, empty [`TreeSyncDiff`].
439    pub(crate) fn empty_diff(&self) -> TreeSyncDiff<'_> {
440        self.into()
441    }
442
443    /// A helper function that generates a [`TreeSync`] instance from the given
444    /// slice of nodes. It verifies that the provided encryption key is present
445    /// in the tree and that the invariants documented in [`TreeSync`] hold.
446    pub(crate) fn from_ratchet_tree(
447        crypto: &impl OpenMlsCrypto,
448        ciphersuite: Ciphersuite,
449        ratchet_tree: RatchetTree,
450    ) -> Result<Self, TreeSyncFromNodesError> {
451        // TODO #800: Unmerged leaves should be checked
452        let total_nodes = ratchet_tree.0.len();
453        let mut leaf_nodes = Vec::with_capacity(total_nodes.div_ceil(2));
454        let mut parent_nodes = Vec::with_capacity(total_nodes / 2);
455
456        // Set the leaf indices in all the leaves and convert the node types.
457        for (node_index, node_option) in ratchet_tree.0.into_iter().enumerate() {
458            if node_index % 2 == 0 {
459                let leaf = match node_option {
460                    Some(node) => match TreeSyncNode::from(node) {
461                        TreeSyncNode::Leaf(l) => *l,
462                        TreeSyncNode::Parent(_) => {
463                            return Err(TreeSyncFromNodesError::from(
464                                PublicTreeError::MalformedTree,
465                            ))
466                        }
467                    },
468                    None => TreeSyncLeafNode::blank(),
469                };
470                leaf_nodes.push(leaf);
471            } else {
472                let parent = match node_option {
473                    Some(node) => match TreeSyncNode::from(node) {
474                        TreeSyncNode::Parent(p) => *p,
475                        TreeSyncNode::Leaf(_) => {
476                            return Err(TreeSyncFromNodesError::from(
477                                PublicTreeError::MalformedTree,
478                            ))
479                        }
480                    },
481                    None => TreeSyncParentNode::blank(),
482                };
483                parent_nodes.push(parent);
484            }
485        }
486
487        let tree = MlsBinaryTree::from_components(leaf_nodes, parent_nodes)
488            .map_err(|_| PublicTreeError::MalformedTree)?;
489        let mut tree_sync = Self {
490            tree,
491            tree_hash: vec![],
492        };
493
494        // Verify all parent hashes.
495        tree_sync
496            .verify_parent_hashes(crypto, ciphersuite)
497            .map_err(|e| match e {
498                TreeSyncParentHashError::LibraryError(e) => e.into(),
499                TreeSyncParentHashError::InvalidParentHash => {
500                    TreeSyncFromNodesError::from(PublicTreeError::InvalidParentHash)
501                }
502            })?;
503
504        // Populate tree hash caches.
505        tree_sync.populate_parent_hashes(crypto, ciphersuite)?;
506        Ok(tree_sync)
507    }
508
509    /// Find the `LeafNodeIndex` which a new leaf would have if it were added to the
510    /// tree. This is either the left-most blank node or, if there are no blank
511    /// leaves, the leaf count, since adding a member would extend the tree by
512    /// one leaf.
513    pub(crate) fn free_leaf_index(&self) -> LeafNodeIndex {
514        let diff = self.empty_diff();
515        diff.free_leaf_index()
516    }
517
518    /// Populate the parent hash caches of all nodes in the tree.
519    fn populate_parent_hashes(
520        &mut self,
521        crypto: &impl OpenMlsCrypto,
522        ciphersuite: Ciphersuite,
523    ) -> Result<(), LibraryError> {
524        let diff = self.empty_diff();
525        // Make the diff into a staged diff. This implicitly computes the
526        // tree hashes and poulates the tree hash caches.
527        let staged_diff = diff.into_staged_diff(crypto, ciphersuite)?;
528        // Merge the diff.
529        self.merge_diff(staged_diff);
530        Ok(())
531    }
532
533    /// Verify the parent hashes of all parent nodes in the tree.
534    ///
535    /// Returns an error if one of the parent nodes in the tree has an invalid
536    /// parent hash.
537    fn verify_parent_hashes(
538        &self,
539        crypto: &impl OpenMlsCrypto,
540        ciphersuite: Ciphersuite,
541    ) -> Result<(), TreeSyncParentHashError> {
542        // The ability to verify parent hashes is required both for diffs and
543        // treesync instances. We choose the computationally slightly more
544        // expensive solution of implementing parent hash verification for the
545        // diff and creating an empty diff whenever we need to verify parent
546        // hashes for a `TreeSync` instance. At the time of writing, this
547        // happens only upon construction of a `TreeSync` instance from a vector
548        // of nodes. The alternative solution would be to create a `TreeLike`
549        // trait, which allows tree navigation and node access. We could then
550        // implement `TreeLike` for both `TreeSync` and `TreeSyncDiff` and
551        // finally implement parent hash verification for any struct that
552        // implements `TreeLike`. We choose the less complex version for now.
553        // Should this turn out to cause too much computational overhead, we
554        // should reconsider and choose the alternative sketched above
555        let diff = self.empty_diff();
556        // No need to merge the diff, since we didn't actually modify any state.
557        diff.verify_parent_hashes(crypto, ciphersuite)
558    }
559
560    /// Returns the tree size
561    pub(crate) fn tree_size(&self) -> TreeSize {
562        self.tree.tree_size()
563    }
564
565    /// Returns an iterator over the (non-blank) [`LeafNode`]s in the tree.
566    pub(crate) fn full_leaves(&self) -> impl Iterator<Item = &LeafNode> {
567        self.tree
568            .leaves()
569            .filter_map(|(_, tsn)| tsn.node().as_ref())
570    }
571
572    /// Returns an iterator over the (non-blank) [`ParentNode`]s in the tree.
573    pub(crate) fn full_parents(&self) -> impl Iterator<Item = (ParentNodeIndex, &ParentNode)> {
574        self.tree
575            .parents()
576            .filter_map(|(index, tsn)| tsn.node().as_ref().map(|pn| (index, pn)))
577    }
578
579    /// Returns the index of the last full leaf in the tree.
580    fn rightmost_full_leaf(&self) -> LeafNodeIndex {
581        let mut index = LeafNodeIndex::new(0);
582        for (leaf_index, leaf) in self.tree.leaves() {
583            if leaf.node().as_ref().is_some() {
584                index = leaf_index;
585            }
586        }
587        index
588    }
589
590    /// Returns a list of [`Member`]s containing only full nodes.
591    ///
592    /// XXX: For performance reasons we probably want to have this in a borrowing
593    ///      version as well. But it might well go away again.
594    pub(crate) fn full_leave_members(&self) -> impl Iterator<Item = Member> + '_ {
595        self.tree
596            .leaves()
597            // Filter out blank nodes
598            .filter_map(|(index, tsn)| tsn.node().as_ref().map(|node| (index, node)))
599            // Map to `Member`
600            .map(|(index, leaf_node)| {
601                Member::new(
602                    index,
603                    leaf_node.encryption_key().as_slice().to_vec(),
604                    leaf_node.signature_key().as_slice().to_vec(),
605                    leaf_node.credential().clone(),
606                )
607            })
608    }
609
610    /// Returns the nodes in the tree ordered according to the
611    /// array-representation of the underlying binary tree.
612    pub fn export_ratchet_tree(&self) -> RatchetTree {
613        let mut nodes = Vec::new();
614
615        // Determine the index of the rightmost full leaf.
616        let max_length = self.rightmost_full_leaf();
617
618        // We take all the leaves including the rightmost full leaf, blank
619        // leaves beyond that are trimmed.
620        let mut leaves = self
621            .tree
622            .leaves()
623            .map(|(_, leaf)| leaf)
624            .take(max_length.usize() + 1);
625
626        // Get the first leaf.
627        if let Some(leaf) = leaves.next() {
628            nodes.push(leaf.node().clone().map(Node::leaf_node));
629        } else {
630            // The tree was empty.
631            return RatchetTree::trimmed(vec![]);
632        }
633
634        // Blank parent node used for padding
635        let default_parent = TreeSyncParentNode::default();
636
637        // Get the parents.
638        let parents = self
639            .tree
640            .parents()
641            // Drop the index
642            .map(|(_, parent)| parent)
643            // Take the parents up to the max length
644            .take(max_length.usize())
645            // Pad the parents with blank nodes if needed
646            .chain(
647                (self.tree.parents().count()..self.tree.leaves().count() - 1)
648                    .map(|_| &default_parent),
649            );
650
651        // Interleave the leaves and parents.
652        for (leaf, parent) in leaves.zip(parents) {
653            nodes.push(parent.node().clone().map(Node::parent_node));
654            nodes.push(leaf.node().clone().map(Node::leaf_node));
655        }
656
657        RatchetTree::trimmed(nodes)
658    }
659
660    /// Return a reference to the leaf at the given `LeafNodeIndex` or `None` if the
661    /// leaf is blank.
662    pub(crate) fn leaf(&self, leaf_index: LeafNodeIndex) -> Option<&LeafNode> {
663        let tsn = self.tree.leaf(leaf_index);
664        tsn.node().as_ref()
665    }
666
667    /// Returns a [`TreeSyncError`] if the `leaf_index` is not a leaf in this
668    /// tree or empty.
669    pub(crate) fn is_leaf_in_tree(&self, leaf_index: LeafNodeIndex) -> bool {
670        is_node_in_tree(leaf_index.into(), self.tree.tree_size())
671    }
672
673    /// Return a vector containing all [`EncryptionKey`]s for which the owner of
674    /// the given `leaf_index` should have private key material.
675    pub(crate) fn owned_encryption_keys(&self, leaf_index: LeafNodeIndex) -> Vec<EncryptionKey> {
676        self.empty_diff()
677            .encryption_keys(leaf_index)
678            .cloned()
679            .collect::<Vec<EncryptionKey>>()
680    }
681
682    /// Derives [`EncryptionKeyPair`]s for the nodes in the shared direct path
683    /// of the leaves with index `leaf_index` and `sender_index`.  This function
684    /// also checks that the derived public keys match the existing public keys.
685    ///
686    /// Returns the `CommitSecret` derived from the path secret of the root
687    /// node, as well as the derived [`EncryptionKeyPair`]s. Returns an error if
688    /// the target leaf is outside of the tree.
689    ///
690    /// Returns TreeSyncSetPathError::PublicKeyMismatch if the derived keys don't
691    /// match with the existing ones.
692    ///
693    /// Returns TreeSyncSetPathError::LibraryError if the sender_index is not
694    /// in the tree.
695    pub(crate) fn derive_path_secrets(
696        &self,
697        crypto: &impl OpenMlsCrypto,
698        ciphersuite: Ciphersuite,
699        mut path_secret: PathSecret,
700        sender_index: LeafNodeIndex,
701        leaf_index: LeafNodeIndex,
702    ) -> Result<(Vec<EncryptionKeyPair>, CommitSecret), DerivePathError> {
703        // We assume both nodes are in the tree, since the sender_index must be in the tree
704        // Skip the nodes in the subtree path for which we are an unmerged leaf.
705        let subtree_path = self.tree.subtree_path(leaf_index, sender_index);
706        let mut keypairs = Vec::new();
707        for parent_index in subtree_path {
708            // We know the node is in the tree, since it is in the subtree path
709            let tsn = self.tree.parent_by_index(parent_index);
710            // We only care about non-blank nodes.
711            if let Some(ref parent_node) = tsn.node() {
712                // If our own leaf index is not in the list of unmerged leaves
713                // then we should have the secret for this node.
714                if !parent_node.unmerged_leaves().contains(&leaf_index) {
715                    let keypair = path_secret.derive_key_pair(crypto, ciphersuite)?;
716                    // The derived public key should match the one in the node.
717                    // If not, the tree is corrupt.
718                    if parent_node.encryption_key() != keypair.public_key() {
719                        return Err(DerivePathError::PublicKeyMismatch);
720                    } else {
721                        // If everything is ok, set the private key and derive
722                        // the next path secret.
723                        keypairs.push(keypair);
724                        path_secret = path_secret.derive_path_secret(crypto, ciphersuite)?;
725                    }
726                };
727                // If the leaf is blank or our index is in the list of unmerged
728                // leaves, go to the next node.
729            }
730        }
731        Ok((keypairs, path_secret.into()))
732    }
733
734    /// Return a reference to the parent node at the given `ParentNodeIndex` or
735    /// `None` if the node is blank.
736    pub(crate) fn parent(&self, node_index: ParentNodeIndex) -> Option<&ParentNode> {
737        let tsn = self.tree.parent(node_index);
738        tsn.node().as_ref()
739    }
740}
741
742#[cfg(test)]
743impl TreeSync {
744    pub(crate) fn leaf_count(&self) -> u32 {
745        self.tree.leaf_count()
746    }
747}
748
749#[cfg(test)]
750mod test {
751    use super::*;
752
753    #[cfg(debug_assertions)]
754    #[test]
755    #[should_panic]
756    /// This should only panic in debug-builds.
757    fn test_ratchet_tree_internal_empty() {
758        RatchetTree::trimmed(vec![]);
759    }
760
761    #[cfg(debug_assertions)]
762    #[test]
763    #[should_panic]
764    /// This should only panic in debug-builds.
765    fn test_ratchet_tree_internal_empty_after_trim() {
766        RatchetTree::trimmed(vec![None]);
767    }
768
769    #[openmls_test::openmls_test]
770    fn test_ratchet_tree_trailing_blank_nodes() {
771        let provider = &Provider::default();
772        let (key_package, _, _) = crate::key_packages::tests::key_package(ciphersuite, provider);
773        let node_in = NodeIn::from(Node::leaf_node(LeafNode::from(key_package)));
774        let tests = [
775            (vec![], false),
776            (vec![None], false),
777            (vec![None, None], false),
778            (vec![None, None, None], false),
779            (vec![Some(node_in.clone())], true),
780            (vec![Some(node_in.clone()), None], false),
781            (
782                vec![Some(node_in.clone()), None, Some(node_in.clone())],
783                true,
784            ),
785            (
786                vec![Some(node_in.clone()), None, Some(node_in), None],
787                false,
788            ),
789        ];
790
791        for (test, expected) in tests.into_iter() {
792            let got = RatchetTree::try_from_nodes(
793                ciphersuite,
794                provider.crypto(),
795                test,
796                &GroupId::random(provider.rand()),
797            )
798            .is_ok();
799            assert_eq!(got, expected);
800        }
801    }
802
803    #[cfg(not(debug_assertions))]
804    #[test]
805    /// This should not panic in release-builds.
806    fn test_ratchet_tree_internal_empty() {
807        RatchetTree::trimmed(vec![]);
808    }
809
810    #[cfg(not(debug_assertions))]
811    #[test]
812    /// This should not panic in release-builds.
813    fn test_ratchet_tree_internal_empty_after_trim() {
814        RatchetTree::trimmed(vec![None]);
815    }
816}