openmls/treesync/
treesync_node.rs

1//! This module contains the [`TreeSyncNode`] struct and its implementation.
2
3use std::collections::HashSet;
4
5use openmls_traits::crypto::OpenMlsCrypto;
6use openmls_traits::types::Ciphersuite;
7use serde::{Deserialize, Serialize};
8use tls_codec::VLByteSlice;
9
10use crate::{
11    binary_tree::array_representation::{tree::TreeNode, LeafNodeIndex},
12    error::LibraryError,
13};
14
15use super::{hashes::TreeHashInput, LeafNode, Node, ParentNode};
16
17/// A node in the MLS tree.
18pub(crate) enum TreeSyncNode {
19    Leaf(Box<TreeSyncLeafNode>),
20    Parent(Box<TreeSyncParentNode>),
21}
22
23impl From<Node> for TreeSyncNode {
24    fn from(node: Node) -> Self {
25        match node {
26            Node::LeafNode(leaf) => TreeSyncNode::Leaf(Box::new((*leaf).into())),
27            Node::ParentNode(parent) => TreeSyncNode::Parent(Box::new((*parent).into())),
28        }
29    }
30}
31
32impl From<TreeSyncNode> for Option<Node> {
33    fn from(tsn: TreeSyncNode) -> Self {
34        match tsn {
35            TreeSyncNode::Leaf(leaf) => (*leaf).into(),
36            TreeSyncNode::Parent(parent) => (*parent).into(),
37        }
38    }
39}
40
41impl From<TreeNode<TreeSyncLeafNode, TreeSyncParentNode>> for TreeSyncNode {
42    fn from(tree_node: TreeNode<TreeSyncLeafNode, TreeSyncParentNode>) -> Self {
43        match tree_node {
44            TreeNode::Leaf(leaf) => TreeSyncNode::Leaf(leaf),
45            TreeNode::Parent(parent) => TreeSyncNode::Parent(parent),
46        }
47    }
48}
49
50impl From<TreeSyncNode> for TreeNode<TreeSyncLeafNode, TreeSyncParentNode> {
51    fn from(tsn: TreeSyncNode) -> Self {
52        match tsn {
53            TreeSyncNode::Leaf(leaf) => TreeNode::Leaf(leaf),
54            TreeSyncNode::Parent(parent) => TreeNode::Parent(parent),
55        }
56    }
57}
58
59#[derive(Debug, Clone, Default, Serialize, Deserialize)]
60#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
61/// This intermediate struct on top of `Option<Node>` allows us to cache tree
62/// hash values. Blank nodes are represented by [`TreeSyncNode`] instances where
63/// `node = None`.
64pub(crate) struct TreeSyncLeafNode {
65    node: Option<LeafNode>,
66}
67
68impl TreeSyncLeafNode {
69    /// Create a blank [`TreeSyncLeafNode`].
70    pub(in crate::treesync) fn blank() -> Self {
71        Self::default()
72    }
73
74    /// Return a reference to the contained `Option<Node>`.
75    pub(in crate::treesync) fn node(&self) -> &Option<LeafNode> {
76        &self.node
77    }
78
79    /// Compute the tree hash for this node, thus populating the `tree_hash`
80    /// field.
81    pub(in crate::treesync) fn compute_tree_hash(
82        &self,
83        crypto: &impl OpenMlsCrypto,
84        ciphersuite: Ciphersuite,
85        leaf_index: LeafNodeIndex,
86    ) -> Result<Vec<u8>, LibraryError> {
87        let hash_input = TreeHashInput::new_leaf(&leaf_index, self.node.as_ref());
88        let hash = hash_input.hash(crypto, ciphersuite)?;
89
90        Ok(hash)
91    }
92}
93
94impl From<LeafNode> for TreeSyncLeafNode {
95    fn from(node: LeafNode) -> Self {
96        Self { node: Some(node) }
97    }
98}
99
100impl From<LeafNode> for Box<TreeSyncLeafNode> {
101    fn from(node: LeafNode) -> Self {
102        Box::new(TreeSyncLeafNode { node: Some(node) })
103    }
104}
105
106impl From<TreeSyncLeafNode> for Option<Node> {
107    fn from(tsln: TreeSyncLeafNode) -> Self {
108        tsln.node.map(|n| Node::LeafNode(Box::new(n)))
109    }
110}
111
112#[derive(Debug, Clone, Default, Serialize, Deserialize)]
113#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
114/// This intermediate struct on top of `Option<Node>` allows us to cache tree
115/// hash values. Blank nodes are represented by [`TreeSyncNode`] instances where
116/// `node = None`.
117pub(crate) struct TreeSyncParentNode {
118    node: Option<ParentNode>,
119}
120
121impl TreeSyncParentNode {
122    /// Create a blank [`TreeSyncParentNode`].
123    pub(in crate::treesync) fn blank() -> Self {
124        Self::default()
125    }
126
127    /// Return a reference to the contained `Option<Node>`.
128    pub(in crate::treesync) fn node(&self) -> &Option<ParentNode> {
129        &self.node
130    }
131
132    /// Return a mutable reference to the contained `Option<Node>`.
133    pub(in crate::treesync) fn node_mut(&mut self) -> &mut Option<ParentNode> {
134        &mut self.node
135    }
136
137    /// Compute the tree hash for this node. Leaf nodes from the exclusion list
138    /// are filtered out.
139    pub(in crate::treesync) fn compute_tree_hash(
140        &self,
141        crypto: &impl OpenMlsCrypto,
142        ciphersuite: Ciphersuite,
143        left_hash: Vec<u8>,
144        right_hash: Vec<u8>,
145        exclusion_list: &HashSet<&LeafNodeIndex>,
146    ) -> Result<Vec<u8>, LibraryError> {
147        let hash = if exclusion_list.is_empty() {
148            // If the exclusion list is empty, we can just use the parent node
149            TreeHashInput::new_parent(
150                self.node.as_ref(),
151                VLByteSlice(&left_hash),
152                VLByteSlice(&right_hash),
153            )
154            .hash(crypto, ciphersuite)?
155        } else if let Some(parent_node) = self.node.as_ref() {
156            // If the exclusion list is not empty, we need to create a new
157            // parent node without the excluded indices in the unmerged leaves.
158            let mut new_node = parent_node.clone();
159            let unmerged_leaves = new_node
160                .unmerged_leaves()
161                .iter()
162                .filter(|leaf| !exclusion_list.contains(leaf))
163                .cloned()
164                .collect();
165            new_node.set_unmerged_leaves(unmerged_leaves);
166            TreeHashInput::new_parent(
167                Some(&new_node),
168                VLByteSlice(&left_hash),
169                VLByteSlice(&right_hash),
170            )
171            .hash(crypto, ciphersuite)?
172        } else {
173            // If the node is blank
174            TreeHashInput::new_parent(None, VLByteSlice(&left_hash), VLByteSlice(&right_hash))
175                .hash(crypto, ciphersuite)?
176        };
177
178        Ok(hash)
179    }
180}
181
182impl From<ParentNode> for TreeSyncParentNode {
183    fn from(node: ParentNode) -> Self {
184        Self { node: Some(node) }
185    }
186}
187
188impl From<ParentNode> for Box<TreeSyncParentNode> {
189    fn from(node: ParentNode) -> Self {
190        Box::new(TreeSyncParentNode { node: Some(node) })
191    }
192}
193
194impl From<TreeSyncParentNode> for Option<Node> {
195    fn from(tspn: TreeSyncParentNode) -> Self {
196        tspn.node.map(|n| Node::ParentNode(Box::new(n)))
197    }
198}