Skip to main content

openmls/treesync/
mod.rs

1//! This module implements the ratchet tree component of MLS.
2//!
3//! It exposes the [`Node`] enum that can contain either a [`LeafNode`] or a [`ParentNode`].
4
5// # Internal documentation
6//
7// This module provides the [`TreeSync`] struct, which contains the state
8// shared between a group of MLS clients in the shape of a tree, where each
9// non-blank leaf corresponds to one group member. The functions provided by
10// its implementation allow the creation of a [`TreeSyncDiff`] instance, which
11// in turn can be mutably operated on and merged back into the original
12// [`TreeSync`] instance.
13//
14// The submodules of this module define the nodes of the tree (`nodes`),
15// helper functions and structs for the algorithms used to sync the tree across
16// the group ([`hashes`]) and the diff functionality ([`diff`]).
17//
18// Finally, this module contains the [`treekem`] module, which allows the
19// encryption and decryption of updates to the tree.
20
21#[cfg(any(feature = "test-utils", test))]
22use std::fmt;
23
24use openmls_traits::{
25    crypto::OpenMlsCrypto,
26    signatures::Signer,
27    types::{Ciphersuite, CryptoError},
28};
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
32
33use self::{
34    diff::{StagedTreeSyncDiff, TreeSyncDiff},
35    node::{
36        leaf_node::{
37            Capabilities, NewLeafNodeParams, TreeInfoTbs, TreePosition, VerifiableLeafNode,
38        },
39        NodeIn,
40    },
41    treesync_node::{TreeSyncLeafNode, TreeSyncNode, TreeSyncParentNode},
42};
43#[cfg(any(feature = "test-utils", test))]
44use crate::{binary_tree::array_representation::level, test_utils::bytes_to_hex};
45use crate::{
46    binary_tree::array_representation::ParentNodeIndex, treesync::node::leaf_node::LeafNodeIn,
47};
48use crate::{
49    binary_tree::{
50        array_representation::{is_node_in_tree, LeafNodeIndex, TreeSize},
51        MlsBinaryTree, MlsBinaryTreeError,
52    },
53    ciphersuite::{signable::Verifiable, Secret},
54    credentials::CredentialWithKey,
55    error::LibraryError,
56    extensions::Extensions,
57    group::{GroupId, Member},
58    key_packages::Lifetime,
59    messages::{PathSecret, PathSecretError},
60    schedule::CommitSecret,
61    storage::OpenMlsProvider,
62};
63
64// Private
65mod hashes;
66use errors::*;
67
68// Crate
69pub(crate) mod diff;
70pub(crate) mod node;
71pub(crate) mod treekem;
72pub(crate) mod treesync_node;
73
74use node::encryption_keys::EncryptionKeyPair;
75
76// Public
77pub mod errors;
78#[cfg(feature = "test-utils")]
79pub use node::encryption_keys::test_utils;
80pub use node::encryption_keys::EncryptionKey;
81
82// Public re-exports
83pub use node::{
84    leaf_node::{
85        LeafNode, LeafNodeParameters, LeafNodeParametersBuilder, LeafNodeSource,
86        LeafNodeUpdateError,
87    },
88    parent_node::ParentNode,
89    Node,
90};
91
92// Tests
93#[cfg(any(feature = "test-utils", test))]
94pub mod tests_and_kats;
95
96/// An exported ratchet tree as used in, e.g., [`GroupInfo`](crate::messages::group_info::GroupInfo).
97#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
98pub struct RatchetTree(Vec<Option<Node>>);
99
100/// An error during processing of an incoming ratchet tree.
101#[derive(Error, Debug, PartialEq, Clone)]
102pub enum RatchetTreeError {
103    /// The ratchet tree is empty.
104    #[error("The ratchet tree has no nodes.")]
105    MissingNodes,
106    /// The ratchet tree has a trailing blank node.
107    #[error("The ratchet tree has trailing blank nodes.")]
108    TrailingBlankNodes,
109    /// Invalid node signature.
110    #[error("Invalid node signature.")]
111    InvalidNodeSignature,
112    /// Wrong node type.
113    #[error("Wrong node type.")]
114    WrongNodeType,
115}
116
117impl RatchetTree {
118    /// Create a [`RatchetTree`] from a vector of nodes stripping all trailing blank nodes.
119    ///
120    /// Note: The caller must ensure to call this with a vector that is *not* empty after removing all trailing blank nodes.
121    fn trimmed(mut nodes: Vec<Option<Node>>) -> Self {
122        // Remove all trailing blank nodes.
123        match nodes.iter().enumerate().rfind(|(_, node)| node.is_some()) {
124            Some((rightmost_nonempty_position, _)) => {
125                // We need to add 1 to `rightmost_nonempty_position` to keep the rightmost node.
126                nodes.resize(rightmost_nonempty_position + 1, None);
127            }
128            None => {
129                // If there is no rightmost non-blank node, the vector consist of blank nodes only.
130                nodes.clear();
131            }
132        }
133
134        debug_assert!(!nodes.is_empty(), "Caller should have ensured that `RatchetTree::trimmed` is not called with a vector that is empty after removing all trailing blank nodes.");
135        Self(nodes)
136    }
137
138    /// Create a new [`RatchetTree`] from a vector of nodes.
139    pub(crate) fn try_from_nodes(
140        ciphersuite: Ciphersuite,
141        crypto: &impl OpenMlsCrypto,
142        nodes: Vec<Option<NodeIn>>,
143        group_id: &GroupId,
144    ) -> Result<Self, RatchetTreeError> {
145        // ValSem300: "Exported ratchet trees must not have trailing blank nodes."
146        //
147        // We can check this by only looking at the last node (if any).
148        match nodes.last() {
149            Some(None) => {
150                // The ratchet tree is not empty, i.e., has a last node, *but* the last node *is* blank.
151                Err(RatchetTreeError::TrailingBlankNodes)
152            }
153            None => {
154                // The ratchet tree is empty.
155                Err(RatchetTreeError::MissingNodes)
156            }
157            Some(Some(_)) => {
158                // The ratchet tree is not empty, i.e., has a last node, and the last node is not blank.
159
160                // Verify the nodes.
161                // https://validation.openmls.tech/#valn1407
162                let mut verified_nodes = Vec::new();
163                for (index, node) in nodes.into_iter().enumerate() {
164                    let verified_node = match (index % 2, node) {
165                        // Even indices must be leaf nodes.
166                        (0, Some(NodeIn::LeafNode(leaf_node))) => {
167                            let tree_position = TreePosition::new(
168                                group_id.clone(),
169                                LeafNodeIndex::new((index / 2) as u32),
170                            );
171                            let verifiable_leaf_node = leaf_node.into_verifiable_leaf_node();
172                            let signature_key = verifiable_leaf_node
173                                .signature_key()
174                                .clone()
175                                .into_signature_public_key_enriched(
176                                    ciphersuite.signature_algorithm(),
177                                );
178                            Some(Node::leaf_node(match verifiable_leaf_node {
179                                VerifiableLeafNode::KeyPackage(leaf_node) => leaf_node
180                                    .verify(crypto, &signature_key)
181                                    .map_err(|_| RatchetTreeError::InvalidNodeSignature)?,
182                                VerifiableLeafNode::Update(mut leaf_node) => {
183                                    leaf_node.add_tree_position(tree_position);
184                                    leaf_node
185                                        .verify(crypto, &signature_key)
186                                        .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
187                                }
188                                VerifiableLeafNode::Commit(mut leaf_node) => {
189                                    leaf_node.add_tree_position(tree_position);
190                                    leaf_node
191                                        .verify(crypto, &signature_key)
192                                        .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
193                                }
194                            }))
195                        }
196                        // Odd indices must be parent nodes.
197                        (1, Some(NodeIn::ParentNode(parent_node))) => {
198                            Some(Node::ParentNode(parent_node))
199                        }
200                        // Blank nodes.
201                        (_, None) => None,
202                        // All other cases are invalid.
203                        _ => {
204                            return Err(RatchetTreeError::WrongNodeType);
205                        }
206                    };
207                    verified_nodes.push(verified_node);
208                }
209                Ok(Self::trimmed(verified_nodes))
210            }
211        }
212    }
213
214    /// Returns an iterator over all nodes in the ratchet tree.
215    pub fn nodes(&self) -> impl Iterator<Item = &Node> {
216        self.0.iter().flatten()
217    }
218
219    /// Returns an iterator over all leaf nodes in the ratchet tree.
220    pub fn leaves(&self) -> impl Iterator<Item = &LeafNode> {
221        self.nodes().filter_map(|node| match node {
222            Node::LeafNode(leaf_node) => Some(&**leaf_node),
223            Node::ParentNode(_parent_node) => None,
224        })
225    }
226
227    /// Returns an iterator over all parent nodes in the ratchet tree.
228    pub fn parents(&self) -> impl Iterator<Item = &ParentNode> {
229        self.nodes().filter_map(|node| match node {
230            Node::ParentNode(parent_node) => Some(&**parent_node),
231            Node::LeafNode(_leaf_node) => None,
232        })
233    }
234}
235
236/// A ratchet tree made of unverified nodes. This is used for deserialization
237/// and verification.
238#[derive(
239    PartialEq,
240    Eq,
241    Clone,
242    Debug,
243    Serialize,
244    Deserialize,
245    TlsDeserialize,
246    TlsDeserializeBytes,
247    TlsSerialize,
248    TlsSize,
249)]
250pub struct RatchetTreeIn(Vec<Option<NodeIn>>);
251
252impl RatchetTreeIn {
253    /// Create a new [`RatchetTreeIn`] from a vector of nodes after verifying
254    /// the nodes.
255    pub fn into_verified(
256        self,
257        ciphersuite: Ciphersuite,
258        crypto: &impl OpenMlsCrypto,
259        group_id: &GroupId,
260    ) -> Result<RatchetTree, RatchetTreeError> {
261        RatchetTree::try_from_nodes(ciphersuite, crypto, self.0, group_id)
262    }
263
264    /// Returns an iterator over all nodes in the ratchet tree.
265    pub fn nodes(&self) -> impl Iterator<Item = &NodeIn> {
266        self.0.iter().flatten()
267    }
268
269    /// Returns an iterator over all leaf nodes in the ratchet tree.
270    pub fn leaves(&self) -> impl Iterator<Item = &LeafNodeIn> {
271        self.nodes().filter_map(|node| match node {
272            NodeIn::LeafNode(leaf_node) => Some(&**leaf_node),
273            NodeIn::ParentNode(_parent_node) => None,
274        })
275    }
276
277    /// Returns an iterator over all parent nodes in the ratchet tree.
278    pub fn parents(&self) -> impl Iterator<Item = &ParentNode> {
279        self.nodes().filter_map(|node| match node {
280            NodeIn::ParentNode(parent_node) => Some(&**parent_node),
281            NodeIn::LeafNode(_leaf_node) => None,
282        })
283    }
284
285    fn from_ratchet_tree(ratchet_tree: RatchetTree) -> Self {
286        let nodes = ratchet_tree
287            .0
288            .into_iter()
289            .map(|node| node.map(NodeIn::from))
290            .collect();
291        Self(nodes)
292    }
293
294    #[cfg(test)]
295    pub(crate) fn from_nodes(nodes: Vec<Option<NodeIn>>) -> Self {
296        Self(nodes)
297    }
298}
299
300impl From<RatchetTree> for RatchetTreeIn {
301    fn from(ratchet_tree: RatchetTree) -> Self {
302        RatchetTreeIn::from_ratchet_tree(ratchet_tree)
303    }
304}
305
306// The following `From` implementation breaks abstraction layers and MUST
307// NOT be made available outside of tests or "test-utils".
308#[cfg(any(feature = "test-utils", test))]
309impl From<RatchetTreeIn> for RatchetTree {
310    fn from(ratchet_tree_in: RatchetTreeIn) -> Self {
311        Self(
312            ratchet_tree_in
313                .0
314                .into_iter()
315                .map(|node| node.map(Node::from))
316                .collect(),
317        )
318    }
319}
320
321#[cfg(any(feature = "test-utils", test))]
322fn log2(x: u32) -> usize {
323    if x == 0 {
324        return 0;
325    }
326    (31 - x.leading_zeros()) as usize
327}
328
329#[cfg(any(feature = "test-utils", test))]
330pub(crate) fn root(size: u32) -> u32 {
331    (1 << log2(size)) - 1
332}
333
334#[cfg(any(feature = "test-utils", test))]
335impl fmt::Display for RatchetTree {
336    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337        let factor = 3;
338        let nodes = &self.0;
339        let tree_size = nodes.len() as u32;
340
341        for (i, node) in nodes.iter().enumerate() {
342            let level = level(i as u32);
343            write!(f, "{i:04}")?;
344            if let Some(node) = node {
345                let (key_bytes, parent_hash_bytes) = match node {
346                    Node::LeafNode(leaf_node) => {
347                        write!(f, "\tL      ")?;
348                        let key_bytes = leaf_node.encryption_key().as_slice();
349                        let parent_hash_bytes = leaf_node
350                            .parent_hash()
351                            .map(bytes_to_hex)
352                            .unwrap_or_default();
353                        (key_bytes, parent_hash_bytes)
354                    }
355                    Node::ParentNode(parent_node) => {
356                        if root(tree_size) == i as u32 {
357                            write!(f, "\tP (*)  ")?;
358                        } else {
359                            write!(f, "\tP      ")?;
360                        }
361                        let key_bytes = parent_node.public_key().as_slice();
362                        let parent_hash_string = bytes_to_hex(parent_node.parent_hash());
363                        (key_bytes, parent_hash_string)
364                    }
365                };
366                write!(
367                    f,
368                    "PK: {}  PH: {} | ",
369                    bytes_to_hex(key_bytes),
370                    if !parent_hash_bytes.is_empty() {
371                        parent_hash_bytes
372                    } else {
373                        str::repeat("  ", 32)
374                    }
375                )?;
376
377                write!(f, "{}◼︎", str::repeat(" ", level * factor))?;
378            } else {
379                if root(tree_size) == i as u32 {
380                    write!(
381                        f,
382                        "\t_ (*)  PK: {}  PH: {} | ",
383                        str::repeat("__", 32),
384                        str::repeat("__", 32)
385                    )?;
386                } else {
387                    write!(
388                        f,
389                        "\t_      PK: {}  PH: {} | ",
390                        str::repeat("__", 32),
391                        str::repeat("__", 32)
392                    )?;
393                }
394
395                write!(f, "{}❑", str::repeat(" ", level * factor))?;
396            }
397            writeln!(f)?;
398        }
399
400        Ok(())
401    }
402}
403
404/// The [`TreeSync`] struct holds an `MlsBinaryTree` instance, which contains
405/// the state that is synced across the group, as well as the [`LeafNodeIndex`]
406/// pointing to the leaf of this group member and the current hash of the tree.
407///
408/// It follows the same pattern of tree and diff as the underlying
409/// `MlsBinaryTree`, where the [`TreeSync`] instance is immutable safe for
410/// merging a `TreeSyncDiff`, which can be created, staged and merged (see
411/// `TreeSyncDiff`).
412///
413/// [`TreeSync`] instance guarantee a few invariants that are checked upon
414/// creating a new instance from an imported set of nodes, as well as when
415/// merging a diff.
416#[derive(Debug, Serialize, Deserialize)]
417#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq, Clone))]
418pub struct TreeSync {
419    tree: MlsBinaryTree<TreeSyncLeafNode, TreeSyncParentNode>,
420    tree_hash: Vec<u8>,
421}
422
423impl TreeSync {
424    /// Create a new tree with an own leaf for the given credential.
425    ///
426    /// Returns the resulting [`TreeSync`] instance, as well as the
427    /// corresponding [`CommitSecret`].
428    pub(crate) fn new(
429        provider: &impl OpenMlsProvider,
430        signer: &impl Signer,
431        ciphersuite: Ciphersuite,
432        credential_with_key: CredentialWithKey,
433        life_time: Lifetime,
434        capabilities: Capabilities,
435        extensions: Extensions<LeafNode>,
436    ) -> Result<(Self, CommitSecret, EncryptionKeyPair), LibraryError> {
437        let new_leaf_node_params = NewLeafNodeParams {
438            ciphersuite,
439            credential_with_key,
440            // Creation of a group is considered to be from a key package.
441            leaf_node_source: LeafNodeSource::KeyPackage(life_time),
442            capabilities,
443            extensions,
444            tree_info_tbs: TreeInfoTbs::KeyPackage,
445        };
446        let (leaf, encryption_key_pair) = LeafNode::new(provider, signer, new_leaf_node_params)?;
447
448        let node = Node::leaf_node(leaf);
449        let path_secret: PathSecret = Secret::random(ciphersuite, provider.rand())
450            .map_err(LibraryError::unexpected_crypto_error)?
451            .into();
452        let commit_secret: CommitSecret = path_secret
453            .derive_path_secret(provider.crypto(), ciphersuite)?
454            .into();
455        let nodes = vec![TreeSyncNode::from(node).into()];
456        let tree = MlsBinaryTree::new(nodes)
457            .map_err(|_| LibraryError::custom("Unexpected error creating the binary tree."))?;
458        let mut tree_sync = Self {
459            tree,
460            tree_hash: vec![],
461        };
462        // Populate tree hash caches.
463        tree_sync.populate_parent_hashes(provider.crypto(), ciphersuite)?;
464
465        Ok((tree_sync, commit_secret, encryption_key_pair))
466    }
467
468    /// Return the full tree
469    pub(crate) fn tree(&self) -> &MlsBinaryTree<TreeSyncLeafNode, TreeSyncParentNode> {
470        &self.tree
471    }
472
473    /// Return the tree hash of the root node of the tree.
474    pub(crate) fn tree_hash(&self) -> &[u8] {
475        self.tree_hash.as_slice()
476    }
477
478    /// Merge the given diff into this `TreeSync` instance, refreshing the
479    /// `tree_hash` value in the process.
480    pub(crate) fn merge_diff(&mut self, tree_sync_diff: StagedTreeSyncDiff) {
481        let (diff, new_tree_hash) = tree_sync_diff.into_parts();
482        self.tree_hash = new_tree_hash;
483        self.tree.merge_diff(diff);
484    }
485
486    /// Create an empty diff based on this [`TreeSync`] instance all operations
487    /// are created based on an initial, empty [`TreeSyncDiff`].
488    pub(crate) fn empty_diff(&self) -> TreeSyncDiff<'_> {
489        self.into()
490    }
491
492    /// A helper function that generates a [`TreeSync`] instance from the given
493    /// slice of nodes. It verifies that the provided encryption key is present
494    /// in the tree and that the invariants documented in [`TreeSync`] hold.
495    pub(crate) fn from_ratchet_tree(
496        crypto: &impl OpenMlsCrypto,
497        ciphersuite: Ciphersuite,
498        ratchet_tree: RatchetTree,
499    ) -> Result<Self, TreeSyncFromNodesError> {
500        // TODO #800: Unmerged leaves should be checked
501        let total_nodes = ratchet_tree.0.len();
502        let mut leaf_nodes = Vec::with_capacity(total_nodes.div_ceil(2));
503        let mut parent_nodes = Vec::with_capacity(total_nodes / 2);
504
505        // Set the leaf indices in all the leaves and convert the node types.
506        for (node_index, node_option) in ratchet_tree.0.into_iter().enumerate() {
507            if node_index % 2 == 0 {
508                let leaf = match node_option {
509                    Some(node) => match TreeSyncNode::from(node) {
510                        TreeSyncNode::Leaf(l) => *l,
511                        TreeSyncNode::Parent(_) => {
512                            return Err(TreeSyncFromNodesError::from(
513                                PublicTreeError::MalformedTree,
514                            ))
515                        }
516                    },
517                    None => TreeSyncLeafNode::blank(),
518                };
519                leaf_nodes.push(leaf);
520            } else {
521                let parent = match node_option {
522                    Some(node) => match TreeSyncNode::from(node) {
523                        TreeSyncNode::Parent(p) => *p,
524                        TreeSyncNode::Leaf(_) => {
525                            return Err(TreeSyncFromNodesError::from(
526                                PublicTreeError::MalformedTree,
527                            ))
528                        }
529                    },
530                    None => TreeSyncParentNode::blank(),
531                };
532                parent_nodes.push(parent);
533            }
534        }
535
536        let tree = MlsBinaryTree::from_components(leaf_nodes, parent_nodes)
537            .map_err(|_| PublicTreeError::MalformedTree)?;
538        let mut tree_sync = Self {
539            tree,
540            tree_hash: vec![],
541        };
542
543        // Verify all parent hashes.
544        tree_sync
545            .verify_parent_hashes(crypto, ciphersuite)
546            .map_err(|e| match e {
547                TreeSyncParentHashError::LibraryError(e) => e.into(),
548                TreeSyncParentHashError::InvalidParentHash => {
549                    TreeSyncFromNodesError::from(PublicTreeError::InvalidParentHash)
550                }
551            })?;
552
553        // Populate tree hash caches.
554        tree_sync.populate_parent_hashes(crypto, ciphersuite)?;
555        Ok(tree_sync)
556    }
557
558    /// Find the `LeafNodeIndex` which a new leaf would have if it were added to the
559    /// tree. This is either the left-most blank node or, if there are no blank
560    /// leaves, the leaf count, since adding a member would extend the tree by
561    /// one leaf.
562    pub(crate) fn free_leaf_index(&self) -> LeafNodeIndex {
563        let diff = self.empty_diff();
564        diff.free_leaf_index()
565    }
566
567    /// Populate the parent hash caches of all nodes in the tree.
568    fn populate_parent_hashes(
569        &mut self,
570        crypto: &impl OpenMlsCrypto,
571        ciphersuite: Ciphersuite,
572    ) -> Result<(), LibraryError> {
573        let diff = self.empty_diff();
574        // Make the diff into a staged diff. This implicitly computes the
575        // tree hashes and poulates the tree hash caches.
576        let staged_diff = diff.into_staged_diff(crypto, ciphersuite)?;
577        // Merge the diff.
578        self.merge_diff(staged_diff);
579        Ok(())
580    }
581
582    /// Verify the parent hashes of all parent nodes in the tree.
583    ///
584    /// Returns an error if one of the parent nodes in the tree has an invalid
585    /// parent hash.
586    fn verify_parent_hashes(
587        &self,
588        crypto: &impl OpenMlsCrypto,
589        ciphersuite: Ciphersuite,
590    ) -> Result<(), TreeSyncParentHashError> {
591        // The ability to verify parent hashes is required both for diffs and
592        // treesync instances. We choose the computationally slightly more
593        // expensive solution of implementing parent hash verification for the
594        // diff and creating an empty diff whenever we need to verify parent
595        // hashes for a `TreeSync` instance. At the time of writing, this
596        // happens only upon construction of a `TreeSync` instance from a vector
597        // of nodes. The alternative solution would be to create a `TreeLike`
598        // trait, which allows tree navigation and node access. We could then
599        // implement `TreeLike` for both `TreeSync` and `TreeSyncDiff` and
600        // finally implement parent hash verification for any struct that
601        // implements `TreeLike`. We choose the less complex version for now.
602        // Should this turn out to cause too much computational overhead, we
603        // should reconsider and choose the alternative sketched above
604        let diff = self.empty_diff();
605        // No need to merge the diff, since we didn't actually modify any state.
606        diff.verify_parent_hashes(crypto, ciphersuite)
607    }
608
609    /// Returns the tree size
610    pub(crate) fn tree_size(&self) -> TreeSize {
611        self.tree.tree_size()
612    }
613
614    /// Returns an iterator over the (non-blank) [`LeafNode`]s in the tree.
615    pub fn full_leaves(&self) -> impl Iterator<Item = (LeafNodeIndex, &LeafNode)> {
616        self.tree
617            .leaves()
618            .filter_map(|(index, tsn)| tsn.node().as_ref().map(|ln| (index, ln)))
619    }
620
621    /// Returns an iterator over the (non-blank) [`ParentNode`]s in the tree.
622    pub fn full_parents(&self) -> impl Iterator<Item = (ParentNodeIndex, &ParentNode)> {
623        self.tree
624            .parents()
625            .filter_map(|(index, tsn)| tsn.node().as_ref().map(|pn| (index, pn)))
626    }
627
628    /// Returns an iterator over the [`ParentNodeIndex`]es of blank [`ParentNode`]s in the tree.
629    pub fn blank_parents<'a>(&'a self) -> impl Iterator<Item = ParentNodeIndex> + 'a {
630        self.tree
631            .parents()
632            .filter_map(|(index, tsn)| tsn.node().as_ref().map_or(Some(index), |_| None))
633    }
634
635    /// Returns an iterator over the [`LeafNodeIndex`]es of blank [`LeafNode`]s in the tree.
636    pub fn blank_leaves<'a>(&'a self) -> impl Iterator<Item = LeafNodeIndex> + 'a {
637        self.tree
638            .leaves()
639            .filter_map(|(index, tsn)| tsn.node().as_ref().map_or(Some(index), |_| None))
640    }
641
642    /// Returns the index of the last full leaf in the tree.
643    fn rightmost_full_leaf(&self) -> LeafNodeIndex {
644        let mut index = LeafNodeIndex::new(0);
645        for (leaf_index, leaf) in self.tree.leaves() {
646            if leaf.node().as_ref().is_some() {
647                index = leaf_index;
648            }
649        }
650        index
651    }
652
653    /// Returns a list of [`Member`]s containing only full nodes.
654    ///
655    /// XXX: For performance reasons we probably want to have this in a borrowing
656    ///      version as well. But it might well go away again.
657    pub(crate) fn full_leaf_members(&self) -> impl Iterator<Item = Member> + '_ {
658        self.tree
659            .leaves()
660            // Filter out blank nodes
661            .filter_map(|(index, tsn)| tsn.node().as_ref().map(|node| (index, node)))
662            // Map to `Member`
663            .map(|(index, leaf_node)| {
664                Member::new(
665                    index,
666                    leaf_node.encryption_key().as_slice().to_vec(),
667                    leaf_node.signature_key().as_slice().to_vec(),
668                    leaf_node.credential().clone(),
669                )
670            })
671    }
672
673    /// Returns the nodes in the tree ordered according to the
674    /// array-representation of the underlying binary tree.
675    pub fn export_ratchet_tree(&self) -> RatchetTree {
676        let mut nodes = Vec::new();
677
678        // Determine the index of the rightmost full leaf.
679        let max_length = self.rightmost_full_leaf();
680
681        // We take all the leaves including the rightmost full leaf, blank
682        // leaves beyond that are trimmed.
683        let mut leaves = self
684            .tree
685            .leaves()
686            .map(|(_, leaf)| leaf)
687            .take(max_length.usize() + 1);
688
689        // Get the first leaf.
690        if let Some(leaf) = leaves.next() {
691            nodes.push(leaf.node().clone().map(Node::leaf_node));
692        } else {
693            // The tree was empty.
694            return RatchetTree::trimmed(vec![]);
695        }
696
697        // Blank parent node used for padding
698        let default_parent = TreeSyncParentNode::default();
699
700        // Get the parents.
701        let parents = self
702            .tree
703            .parents()
704            // Drop the index
705            .map(|(_, parent)| parent)
706            // Take the parents up to the max length
707            .take(max_length.usize())
708            // Pad the parents with blank nodes if needed
709            .chain(
710                (self.tree.parents().count()..self.tree.leaves().count() - 1)
711                    .map(|_| &default_parent),
712            );
713
714        // Interleave the leaves and parents.
715        for (leaf, parent) in leaves.zip(parents) {
716            nodes.push(parent.node().clone().map(Node::parent_node));
717            nodes.push(leaf.node().clone().map(Node::leaf_node));
718        }
719
720        RatchetTree::trimmed(nodes)
721    }
722
723    /// Return a reference to the leaf at the given `LeafNodeIndex` or `None` if the
724    /// leaf is blank.
725    pub(crate) fn leaf(&self, leaf_index: LeafNodeIndex) -> Option<&LeafNode> {
726        let tsn = self.tree.leaf(leaf_index);
727        tsn.node().as_ref()
728    }
729
730    /// Returns a [`TreeSyncError`] if the `leaf_index` is not a leaf in this
731    /// tree or empty.
732    pub(crate) fn is_leaf_in_tree(&self, leaf_index: LeafNodeIndex) -> bool {
733        is_node_in_tree(leaf_index.into(), self.tree.tree_size())
734    }
735
736    /// Return a vector containing all [`EncryptionKey`]s for which the owner of
737    /// the given `leaf_index` should have private key material.
738    pub(crate) fn owned_encryption_keys(&self, leaf_index: LeafNodeIndex) -> Vec<EncryptionKey> {
739        self.empty_diff()
740            .encryption_keys(leaf_index)
741            .cloned()
742            .collect::<Vec<EncryptionKey>>()
743    }
744
745    /// Derives [`EncryptionKeyPair`]s for the nodes in the shared direct path
746    /// of the leaves with index `leaf_index` and `sender_index`.  This function
747    /// also checks that the derived public keys match the existing public keys.
748    ///
749    /// Returns the `CommitSecret` derived from the path secret of the root
750    /// node, as well as the derived [`EncryptionKeyPair`]s. Returns an error if
751    /// the target leaf is outside of the tree.
752    ///
753    /// Returns TreeSyncSetPathError::PublicKeyMismatch if the derived keys don't
754    /// match with the existing ones.
755    ///
756    /// Returns TreeSyncSetPathError::LibraryError if the sender_index is not
757    /// in the tree.
758    pub(crate) fn derive_path_secrets(
759        &self,
760        crypto: &impl OpenMlsCrypto,
761        ciphersuite: Ciphersuite,
762        mut path_secret: PathSecret,
763        sender_index: LeafNodeIndex,
764        leaf_index: LeafNodeIndex,
765    ) -> Result<(Vec<EncryptionKeyPair>, CommitSecret), DerivePathError> {
766        // We assume both nodes are in the tree, since the sender_index must be in the tree
767        // Skip the nodes in the subtree path for which we are an unmerged leaf.
768        let subtree_path = self.tree.subtree_path(leaf_index, sender_index);
769        let mut keypairs = Vec::new();
770        for parent_index in subtree_path {
771            // We know the node is in the tree, since it is in the subtree path
772            let tsn = self.tree.parent_by_index(parent_index);
773            // We only care about non-blank nodes.
774            if let Some(ref parent_node) = tsn.node() {
775                // If our own leaf index is not in the list of unmerged leaves
776                // then we should have the secret for this node.
777                if !parent_node.unmerged_leaves().contains(&leaf_index) {
778                    let keypair = path_secret.derive_key_pair(crypto, ciphersuite)?;
779                    // The derived public key should match the one in the node.
780                    // If not, the tree is corrupt.
781                    if parent_node.encryption_key() != keypair.public_key() {
782                        return Err(DerivePathError::PublicKeyMismatch);
783                    } else {
784                        // If everything is ok, set the private key and derive
785                        // the next path secret.
786                        keypairs.push(keypair);
787                        path_secret = path_secret.derive_path_secret(crypto, ciphersuite)?;
788                    }
789                };
790                // If the leaf is blank or our index is in the list of unmerged
791                // leaves, go to the next node.
792            }
793        }
794        Ok((keypairs, path_secret.into()))
795    }
796
797    /// Return a reference to the parent node at the given `ParentNodeIndex` or
798    /// `None` if the node is blank.
799    pub(crate) fn parent(&self, node_index: ParentNodeIndex) -> Option<&ParentNode> {
800        let tsn = self.tree.parent(node_index);
801        tsn.node().as_ref()
802    }
803}
804
805#[cfg(test)]
806impl TreeSync {
807    pub(crate) fn leaf_count(&self) -> u32 {
808        self.tree.leaf_count()
809    }
810}
811
812#[cfg(test)]
813mod test {
814    use super::*;
815
816    #[cfg(debug_assertions)]
817    #[test]
818    #[should_panic]
819    /// This should only panic in debug-builds.
820    fn test_ratchet_tree_internal_empty() {
821        RatchetTree::trimmed(vec![]);
822    }
823
824    #[cfg(debug_assertions)]
825    #[test]
826    #[should_panic]
827    /// This should only panic in debug-builds.
828    fn test_ratchet_tree_internal_empty_after_trim() {
829        RatchetTree::trimmed(vec![None]);
830    }
831
832    #[openmls_test::openmls_test]
833    fn test_ratchet_tree_trailing_blank_nodes() {
834        let provider = &Provider::default();
835        let (key_package, _, _) = crate::key_packages::tests::key_package(ciphersuite, provider);
836        let node_in = NodeIn::from(Node::leaf_node(LeafNode::from(key_package)));
837        let tests = [
838            (vec![], false),
839            (vec![None], false),
840            (vec![None, None], false),
841            (vec![None, None, None], false),
842            (vec![Some(node_in.clone())], true),
843            (vec![Some(node_in.clone()), None], false),
844            (
845                vec![Some(node_in.clone()), None, Some(node_in.clone())],
846                true,
847            ),
848            (
849                vec![Some(node_in.clone()), None, Some(node_in), None],
850                false,
851            ),
852        ];
853
854        for (test, expected) in tests.into_iter() {
855            let got = RatchetTree::try_from_nodes(
856                ciphersuite,
857                provider.crypto(),
858                test,
859                &GroupId::random(provider.rand()),
860            )
861            .is_ok();
862            assert_eq!(got, expected);
863        }
864    }
865
866    #[cfg(not(debug_assertions))]
867    #[test]
868    /// This should not panic in release-builds.
869    fn test_ratchet_tree_internal_empty() {
870        RatchetTree::trimmed(vec![]);
871    }
872
873    #[cfg(not(debug_assertions))]
874    #[test]
875    /// This should not panic in release-builds.
876    fn test_ratchet_tree_internal_empty_after_trim() {
877        RatchetTree::trimmed(vec![None]);
878    }
879}