openmls/treesync/node/
parent_node.rs

1//! This module contains the [`ParentNode`] struct, its implementation, as well
2//! as the [`PlainUpdatePathNode`], a helper struct for the creation of
3//! [`UpdatePathNode`] instances.
4use openmls_traits::crypto::OpenMlsCrypto;
5use openmls_traits::types::{Ciphersuite, HpkeCiphertext};
6#[cfg(not(target_arch = "wasm32"))]
7use rayon::prelude::*;
8use serde::{Deserialize, Serialize};
9use thiserror::*;
10use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes};
11
12use super::encryption_keys::{EncryptionKey, EncryptionKeyPair};
13use crate::{
14    binary_tree::array_representation::{LeafNodeIndex, ParentNodeIndex},
15    ciphersuite::HpkePublicKey,
16    error::LibraryError,
17    messages::PathSecret,
18    schedule::CommitSecret,
19    treesync::{hashes::ParentHashInput, treekem::UpdatePathNode},
20};
21
22/// This struct implements the MLS parent node. It contains its public key,
23/// parent hash and unmerged leaves. Additionally, it may contain the private
24/// key corresponding to the public key.
25#[derive(
26    Debug,
27    Eq,
28    PartialEq,
29    Clone,
30    Serialize,
31    Deserialize,
32    TlsSerialize,
33    TlsDeserialize,
34    TlsDeserializeBytes,
35    TlsSize,
36)]
37pub struct ParentNode {
38    pub(super) encryption_key: EncryptionKey,
39    pub(super) parent_hash: VLBytes,
40    pub(super) unmerged_leaves: UnmergedLeaves,
41}
42
43impl From<EncryptionKey> for ParentNode {
44    fn from(public_key: EncryptionKey) -> Self {
45        Self {
46            encryption_key: public_key,
47            parent_hash: vec![].into(),
48            unmerged_leaves: UnmergedLeaves::new(),
49        }
50    }
51}
52
53/// Helper struct for the encryption of a [`ParentNode`].
54#[cfg_attr(test, derive(Clone))]
55#[derive(Debug)]
56pub(crate) struct PlainUpdatePathNode {
57    public_key: EncryptionKey,
58    path_secret: PathSecret,
59}
60
61impl PlainUpdatePathNode {
62    /// Encrypt this node and return the resulting [`UpdatePathNode`].
63    pub(in crate::treesync) fn encrypt(
64        &self,
65        crypto: &impl OpenMlsCrypto,
66        ciphersuite: Ciphersuite,
67        public_keys: &[EncryptionKey],
68        group_context: &[u8],
69    ) -> Result<UpdatePathNode, LibraryError> {
70        #[cfg(target_arch = "wasm32")]
71        let public_keys = public_keys.iter();
72        #[cfg(not(target_arch = "wasm32"))]
73        let public_keys = public_keys.par_iter();
74
75        public_keys
76            .map(|pk| {
77                self.path_secret
78                    .encrypt(crypto, ciphersuite, pk, group_context)
79            })
80            .collect::<Result<Vec<HpkeCiphertext>, LibraryError>>()
81            .map(|encrypted_path_secrets| UpdatePathNode {
82                public_key: self.public_key.clone(),
83                encrypted_path_secrets,
84            })
85    }
86
87    /// Return a reference to the `path_secret` of this node.
88    pub(in crate::treesync) fn path_secret(&self) -> &PathSecret {
89        &self.path_secret
90    }
91
92    #[cfg(test)]
93    pub(crate) fn new(public_key: EncryptionKey, path_secret: PathSecret) -> Self {
94        Self {
95            public_key,
96            path_secret,
97        }
98    }
99}
100
101/// The result of a path derivation result containing the vector of
102/// [`ParentNode`], as well as [`PlainUpdatePathNode`] instance and a
103/// [`CommitSecret`].
104pub(in crate::treesync) type PathDerivationResult = (
105    Vec<(ParentNodeIndex, ParentNode)>,
106    Vec<PlainUpdatePathNode>,
107    Vec<EncryptionKeyPair>,
108    CommitSecret,
109);
110
111impl ParentNode {
112    /// Derives a path from the given path secret, where the `node_secret` of
113    /// the first node is immediately derived from the given `path_secret`.
114    ///
115    /// Returns the resulting vector of [`ParentNode`] instances, as well as the
116    /// intermediary [`PathSecret`]s, and the [`CommitSecret`].
117    pub(crate) fn derive_path(
118        crypto: &impl OpenMlsCrypto,
119        ciphersuite: Ciphersuite,
120        path_secret: PathSecret,
121        path_indices: Vec<ParentNodeIndex>,
122    ) -> Result<PathDerivationResult, LibraryError> {
123        let mut next_path_secret = path_secret;
124        let mut path_secrets = Vec::with_capacity(path_indices.len());
125
126        for _ in 0..path_indices.len() {
127            let path_secret = next_path_secret;
128            // Derive the next path secret.
129            next_path_secret = path_secret.derive_path_secret(crypto, ciphersuite)?;
130            path_secrets.push(path_secret);
131        }
132
133        type PathDerivationResults = (
134            Vec<((ParentNodeIndex, ParentNode), EncryptionKeyPair)>,
135            Vec<PlainUpdatePathNode>,
136        );
137
138        // Iterate over the path secrets and derive a key pair
139
140        #[cfg(not(target_arch = "wasm32"))]
141        let path_secrets = path_secrets.into_par_iter();
142        #[cfg(target_arch = "wasm32")]
143        let path_secrets = path_secrets.into_iter();
144
145        let (path_with_keypairs, update_path_nodes): PathDerivationResults = path_secrets
146            .zip(path_indices)
147            .map(|(path_secret, index)| {
148                // Derive a key pair from the path secret. This includes the
149                // intermediate derivation of a node secret.
150                let keypair = path_secret.derive_key_pair(crypto, ciphersuite)?;
151                let parent_node = ParentNode::from(keypair.public_key().clone());
152                // Store the current path secret and the derived public key for
153                // later encryption.
154                let update_path_node = PlainUpdatePathNode {
155                    public_key: keypair.public_key().clone(),
156                    path_secret,
157                };
158                Ok((((index, parent_node), keypair), update_path_node))
159            })
160            .collect::<Result<
161                Vec<(
162                    ((ParentNodeIndex, ParentNode), EncryptionKeyPair),
163                    PlainUpdatePathNode,
164                )>,
165                LibraryError,
166            >>()?
167            .into_iter()
168            .unzip();
169
170        let (path, keypairs) = path_with_keypairs.into_iter().unzip();
171
172        let commit_secret = next_path_secret.into();
173        Ok((path, update_path_nodes, keypairs, commit_secret))
174    }
175
176    /// Return a reference to the `public_key` of this node.
177    pub(crate) fn public_key(&self) -> &HpkePublicKey {
178        self.encryption_key.key()
179    }
180
181    /// Return a reference to the `public_key` of this node.
182    pub(crate) fn encryption_key(&self) -> &EncryptionKey {
183        &self.encryption_key
184    }
185
186    /// Get the list of unmerged leaves.
187    pub(crate) fn unmerged_leaves(&self) -> &[LeafNodeIndex] {
188        self.unmerged_leaves.list()
189    }
190
191    /// Set the list of unmerged leaves.
192    pub(in crate::treesync) fn set_unmerged_leaves(&mut self, unmerged_leaves: Vec<LeafNodeIndex>) {
193        self.unmerged_leaves.set_list(unmerged_leaves);
194    }
195
196    /// Add a [`LeafNodeIndex`] to the node's list of unmerged leaves.
197    pub(in crate::treesync) fn add_unmerged_leaf(&mut self, leaf_index: LeafNodeIndex) {
198        self.unmerged_leaves.add(leaf_index);
199    }
200
201    /// Compute the parent hash value of this node.
202    pub(in crate::treesync) fn compute_parent_hash(
203        &self,
204        crypto: &impl OpenMlsCrypto,
205        ciphersuite: Ciphersuite,
206        original_child_resolution: &[u8],
207    ) -> Result<Vec<u8>, LibraryError> {
208        let parent_hash_input = ParentHashInput::new(
209            self.encryption_key.key(),
210            self.parent_hash(),
211            original_child_resolution,
212        );
213        parent_hash_input.hash(crypto, ciphersuite)
214    }
215
216    /// Set the `parent_hash` of this node.
217    pub(in crate::treesync) fn set_parent_hash(&mut self, parent_hash: Vec<u8>) {
218        self.parent_hash = parent_hash.into()
219    }
220
221    /// Get the parent hash value of this node.
222    pub(crate) fn parent_hash(&self) -> &[u8] {
223        self.parent_hash.as_slice()
224    }
225}
226
227/// A helper struct that maintains a sorted list of unmerged leaves.
228#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize, TlsSize, TlsSerialize)]
229pub(in crate::treesync) struct UnmergedLeaves {
230    list: Vec<LeafNodeIndex>,
231}
232
233impl UnmergedLeaves {
234    pub(in crate::treesync) fn new() -> Self {
235        Self { list: Vec::new() }
236    }
237
238    pub(in crate::treesync) fn add(&mut self, leaf_index: LeafNodeIndex) {
239        // The list of unmerged leaves must be sorted. This is enforced upon
240        // deserialization. We can therefore safely insert the new leaf at the
241        // correct position.
242        let position = self.list.binary_search(&leaf_index).unwrap_or_else(|e| e);
243        self.list.insert(position, leaf_index);
244    }
245
246    pub(in crate::treesync) fn list(&self) -> &[LeafNodeIndex] {
247        self.list.as_slice()
248    }
249
250    /// Set the list of unmerged leaves.
251    pub(in crate::treesync) fn set_list(&mut self, list: Vec<LeafNodeIndex>) {
252        self.list = list;
253    }
254}
255
256#[derive(Error, Debug)]
257pub(in crate::treesync) enum UnmergedLeavesError {
258    /// The list of leaves is not sorted.
259    #[error("The list of leaves is not sorted.")]
260    NotSorted,
261}
262
263impl TryFrom<Vec<LeafNodeIndex>> for UnmergedLeaves {
264    type Error = UnmergedLeavesError;
265
266    fn try_from(list: Vec<LeafNodeIndex>) -> Result<Self, Self::Error> {
267        // The list of unmerged leaves must be sorted.
268        if !list.windows(2).all(|e| e[0] < e[1]) {
269            return Err(UnmergedLeavesError::NotSorted);
270        }
271        Ok(Self { list })
272    }
273}