openmls/treesync/
mod.rs

1//! This module implements the ratchet tree component of MLS.
2//!
3//! It exposes the [`Node`] enum that can contain either a [`LeafNode`] or a [`ParentNode`].
4
5// # Internal documentation
6//
7// This module provides the [`TreeSync`] struct, which contains the state
8// shared between a group of MLS clients in the shape of a tree, where each
9// non-blank leaf corresponds to one group member. The functions provided by
10// its implementation allow the creation of a [`TreeSyncDiff`] instance, which
11// in turn can be mutably operated on and merged back into the original
12// [`TreeSync`] instance.
13//
14// The submodules of this module define the nodes of the tree (`nodes`),
15// helper functions and structs for the algorithms used to sync the tree across
16// the group ([`hashes`]) and the diff functionality ([`diff`]).
17//
18// Finally, this module contains the [`treekem`] module, which allows the
19// encryption and decryption of updates to the tree.
20
21#[cfg(any(feature = "test-utils", test))]
22use std::fmt;
23
24use openmls_traits::{
25    crypto::OpenMlsCrypto,
26    signatures::Signer,
27    types::{Ciphersuite, CryptoError},
28};
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
32
33use self::{
34    diff::{StagedTreeSyncDiff, TreeSyncDiff},
35    node::{
36        leaf_node::{
37            Capabilities, NewLeafNodeParams, TreeInfoTbs, TreePosition, VerifiableLeafNode,
38        },
39        NodeIn,
40    },
41    treesync_node::{TreeSyncLeafNode, TreeSyncNode, TreeSyncParentNode},
42};
43use crate::binary_tree::array_representation::ParentNodeIndex;
44#[cfg(any(feature = "test-utils", test))]
45use crate::{binary_tree::array_representation::level, test_utils::bytes_to_hex};
46use crate::{
47    binary_tree::{
48        array_representation::{is_node_in_tree, tree::TreeNode, LeafNodeIndex, TreeSize},
49        MlsBinaryTree, MlsBinaryTreeError,
50    },
51    ciphersuite::{signable::Verifiable, Secret},
52    credentials::CredentialWithKey,
53    error::LibraryError,
54    extensions::Extensions,
55    group::{GroupId, Member},
56    key_packages::Lifetime,
57    messages::{PathSecret, PathSecretError},
58    schedule::CommitSecret,
59    storage::OpenMlsProvider,
60};
61
62// Private
63mod hashes;
64use errors::*;
65
66// Crate
67pub(crate) mod diff;
68pub(crate) mod node;
69pub(crate) mod treekem;
70pub(crate) mod treesync_node;
71
72use node::encryption_keys::EncryptionKeyPair;
73
74// Public
75pub mod errors;
76#[cfg(feature = "test-utils")]
77pub use node::encryption_keys::test_utils;
78pub use node::encryption_keys::EncryptionKey;
79
80// Public re-exports
81pub use node::{
82    leaf_node::{
83        LeafNode, LeafNodeParameters, LeafNodeParametersBuilder, LeafNodeSource,
84        LeafNodeUpdateError,
85    },
86    parent_node::ParentNode,
87    Node,
88};
89
90// Tests
91#[cfg(any(feature = "test-utils", test))]
92pub mod tests_and_kats;
93
94/// An exported ratchet tree as used in, e.g., [`GroupInfo`](crate::messages::group_info::GroupInfo).
95#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
96pub struct RatchetTree(Vec<Option<Node>>);
97
98/// An error during processing of an incoming ratchet tree.
99#[derive(Error, Debug, PartialEq, Clone)]
100pub enum RatchetTreeError {
101    /// The ratchet tree is empty.
102    #[error("The ratchet tree has no nodes.")]
103    MissingNodes,
104    /// The ratchet tree has a trailing blank node.
105    #[error("The ratchet tree has trailing blank nodes.")]
106    TrailingBlankNodes,
107    /// Invalid node signature.
108    #[error("Invalid node signature.")]
109    InvalidNodeSignature,
110    /// Wrong node type.
111    #[error("Wrong node type.")]
112    WrongNodeType,
113}
114
115impl RatchetTree {
116    /// Create a [`RatchetTree`] from a vector of nodes stripping all trailing blank nodes.
117    ///
118    /// Note: The caller must ensure to call this with a vector that is *not* empty after removing all trailing blank nodes.
119    fn trimmed(mut nodes: Vec<Option<Node>>) -> Self {
120        // Remove all trailing blank nodes.
121        match nodes.iter().enumerate().rfind(|(_, node)| node.is_some()) {
122            Some((rightmost_nonempty_position, _)) => {
123                // We need to add 1 to `rightmost_nonempty_position` to keep the rightmost node.
124                nodes.resize(rightmost_nonempty_position + 1, None);
125            }
126            None => {
127                // If there is no rightmost non-blank node, the vector consist of blank nodes only.
128                nodes.clear();
129            }
130        }
131
132        debug_assert!(!nodes.is_empty(), "Caller should have ensured that `RatchetTree::trimmed` is not called with a vector that is empty after removing all trailing blank nodes.");
133        Self(nodes)
134    }
135
136    /// Create a new [`RatchetTree`] from a vector of nodes.
137    pub(crate) fn try_from_nodes(
138        ciphersuite: Ciphersuite,
139        crypto: &impl OpenMlsCrypto,
140        nodes: Vec<Option<NodeIn>>,
141        group_id: &GroupId,
142    ) -> Result<Self, RatchetTreeError> {
143        // ValSem300: "Exported ratchet trees must not have trailing blank nodes."
144        //
145        // We can check this by only looking at the last node (if any).
146        match nodes.last() {
147            Some(None) => {
148                // The ratchet tree is not empty, i.e., has a last node, *but* the last node *is* blank.
149                Err(RatchetTreeError::TrailingBlankNodes)
150            }
151            None => {
152                // The ratchet tree is empty.
153                Err(RatchetTreeError::MissingNodes)
154            }
155            Some(Some(_)) => {
156                // The ratchet tree is not empty, i.e., has a last node, and the last node is not blank.
157
158                // Verify the nodes.
159                // https://validation.openmls.tech/#valn1407
160                let mut verified_nodes = Vec::new();
161                for (index, node) in nodes.into_iter().enumerate() {
162                    let verified_node = match (index % 2, node) {
163                        // Even indices must be leaf nodes.
164                        (0, Some(NodeIn::LeafNode(leaf_node))) => {
165                            let tree_position = TreePosition::new(
166                                group_id.clone(),
167                                LeafNodeIndex::new((index / 2) as u32),
168                            );
169                            let verifiable_leaf_node = leaf_node.into_verifiable_leaf_node();
170                            let signature_key = verifiable_leaf_node
171                                .signature_key()
172                                .clone()
173                                .into_signature_public_key_enriched(
174                                    ciphersuite.signature_algorithm(),
175                                );
176                            Some(Node::LeafNode(match verifiable_leaf_node {
177                                VerifiableLeafNode::KeyPackage(leaf_node) => leaf_node
178                                    .verify(crypto, &signature_key)
179                                    .map_err(|_| RatchetTreeError::InvalidNodeSignature)?,
180                                VerifiableLeafNode::Update(mut leaf_node) => {
181                                    leaf_node.add_tree_position(tree_position);
182                                    leaf_node
183                                        .verify(crypto, &signature_key)
184                                        .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
185                                }
186                                VerifiableLeafNode::Commit(mut leaf_node) => {
187                                    leaf_node.add_tree_position(tree_position);
188                                    leaf_node
189                                        .verify(crypto, &signature_key)
190                                        .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
191                                }
192                            }))
193                        }
194                        // Odd indices must be parent nodes.
195                        (1, Some(NodeIn::ParentNode(parent_node))) => {
196                            Some(Node::ParentNode(parent_node))
197                        }
198                        // Blank nodes.
199                        (_, None) => None,
200                        // All other cases are invalid.
201                        _ => {
202                            return Err(RatchetTreeError::WrongNodeType);
203                        }
204                    };
205                    verified_nodes.push(verified_node);
206                }
207                Ok(Self::trimmed(verified_nodes))
208            }
209        }
210    }
211}
212
213/// A ratchet tree made of unverified nodes. This is used for deserialization
214/// and verification.
215#[derive(
216    PartialEq,
217    Eq,
218    Clone,
219    Debug,
220    Serialize,
221    Deserialize,
222    TlsDeserialize,
223    TlsDeserializeBytes,
224    TlsSerialize,
225    TlsSize,
226)]
227pub struct RatchetTreeIn(Vec<Option<NodeIn>>);
228
229impl RatchetTreeIn {
230    /// Create a new [`RatchetTreeIn`] from a vector of nodes after verifying
231    /// the nodes.
232    pub fn into_verified(
233        self,
234        ciphersuite: Ciphersuite,
235        crypto: &impl OpenMlsCrypto,
236        group_id: &GroupId,
237    ) -> Result<RatchetTree, RatchetTreeError> {
238        RatchetTree::try_from_nodes(ciphersuite, crypto, self.0, group_id)
239    }
240
241    fn from_ratchet_tree(ratchet_tree: RatchetTree) -> Self {
242        let nodes = ratchet_tree
243            .0
244            .into_iter()
245            .map(|node| node.map(NodeIn::from))
246            .collect();
247        Self(nodes)
248    }
249
250    #[cfg(test)]
251    pub(crate) fn from_nodes(nodes: Vec<Option<NodeIn>>) -> Self {
252        Self(nodes)
253    }
254}
255
256impl From<RatchetTree> for RatchetTreeIn {
257    fn from(ratchet_tree: RatchetTree) -> Self {
258        RatchetTreeIn::from_ratchet_tree(ratchet_tree)
259    }
260}
261
262// The following `From` implementation breaks abstraction layers and MUST
263// NOT be made available outside of tests or "test-utils".
264#[cfg(any(feature = "test-utils", test))]
265impl From<RatchetTreeIn> for RatchetTree {
266    fn from(ratchet_tree_in: RatchetTreeIn) -> Self {
267        Self(
268            ratchet_tree_in
269                .0
270                .into_iter()
271                .map(|node| node.map(Node::from))
272                .collect(),
273        )
274    }
275}
276
277#[cfg(any(feature = "test-utils", test))]
278fn log2(x: u32) -> usize {
279    if x == 0 {
280        return 0;
281    }
282    (31 - x.leading_zeros()) as usize
283}
284
285#[cfg(any(feature = "test-utils", test))]
286pub(crate) fn root(size: u32) -> u32 {
287    (1 << log2(size)) - 1
288}
289
290#[cfg(any(feature = "test-utils", test))]
291impl fmt::Display for RatchetTree {
292    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293        let factor = 3;
294        let nodes = &self.0;
295        let tree_size = nodes.len() as u32;
296
297        for (i, node) in nodes.iter().enumerate() {
298            let level = level(i as u32);
299            write!(f, "{i:04}")?;
300            if let Some(node) = node {
301                let (key_bytes, parent_hash_bytes) = match node {
302                    Node::LeafNode(leaf_node) => {
303                        write!(f, "\tL      ")?;
304                        let key_bytes = leaf_node.encryption_key().as_slice();
305                        let parent_hash_bytes = leaf_node
306                            .parent_hash()
307                            .map(bytes_to_hex)
308                            .unwrap_or_default();
309                        (key_bytes, parent_hash_bytes)
310                    }
311                    Node::ParentNode(parent_node) => {
312                        if root(tree_size) == i as u32 {
313                            write!(f, "\tP (*)  ")?;
314                        } else {
315                            write!(f, "\tP      ")?;
316                        }
317                        let key_bytes = parent_node.public_key().as_slice();
318                        let parent_hash_string = bytes_to_hex(parent_node.parent_hash());
319                        (key_bytes, parent_hash_string)
320                    }
321                };
322                write!(
323                    f,
324                    "PK: {}  PH: {} | ",
325                    bytes_to_hex(key_bytes),
326                    if !parent_hash_bytes.is_empty() {
327                        parent_hash_bytes
328                    } else {
329                        str::repeat("  ", 32)
330                    }
331                )?;
332
333                write!(f, "{}◼︎", str::repeat(" ", level * factor))?;
334            } else {
335                if root(tree_size) == i as u32 {
336                    write!(
337                        f,
338                        "\t_ (*)  PK: {}  PH: {} | ",
339                        str::repeat("__", 32),
340                        str::repeat("__", 32)
341                    )?;
342                } else {
343                    write!(
344                        f,
345                        "\t_      PK: {}  PH: {} | ",
346                        str::repeat("__", 32),
347                        str::repeat("__", 32)
348                    )?;
349                }
350
351                write!(f, "{}❑", str::repeat(" ", level * factor))?;
352            }
353            writeln!(f)?;
354        }
355
356        Ok(())
357    }
358}
359
360/// The [`TreeSync`] struct holds an `MlsBinaryTree` instance, which contains
361/// the state that is synced across the group, as well as the [`LeafNodeIndex`]
362/// pointing to the leaf of this group member and the current hash of the tree.
363///
364/// It follows the same pattern of tree and diff as the underlying
365/// `MlsBinaryTree`, where the [`TreeSync`] instance is immutable safe for
366/// merging a `TreeSyncDiff`, which can be created, staged and merged (see
367/// `TreeSyncDiff`).
368///
369/// [`TreeSync`] instance guarantee a few invariants that are checked upon
370/// creating a new instance from an imported set of nodes, as well as when
371/// merging a diff.
372#[derive(Debug, Serialize, Deserialize)]
373#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq, Clone))]
374pub(crate) struct TreeSync {
375    tree: MlsBinaryTree<TreeSyncLeafNode, TreeSyncParentNode>,
376    tree_hash: Vec<u8>,
377}
378
379impl TreeSync {
380    /// Create a new tree with an own leaf for the given credential.
381    ///
382    /// Returns the resulting [`TreeSync`] instance, as well as the
383    /// corresponding [`CommitSecret`].
384    pub(crate) fn new(
385        provider: &impl OpenMlsProvider,
386        signer: &impl Signer,
387        ciphersuite: Ciphersuite,
388        credential_with_key: CredentialWithKey,
389        life_time: Lifetime,
390        capabilities: Capabilities,
391        extensions: Extensions,
392    ) -> Result<(Self, CommitSecret, EncryptionKeyPair), LibraryError> {
393        let new_leaf_node_params = NewLeafNodeParams {
394            ciphersuite,
395            credential_with_key,
396            // Creation of a group is considered to be from a key package.
397            leaf_node_source: LeafNodeSource::KeyPackage(life_time),
398            capabilities,
399            extensions,
400            tree_info_tbs: TreeInfoTbs::KeyPackage,
401        };
402        let (leaf, encryption_key_pair) = LeafNode::new(provider, signer, new_leaf_node_params)?;
403
404        let node = Node::LeafNode(leaf);
405        let path_secret: PathSecret = Secret::random(ciphersuite, provider.rand())
406            .map_err(LibraryError::unexpected_crypto_error)?
407            .into();
408        let commit_secret: CommitSecret = path_secret
409            .derive_path_secret(provider.crypto(), ciphersuite)?
410            .into();
411        let nodes = vec![TreeSyncNode::from(node).into()];
412        let tree = MlsBinaryTree::new(nodes)
413            .map_err(|_| LibraryError::custom("Unexpected error creating the binary tree."))?;
414        let mut tree_sync = Self {
415            tree,
416            tree_hash: vec![],
417        };
418        // Populate tree hash caches.
419        tree_sync.populate_parent_hashes(provider.crypto(), ciphersuite)?;
420
421        Ok((tree_sync, commit_secret, encryption_key_pair))
422    }
423
424    /// Return the tree hash of the root node of the tree.
425    pub(crate) fn tree_hash(&self) -> &[u8] {
426        self.tree_hash.as_slice()
427    }
428
429    /// Merge the given diff into this `TreeSync` instance, refreshing the
430    /// `tree_hash` value in the process.
431    pub(crate) fn merge_diff(&mut self, tree_sync_diff: StagedTreeSyncDiff) {
432        let (diff, new_tree_hash) = tree_sync_diff.into_parts();
433        self.tree_hash = new_tree_hash;
434        self.tree.merge_diff(diff);
435    }
436
437    /// Create an empty diff based on this [`TreeSync`] instance all operations
438    /// are created based on an initial, empty [`TreeSyncDiff`].
439    pub(crate) fn empty_diff(&self) -> TreeSyncDiff {
440        self.into()
441    }
442
443    /// A helper function that generates a [`TreeSync`] instance from the given
444    /// slice of nodes. It verifies that the provided encryption key is present
445    /// in the tree and that the invariants documented in [`TreeSync`] hold.
446    pub(crate) fn from_ratchet_tree(
447        crypto: &impl OpenMlsCrypto,
448        ciphersuite: Ciphersuite,
449        ratchet_tree: RatchetTree,
450    ) -> Result<Self, TreeSyncFromNodesError> {
451        // TODO #800: Unmerged leaves should be checked
452        let mut ts_nodes: Vec<TreeNode<TreeSyncLeafNode, TreeSyncParentNode>> =
453            Vec::with_capacity(ratchet_tree.0.len());
454
455        // Set the leaf indices in all the leaves and convert the node types.
456        for (node_index, node_option) in ratchet_tree.0.into_iter().enumerate() {
457            let ts_node_option: TreeNode<TreeSyncLeafNode, TreeSyncParentNode> = match node_option {
458                Some(node) => TreeSyncNode::from(node).into(),
459                None => {
460                    if node_index % 2 == 0 {
461                        TreeNode::Leaf(TreeSyncLeafNode::blank())
462                    } else {
463                        TreeNode::Parent(TreeSyncParentNode::blank())
464                    }
465                }
466            };
467            ts_nodes.push(ts_node_option);
468        }
469
470        let tree = MlsBinaryTree::new(ts_nodes).map_err(|_| PublicTreeError::MalformedTree)?;
471        let mut tree_sync = Self {
472            tree,
473            tree_hash: vec![],
474        };
475
476        // Verify all parent hashes.
477        tree_sync
478            .verify_parent_hashes(crypto, ciphersuite)
479            .map_err(|e| match e {
480                TreeSyncParentHashError::LibraryError(e) => e.into(),
481                TreeSyncParentHashError::InvalidParentHash => {
482                    TreeSyncFromNodesError::from(PublicTreeError::InvalidParentHash)
483                }
484            })?;
485
486        // Populate tree hash caches.
487        tree_sync.populate_parent_hashes(crypto, ciphersuite)?;
488        Ok(tree_sync)
489    }
490
491    /// Find the `LeafNodeIndex` which a new leaf would have if it were added to the
492    /// tree. This is either the left-most blank node or, if there are no blank
493    /// leaves, the leaf count, since adding a member would extend the tree by
494    /// one leaf.
495    pub(crate) fn free_leaf_index(&self) -> LeafNodeIndex {
496        let diff = self.empty_diff();
497        diff.free_leaf_index()
498    }
499
500    /// Populate the parent hash caches of all nodes in the tree.
501    fn populate_parent_hashes(
502        &mut self,
503        crypto: &impl OpenMlsCrypto,
504        ciphersuite: Ciphersuite,
505    ) -> Result<(), LibraryError> {
506        let diff = self.empty_diff();
507        // Make the diff into a staged diff. This implicitly computes the
508        // tree hashes and poulates the tree hash caches.
509        let staged_diff = diff.into_staged_diff(crypto, ciphersuite)?;
510        // Merge the diff.
511        self.merge_diff(staged_diff);
512        Ok(())
513    }
514
515    /// Verify the parent hashes of all parent nodes in the tree.
516    ///
517    /// Returns an error if one of the parent nodes in the tree has an invalid
518    /// parent hash.
519    fn verify_parent_hashes(
520        &self,
521        crypto: &impl OpenMlsCrypto,
522        ciphersuite: Ciphersuite,
523    ) -> Result<(), TreeSyncParentHashError> {
524        // The ability to verify parent hashes is required both for diffs and
525        // treesync instances. We choose the computationally slightly more
526        // expensive solution of implementing parent hash verification for the
527        // diff and creating an empty diff whenever we need to verify parent
528        // hashes for a `TreeSync` instance. At the time of writing, this
529        // happens only upon construction of a `TreeSync` instance from a vector
530        // of nodes. The alternative solution would be to create a `TreeLike`
531        // trait, which allows tree navigation and node access. We could then
532        // implement `TreeLike` for both `TreeSync` and `TreeSyncDiff` and
533        // finally implement parent hash verification for any struct that
534        // implements `TreeLike`. We choose the less complex version for now.
535        // Should this turn out to cause too much computational overhead, we
536        // should reconsider and choose the alternative sketched above
537        let diff = self.empty_diff();
538        // No need to merge the diff, since we didn't actually modify any state.
539        diff.verify_parent_hashes(crypto, ciphersuite)
540    }
541
542    /// Returns the tree size
543    pub(crate) fn tree_size(&self) -> TreeSize {
544        self.tree.tree_size()
545    }
546
547    /// Returns an iterator over the (non-blank) [`LeafNode`]s in the tree.
548    pub(crate) fn full_leaves(&self) -> impl Iterator<Item = &LeafNode> {
549        self.tree
550            .leaves()
551            .filter_map(|(_, tsn)| tsn.node().as_ref())
552    }
553
554    /// Returns an iterator over the (non-blank) [`ParentNode`]s in the tree.
555    pub(crate) fn full_parents(&self) -> impl Iterator<Item = (ParentNodeIndex, &ParentNode)> {
556        self.tree
557            .parents()
558            .filter_map(|(index, tsn)| tsn.node().as_ref().map(|pn| (index, pn)))
559    }
560
561    /// Returns the index of the last full leaf in the tree.
562    fn rightmost_full_leaf(&self) -> LeafNodeIndex {
563        let mut index = LeafNodeIndex::new(0);
564        for (leaf_index, leaf) in self.tree.leaves() {
565            if leaf.node().as_ref().is_some() {
566                index = leaf_index;
567            }
568        }
569        index
570    }
571
572    /// Returns a list of [`Member`]s containing only full nodes.
573    ///
574    /// XXX: For performance reasons we probably want to have this in a borrowing
575    ///      version as well. But it might well go away again.
576    pub(crate) fn full_leave_members(&self) -> impl Iterator<Item = Member> + '_ {
577        self.tree
578            .leaves()
579            // Filter out blank nodes
580            .filter_map(|(index, tsn)| tsn.node().as_ref().map(|node| (index, node)))
581            // Map to `Member`
582            .map(|(index, leaf_node)| {
583                Member::new(
584                    index,
585                    leaf_node.encryption_key().as_slice().to_vec(),
586                    leaf_node.signature_key().as_slice().to_vec(),
587                    leaf_node.credential().clone(),
588                )
589            })
590    }
591
592    /// Returns the nodes in the tree ordered according to the
593    /// array-representation of the underlying binary tree.
594    pub fn export_ratchet_tree(&self) -> RatchetTree {
595        let mut nodes = Vec::new();
596
597        // Determine the index of the rightmost full leaf.
598        let max_length = self.rightmost_full_leaf();
599
600        // We take all the leaves including the rightmost full leaf, blank
601        // leaves beyond that are trimmed.
602        let mut leaves = self
603            .tree
604            .leaves()
605            .map(|(_, leaf)| leaf)
606            .take(max_length.usize() + 1);
607
608        // Get the first leaf.
609        if let Some(leaf) = leaves.next() {
610            nodes.push(leaf.node().clone().map(Node::LeafNode));
611        } else {
612            // The tree was empty.
613            return RatchetTree::trimmed(vec![]);
614        }
615
616        // Blank parent node used for padding
617        let default_parent = TreeSyncParentNode::default();
618
619        // Get the parents.
620        let parents = self
621            .tree
622            .parents()
623            // Drop the index
624            .map(|(_, parent)| parent)
625            // Take the parents up to the max length
626            .take(max_length.usize())
627            // Pad the parents with blank nodes if needed
628            .chain(
629                (self.tree.parents().count()..self.tree.leaves().count() - 1)
630                    .map(|_| &default_parent),
631            );
632
633        // Interleave the leaves and parents.
634        for (leaf, parent) in leaves.zip(parents) {
635            nodes.push(parent.node().clone().map(Node::ParentNode));
636            nodes.push(leaf.node().clone().map(Node::LeafNode));
637        }
638
639        RatchetTree::trimmed(nodes)
640    }
641
642    /// Return a reference to the leaf at the given `LeafNodeIndex` or `None` if the
643    /// leaf is blank.
644    pub(crate) fn leaf(&self, leaf_index: LeafNodeIndex) -> Option<&LeafNode> {
645        let tsn = self.tree.leaf(leaf_index);
646        tsn.node().as_ref()
647    }
648
649    /// Returns a [`TreeSyncError`] if the `leaf_index` is not a leaf in this
650    /// tree or empty.
651    pub(crate) fn is_leaf_in_tree(&self, leaf_index: LeafNodeIndex) -> bool {
652        is_node_in_tree(leaf_index.into(), self.tree.tree_size())
653    }
654
655    /// Return a vector containing all [`EncryptionKey`]s for which the owner of
656    /// the given `leaf_index` should have private key material.
657    pub(crate) fn owned_encryption_keys(&self, leaf_index: LeafNodeIndex) -> Vec<EncryptionKey> {
658        self.empty_diff()
659            .encryption_keys(leaf_index)
660            .cloned()
661            .collect::<Vec<EncryptionKey>>()
662    }
663
664    /// Derives [`EncryptionKeyPair`]s for the nodes in the shared direct path
665    /// of the leaves with index `leaf_index` and `sender_index`.  This function
666    /// also checks that the derived public keys match the existing public keys.
667    ///
668    /// Returns the `CommitSecret` derived from the path secret of the root
669    /// node, as well as the derived [`EncryptionKeyPair`]s. Returns an error if
670    /// the target leaf is outside of the tree.
671    ///
672    /// Returns TreeSyncSetPathError::PublicKeyMismatch if the derived keys don't
673    /// match with the existing ones.
674    ///
675    /// Returns TreeSyncSetPathError::LibraryError if the sender_index is not
676    /// in the tree.
677    pub(crate) fn derive_path_secrets(
678        &self,
679        crypto: &impl OpenMlsCrypto,
680        ciphersuite: Ciphersuite,
681        mut path_secret: PathSecret,
682        sender_index: LeafNodeIndex,
683        leaf_index: LeafNodeIndex,
684    ) -> Result<(Vec<EncryptionKeyPair>, CommitSecret), DerivePathError> {
685        // We assume both nodes are in the tree, since the sender_index must be in the tree
686        // Skip the nodes in the subtree path for which we are an unmerged leaf.
687        let subtree_path = self.tree.subtree_path(leaf_index, sender_index);
688        let mut keypairs = Vec::new();
689        for parent_index in subtree_path {
690            // We know the node is in the tree, since it is in the subtree path
691            let tsn = self.tree.parent_by_index(parent_index);
692            // We only care about non-blank nodes.
693            if let Some(ref parent_node) = tsn.node() {
694                // If our own leaf index is not in the list of unmerged leaves
695                // then we should have the secret for this node.
696                if !parent_node.unmerged_leaves().contains(&leaf_index) {
697                    let keypair = path_secret.derive_key_pair(crypto, ciphersuite)?;
698                    // The derived public key should match the one in the node.
699                    // If not, the tree is corrupt.
700                    if parent_node.encryption_key() != keypair.public_key() {
701                        return Err(DerivePathError::PublicKeyMismatch);
702                    } else {
703                        // If everything is ok, set the private key and derive
704                        // the next path secret.
705                        keypairs.push(keypair);
706                        path_secret = path_secret.derive_path_secret(crypto, ciphersuite)?;
707                    }
708                };
709                // If the leaf is blank or our index is in the list of unmerged
710                // leaves, go to the next node.
711            }
712        }
713        Ok((keypairs, path_secret.into()))
714    }
715
716    /// Return a reference to the parent node at the given `ParentNodeIndex` or
717    /// `None` if the node is blank.
718    pub(crate) fn parent(&self, node_index: ParentNodeIndex) -> Option<&ParentNode> {
719        let tsn = self.tree.parent(node_index);
720        tsn.node().as_ref()
721    }
722}
723
724#[cfg(test)]
725impl TreeSync {
726    pub(crate) fn leaf_count(&self) -> u32 {
727        self.tree.leaf_count()
728    }
729}
730
731#[cfg(test)]
732mod test {
733    use super::*;
734
735    #[cfg(debug_assertions)]
736    #[test]
737    #[should_panic]
738    /// This should only panic in debug-builds.
739    fn test_ratchet_tree_internal_empty() {
740        RatchetTree::trimmed(vec![]);
741    }
742
743    #[cfg(debug_assertions)]
744    #[test]
745    #[should_panic]
746    /// This should only panic in debug-builds.
747    fn test_ratchet_tree_internal_empty_after_trim() {
748        RatchetTree::trimmed(vec![None]);
749    }
750
751    #[openmls_test::openmls_test]
752    fn test_ratchet_tree_trailing_blank_nodes(
753        ciphersuite: Ciphersuite,
754        provider: &impl OpenMlsProvider,
755    ) {
756        let (key_package, _, _) = crate::key_packages::tests::key_package(ciphersuite, provider);
757        let node_in = NodeIn::from(Node::LeafNode(LeafNode::from(key_package)));
758        let tests = [
759            (vec![], false),
760            (vec![None], false),
761            (vec![None, None], false),
762            (vec![None, None, None], false),
763            (vec![Some(node_in.clone())], true),
764            (vec![Some(node_in.clone()), None], false),
765            (
766                vec![Some(node_in.clone()), None, Some(node_in.clone())],
767                true,
768            ),
769            (
770                vec![Some(node_in.clone()), None, Some(node_in), None],
771                false,
772            ),
773        ];
774
775        for (test, expected) in tests.into_iter() {
776            let got = RatchetTree::try_from_nodes(
777                ciphersuite,
778                provider.crypto(),
779                test,
780                &GroupId::random(provider.rand()),
781            )
782            .is_ok();
783            assert_eq!(got, expected);
784        }
785    }
786
787    #[cfg(not(debug_assertions))]
788    #[test]
789    /// This should not panic in release-builds.
790    fn test_ratchet_tree_internal_empty() {
791        RatchetTree::trimmed(vec![]);
792    }
793
794    #[cfg(not(debug_assertions))]
795    #[test]
796    /// This should not panic in release-builds.
797    fn test_ratchet_tree_internal_empty_after_trim() {
798        RatchetTree::trimmed(vec![None]);
799    }
800}