openmls/tree/
secret_tree.rs

1use openmls_traits::crypto::OpenMlsCrypto;
2use openmls_traits::types::{Ciphersuite, CryptoError};
3use thiserror::Error;
4use tls_codec::{Error as TlsCodecError, TlsSerialize, TlsSize};
5
6use super::*;
7use crate::{
8    binary_tree::{
9        array_representation::{
10            direct_path, left, right, root, ParentNodeIndex, TreeNodeIndex, TreeSize,
11        },
12        LeafNodeIndex,
13    },
14    framing::*,
15    schedule::*,
16    tree::sender_ratchet::*,
17};
18
19/// Secret tree error
20#[derive(Error, Debug, Eq, PartialEq, Clone)]
21pub enum SecretTreeError {
22    /// Generation is too old to be processed.
23    #[error("Generation is too old to be processed.")]
24    TooDistantInThePast,
25    /// Generation is too far in the future to be processed.
26    #[error("Generation is too far in the future to be processed.")]
27    TooDistantInTheFuture,
28    /// Index out of bounds
29    #[error("Index out of bounds")]
30    IndexOutOfBounds,
31    /// The requested secret was deleted to preserve forward secrecy.
32    #[error("The requested secret was deleted to preserve forward secrecy.")]
33    SecretReuseError,
34    /// Cannot create decryption secrets from own sender ratchet or encryption secrets from the sender ratchets of other members.
35    #[error("Cannot create decryption secrets from own sender ratchet or encryption secrets from the sender ratchets of other members.")]
36    RatchetTypeError,
37    /// Ratchet generation has reached `u32::MAX`.
38    #[error("Ratchet generation has reached `u32::MAX`.")]
39    RatchetTooLong,
40    /// An unrecoverable error has occurred due to a bug in the implementation.
41    #[error("An unrecoverable error has occurred due to a bug in the implementation.")]
42    LibraryError,
43    /// See [`TlsCodecError`] for more details.
44    #[error(transparent)]
45    CodecError(#[from] TlsCodecError),
46    /// See [`CryptoError`] for more details.
47    #[error(transparent)]
48    CryptoError(#[from] CryptoError),
49}
50
51#[derive(Debug, Copy, Clone)]
52pub(crate) enum SecretType {
53    HandshakeSecret,
54    ApplicationSecret,
55}
56
57impl From<&ContentType> for SecretType {
58    fn from(content_type: &ContentType) -> SecretType {
59        match content_type {
60            ContentType::Application => SecretType::ApplicationSecret,
61            ContentType::Commit => SecretType::HandshakeSecret,
62            ContentType::Proposal => SecretType::HandshakeSecret,
63        }
64    }
65}
66
67impl From<&PublicMessage> for SecretType {
68    fn from(public_message: &PublicMessage) -> SecretType {
69        SecretType::from(&public_message.content_type())
70    }
71}
72
73pub(crate) fn derive_child_secrets(
74    parent_secret: &Secret,
75    crypto: &impl OpenMlsCrypto,
76    ciphersuite: Ciphersuite,
77) -> Result<(Secret, Secret), CryptoError> {
78    let left_child = parent_secret.kdf_expand_label(
79        crypto,
80        ciphersuite,
81        "tree",
82        b"left",
83        ciphersuite.hash_length(),
84    )?;
85    let right_child = parent_secret.kdf_expand_label(
86        crypto,
87        ciphersuite,
88        "tree",
89        b"right",
90        ciphersuite.hash_length(),
91    )?;
92    Ok((left_child, right_child))
93}
94
95/// Derives secrets for inner nodes of a SecretTree. This function corresponds
96/// to the `DeriveTreeSecret` defined in Section 10.1 of the MLS specification.
97#[inline]
98pub(crate) fn derive_tree_secret(
99    ciphersuite: Ciphersuite,
100    secret: &Secret,
101    label: &str,
102    generation: u32,
103    length: usize,
104    crypto: &impl OpenMlsCrypto,
105) -> Result<Secret, SecretTreeError> {
106    log::debug!(
107        "Derive tree secret with label \"{label}\" in generation {generation} of length {length}"
108    );
109    log_crypto!(trace, "Input secret {:x?}", secret.as_slice());
110
111    let secret = secret.kdf_expand_label(
112        crypto,
113        ciphersuite,
114        label,
115        &generation.to_be_bytes(),
116        length,
117    )?;
118    log_crypto!(trace, "Derived secret {:x?}", secret.as_slice());
119    Ok(secret)
120}
121
122#[derive(Debug, TlsSerialize, TlsSize)]
123pub(crate) struct TreeContext {
124    pub(crate) node: u32,
125    pub(crate) generation: u32,
126}
127
128#[derive(Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
129#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
130pub(crate) struct SecretTreeNode {
131    pub(crate) secret: Secret,
132}
133
134#[derive(Serialize, Deserialize)]
135#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
136#[cfg_attr(any(feature = "crypto-debug", test), derive(Debug))]
137pub(crate) struct SecretTree {
138    own_index: LeafNodeIndex,
139    leaf_nodes: Vec<Option<SecretTreeNode>>,
140    parent_nodes: Vec<Option<SecretTreeNode>>,
141    handshake_sender_ratchets: Vec<Option<SenderRatchet>>,
142    application_sender_ratchets: Vec<Option<SenderRatchet>>,
143    size: TreeSize,
144}
145
146impl SecretTree {
147    /// Creates a new SecretTree based on an `encryption_secret` and group size
148    /// `size`. The inner nodes of the tree and the SenderRatchets only get
149    /// initialized when secrets are requested either through `secret()`
150    /// or `next_secret()`.
151    pub(crate) fn new(
152        encryption_secret: EncryptionSecret,
153        size: TreeSize,
154        own_index: LeafNodeIndex,
155    ) -> Self {
156        let leaf_count = size.leaf_count() as usize;
157        let leaf_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
158        let parent_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
159        let handshake_sender_ratchets = std::iter::repeat_with(|| None).take(leaf_count).collect();
160        let application_sender_ratchets =
161            std::iter::repeat_with(|| None).take(leaf_count).collect();
162
163        let mut secret_tree = SecretTree {
164            own_index,
165            leaf_nodes,
166            parent_nodes,
167            handshake_sender_ratchets,
168            application_sender_ratchets,
169            size,
170        };
171
172        // Set the encryption secret in the root node. We ignore the Result
173        // here, since the we rely on the tree math to be correct, i.e.
174        // root(size) < size.
175        let _ = secret_tree.set_node(
176            root(size),
177            Some(SecretTreeNode {
178                secret: encryption_secret.consume_secret(),
179            }),
180        );
181
182        secret_tree
183    }
184
185    /// Get current generation for a specific SenderRatchet
186    #[cfg(test)]
187    pub(crate) fn generation(&self, index: LeafNodeIndex, secret_type: SecretType) -> u32 {
188        match self
189            .ratchet_opt(index, secret_type)
190            .expect("Index out of bounds.")
191        {
192            Some(sender_ratchet) => sender_ratchet.generation(),
193            None => 0,
194        }
195    }
196
197    /// Initializes a specific SenderRatchet pair for a given index by
198    /// calculating and deleting the appropriate values in the SecretTree
199    fn initialize_sender_ratchets(
200        &mut self,
201        ciphersuite: Ciphersuite,
202        crypto: &impl OpenMlsCrypto,
203        index: LeafNodeIndex,
204    ) -> Result<(), SecretTreeError> {
205        log::trace!("Initializing sender ratchets for {index:?} with {ciphersuite}");
206        if index.u32() >= self.size.leaf_count() {
207            log::error!("Index is larger than the tree size.");
208            return Err(SecretTreeError::IndexOutOfBounds);
209        }
210        // Check if SenderRatchets are already initialized
211        if self
212            .ratchet_opt(index, SecretType::HandshakeSecret)?
213            .is_some()
214            && self
215                .ratchet_opt(index, SecretType::ApplicationSecret)?
216                .is_some()
217        {
218            log::trace!("The sender ratchets are initialized already.");
219            return Ok(());
220        }
221
222        // If we don't have a secret in the leaf node, we derive it
223        if self.get_node(index.into())?.is_none() {
224            // Collect empty nodes in the direct path until a non-empty node is
225            // found
226            let mut empty_nodes: Vec<ParentNodeIndex> = Vec::new();
227            let direct_path = direct_path(index, self.size);
228            log::trace!("Direct path for node {index:?}: {direct_path:?}");
229            for parent_node in direct_path {
230                empty_nodes.push(parent_node);
231                // Stop if we find a non-empty node
232                if self.get_node(parent_node.into())?.is_some() {
233                    break;
234                }
235            }
236
237            // Invert direct path
238            empty_nodes.reverse();
239
240            // Derive the secrets down all the way to the leaf node
241            for n in empty_nodes {
242                log::trace!("Derive down for parent node {n:?}.");
243                self.derive_down(ciphersuite, crypto, n)?;
244            }
245        }
246
247        // Calculate node secret and initialize SenderRatchets
248        let node_secret = match self.get_node(index.into())? {
249            Some(node) => &node.secret,
250            // We just derived all necessary nodes so this should not happen
251            None => {
252                return Err(SecretTreeError::LibraryError);
253            }
254        };
255
256        log::trace!("Deriving leaf node secrets for leaf {index:?}");
257
258        let handshake_ratchet_secret = node_secret.kdf_expand_label(
259            crypto,
260            ciphersuite,
261            "handshake",
262            b"",
263            ciphersuite.hash_length(),
264        )?;
265        let application_ratchet_secret = node_secret.kdf_expand_label(
266            crypto,
267            ciphersuite,
268            "application",
269            b"",
270            ciphersuite.hash_length(),
271        )?;
272
273        log_crypto!(
274            trace,
275            "handshake ratchet secret {handshake_ratchet_secret:x?}"
276        );
277        log_crypto!(
278            trace,
279            "application ratchet secret {application_ratchet_secret:x?}"
280        );
281
282        // Initialize SenderRatchets, we differentiate between the own SenderRatchets
283        // and the SenderRatchets of other members
284        let (handshake_sender_ratchet, application_sender_ratchet) = if index == self.own_index {
285            let handshake_sender_ratchet = SenderRatchet::EncryptionRatchet(
286                RatchetSecret::initial_ratchet_secret(handshake_ratchet_secret),
287            );
288            let application_sender_ratchet = SenderRatchet::EncryptionRatchet(
289                RatchetSecret::initial_ratchet_secret(application_ratchet_secret),
290            );
291
292            (handshake_sender_ratchet, application_sender_ratchet)
293        } else {
294            let handshake_sender_ratchet =
295                SenderRatchet::DecryptionRatchet(DecryptionRatchet::new(handshake_ratchet_secret));
296            let application_sender_ratchet = SenderRatchet::DecryptionRatchet(
297                DecryptionRatchet::new(application_ratchet_secret),
298            );
299
300            (handshake_sender_ratchet, application_sender_ratchet)
301        };
302
303        *self
304            .handshake_sender_ratchets
305            .get_mut(index.usize())
306            .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(handshake_sender_ratchet);
307        *self
308            .application_sender_ratchets
309            .get_mut(index.usize())
310            .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(application_sender_ratchet);
311
312        // Delete leaf node
313        self.set_node(index.into(), None)
314    }
315
316    /// Return RatchetSecrets for a given index and generation. This should be
317    /// called when decrypting an PrivateMessage received from another member.
318    /// Returns an error if index or generation are out of bound.
319    pub(crate) fn secret_for_decryption(
320        &mut self,
321        ciphersuite: Ciphersuite,
322        crypto: &impl OpenMlsCrypto,
323        index: LeafNodeIndex,
324        secret_type: SecretType,
325        generation: u32,
326        configuration: &SenderRatchetConfiguration,
327    ) -> Result<RatchetKeyMaterial, SecretTreeError> {
328        log::debug!(
329            "Generating {secret_type:?} decryption secret for {index:?} in generation {generation} with {ciphersuite}",
330        );
331        // Check tree bounds
332        if index.u32() >= self.size.leaf_count() {
333            log::error!("Sender index is not in the tree.");
334            return Err(SecretTreeError::IndexOutOfBounds);
335        }
336        if self.ratchet_opt(index, secret_type)?.is_none() {
337            log::trace!("   initialize sender ratchets");
338            self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
339        }
340        match self.ratchet_mut(index, secret_type)? {
341            SenderRatchet::EncryptionRatchet(_) => {
342                log::error!("This is the wrong ratchet type.");
343                Err(SecretTreeError::RatchetTypeError)
344            }
345            SenderRatchet::DecryptionRatchet(dec_ratchet) => {
346                log::trace!("   getting secret for decryption");
347                dec_ratchet.secret_for_decryption(ciphersuite, crypto, generation, configuration)
348            }
349        }
350    }
351
352    /// Return the next RatchetSecrets that should be used for encryption and
353    /// then increments the generation.
354    pub(crate) fn secret_for_encryption(
355        &mut self,
356        ciphersuite: Ciphersuite,
357        crypto: &impl OpenMlsCrypto,
358        index: LeafNodeIndex,
359        secret_type: SecretType,
360    ) -> Result<(u32, RatchetKeyMaterial), SecretTreeError> {
361        if self.ratchet_opt(index, secret_type)?.is_none() {
362            self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
363        }
364        match self.ratchet_mut(index, secret_type)? {
365            SenderRatchet::DecryptionRatchet(_) => {
366                log::error!("Invalid ratchet type. Got decryption, expected encryption.");
367                Err(SecretTreeError::RatchetTypeError)
368            }
369            SenderRatchet::EncryptionRatchet(enc_ratchet) => {
370                enc_ratchet.ratchet_forward(crypto, ciphersuite)
371            }
372        }
373    }
374
375    /// Returns a mutable reference to a specific SenderRatchet. The
376    /// SenderRatchet needs to be initialized.
377    fn ratchet_mut(
378        &mut self,
379        index: LeafNodeIndex,
380        secret_type: SecretType,
381    ) -> Result<&mut SenderRatchet, SecretTreeError> {
382        let sender_ratchets = match secret_type {
383            SecretType::HandshakeSecret => &mut self.handshake_sender_ratchets,
384            SecretType::ApplicationSecret => &mut self.application_sender_ratchets,
385        };
386        sender_ratchets
387            .get_mut(index.usize())
388            .and_then(|r| r.as_mut())
389            .ok_or(SecretTreeError::IndexOutOfBounds)
390    }
391
392    /// Returns an optional reference to a specific SenderRatchet
393    fn ratchet_opt(
394        &self,
395        index: LeafNodeIndex,
396        secret_type: SecretType,
397    ) -> Result<Option<&SenderRatchet>, SecretTreeError> {
398        let sender_ratchets = match secret_type {
399            SecretType::HandshakeSecret => &self.handshake_sender_ratchets,
400            SecretType::ApplicationSecret => &self.application_sender_ratchets,
401        };
402        match sender_ratchets.get(index.usize()) {
403            Some(sender_ratchet_option) => Ok(sender_ratchet_option.as_ref()),
404            None => Err(SecretTreeError::IndexOutOfBounds),
405        }
406    }
407
408    /// Derives the secrets for the child nodes in a SecretTree and blanks the
409    /// parent node.
410    fn derive_down(
411        &mut self,
412        ciphersuite: Ciphersuite,
413        crypto: &impl OpenMlsCrypto,
414        index_in_tree: ParentNodeIndex,
415    ) -> Result<(), SecretTreeError> {
416        log::debug!(
417            "Deriving tree secret for parent node {} with {}",
418            index_in_tree.u32(),
419            ciphersuite
420        );
421        let node_secret = match &self.get_node(index_in_tree.into())? {
422            Some(node) => &node.secret,
423            // This function only gets called top to bottom, so this should not happen
424            None => {
425                return Err(SecretTreeError::LibraryError);
426            }
427        };
428        log_crypto!(trace, "Node secret: {:x?}", node_secret.as_slice());
429        let left_index = left(index_in_tree);
430        let right_index = right(index_in_tree);
431        let (left_secret, right_secret) = derive_child_secrets(node_secret, crypto, ciphersuite)?;
432        log_crypto!(
433            trace,
434            "Left node ({}) secret: {:x?}",
435            left_index.test_u32(),
436            left_secret.as_slice()
437        );
438        log_crypto!(
439            trace,
440            "Right node ({}) secret: {:x?}",
441            right_index.test_u32(),
442            right_secret.as_slice()
443        );
444
445        // Populate left child
446        self.set_node(
447            left_index,
448            Some(SecretTreeNode {
449                secret: left_secret,
450            }),
451        )?;
452
453        // Populate right child
454        self.set_node(
455            right_index,
456            Some(SecretTreeNode {
457                secret: right_secret,
458            }),
459        )?;
460
461        // Delete parent node
462        self.set_node(index_in_tree.into(), None)
463    }
464
465    fn get_node(&self, index: TreeNodeIndex) -> Result<Option<&SecretTreeNode>, SecretTreeError> {
466        match index {
467            TreeNodeIndex::Leaf(leaf_index) => Ok(self
468                .leaf_nodes
469                .get(leaf_index.usize())
470                .ok_or(SecretTreeError::IndexOutOfBounds)?
471                .as_ref()),
472            TreeNodeIndex::Parent(parent_index) => Ok(self
473                .parent_nodes
474                .get(parent_index.usize())
475                .ok_or(SecretTreeError::IndexOutOfBounds)?
476                .as_ref()),
477        }
478    }
479
480    fn set_node(
481        &mut self,
482        index: TreeNodeIndex,
483        node: Option<SecretTreeNode>,
484    ) -> Result<(), SecretTreeError> {
485        match index {
486            TreeNodeIndex::Leaf(leaf_index) => {
487                *self
488                    .leaf_nodes
489                    .get_mut(leaf_index.usize())
490                    .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
491            }
492            TreeNodeIndex::Parent(parent_index) => {
493                *self
494                    .parent_nodes
495                    .get_mut(parent_index.usize())
496                    .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
497            }
498        }
499        Ok(())
500    }
501}