Skip to main content

openmls/tree/
secret_tree.rs

1use openmls_traits::crypto::OpenMlsCrypto;
2use openmls_traits::types::{Ciphersuite, CryptoError};
3use thiserror::Error;
4use tls_codec::{Error as TlsCodecError, TlsSerialize, TlsSize};
5
6use super::*;
7#[cfg(feature = "virtual-clients-draft")]
8use crate::tree::dual_use_ratchet::DualUseRatchet;
9use crate::{
10    binary_tree::{
11        array_representation::{
12            direct_path, left, right, root, ParentNodeIndex, TreeNodeIndex, TreeSize,
13        },
14        LeafNodeIndex,
15    },
16    framing::*,
17    schedule::*,
18    tree::sender_ratchet::*,
19};
20
21/// Secret tree error
22#[derive(Error, Debug, Eq, PartialEq, Clone)]
23pub enum SecretTreeError {
24    /// Generation is too old to be processed.
25    #[error("Generation is too old to be processed.")]
26    TooDistantInThePast,
27    /// Generation is too far in the future to be processed.
28    #[error("Generation is too far in the future to be processed.")]
29    TooDistantInTheFuture,
30    /// Index out of bounds
31    #[error("Index out of bounds")]
32    IndexOutOfBounds,
33    /// The requested secret was deleted to preserve forward secrecy.
34    #[error("The requested secret was deleted to preserve forward secrecy.")]
35    SecretReuseError,
36    /// Cannot create decryption secrets from own sender ratchet or encryption secrets from the sender ratchets of other members.
37    #[error("Cannot create decryption secrets from own sender ratchet or encryption secrets from the sender ratchets of other members.")]
38    RatchetTypeError,
39    /// Ratchet generation has reached `u32::MAX`.
40    #[error("Ratchet generation has reached `u32::MAX`.")]
41    RatchetTooLong,
42    /// An unrecoverable error has occurred due to a bug in the implementation.
43    #[error("An unrecoverable error has occurred due to a bug in the implementation.")]
44    LibraryError,
45    /// See [`TlsCodecError`] for more details.
46    #[error(transparent)]
47    CodecError(#[from] TlsCodecError),
48    /// See [`CryptoError`] for more details.
49    #[error(transparent)]
50    CryptoError(#[from] CryptoError),
51}
52
53#[derive(Debug, Copy, Clone)]
54pub(crate) enum SecretType {
55    HandshakeSecret,
56    ApplicationSecret,
57}
58
59impl From<&ContentType> for SecretType {
60    fn from(content_type: &ContentType) -> SecretType {
61        match content_type {
62            ContentType::Application => SecretType::ApplicationSecret,
63            ContentType::Commit => SecretType::HandshakeSecret,
64            ContentType::Proposal => SecretType::HandshakeSecret,
65        }
66    }
67}
68
69impl From<&PublicMessage> for SecretType {
70    fn from(public_message: &PublicMessage) -> SecretType {
71        SecretType::from(&public_message.content_type())
72    }
73}
74
75pub(crate) fn derive_child_secrets(
76    parent_secret: &Secret,
77    crypto: &impl OpenMlsCrypto,
78    ciphersuite: Ciphersuite,
79) -> Result<(Secret, Secret), CryptoError> {
80    let left_child = parent_secret.kdf_expand_label(
81        crypto,
82        ciphersuite,
83        "tree",
84        b"left",
85        ciphersuite.hash_length(),
86    )?;
87    let right_child = parent_secret.kdf_expand_label(
88        crypto,
89        ciphersuite,
90        "tree",
91        b"right",
92        ciphersuite.hash_length(),
93    )?;
94    Ok((left_child, right_child))
95}
96
97/// Derives secrets for inner nodes of a SecretTree. This function corresponds
98/// to the `DeriveTreeSecret` defined in Section 10.1 of the MLS specification.
99#[inline]
100pub(crate) fn derive_tree_secret(
101    ciphersuite: Ciphersuite,
102    secret: &Secret,
103    label: &str,
104    generation: u32,
105    length: usize,
106    crypto: &impl OpenMlsCrypto,
107) -> Result<Secret, SecretTreeError> {
108    log::debug!(
109        "Derive tree secret with label \"{label}\" in generation {generation} of length {length}"
110    );
111    log_crypto!(trace, "Input secret {:x?}", secret.as_slice());
112
113    let secret = secret.kdf_expand_label(
114        crypto,
115        ciphersuite,
116        label,
117        &generation.to_be_bytes(),
118        length,
119    )?;
120    log_crypto!(trace, "Derived secret {:x?}", secret.as_slice());
121    Ok(secret)
122}
123
124#[derive(Debug, TlsSerialize, TlsSize)]
125pub(crate) struct TreeContext {
126    pub(crate) node: u32,
127    pub(crate) generation: u32,
128}
129
130#[derive(Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
131#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
132pub(crate) struct SecretTreeNode {
133    pub(crate) secret: Secret,
134}
135
136#[derive(Serialize, Deserialize)]
137#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
138#[cfg_attr(any(feature = "crypto-debug", test), derive(Debug))]
139pub(crate) struct SecretTree {
140    own_index: LeafNodeIndex,
141    leaf_nodes: Vec<Option<SecretTreeNode>>,
142    parent_nodes: Vec<Option<SecretTreeNode>>,
143    handshake_sender_ratchets: Vec<Option<SenderRatchet>>,
144    application_sender_ratchets: Vec<Option<SenderRatchet>>,
145    size: TreeSize,
146}
147
148impl SecretTree {
149    /// Creates a new SecretTree based on an `encryption_secret` and group size
150    /// `size`. The inner nodes of the tree and the SenderRatchets only get
151    /// initialized when secrets are requested either through `secret()`
152    /// or `next_secret()`.
153    pub(crate) fn new(
154        encryption_secret: EncryptionSecret,
155        size: TreeSize,
156        own_index: LeafNodeIndex,
157    ) -> Self {
158        let leaf_count = size.leaf_count() as usize;
159        let leaf_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
160        let parent_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
161        let handshake_sender_ratchets = std::iter::repeat_with(|| None).take(leaf_count).collect();
162        let application_sender_ratchets =
163            std::iter::repeat_with(|| None).take(leaf_count).collect();
164
165        let mut secret_tree = SecretTree {
166            own_index,
167            leaf_nodes,
168            parent_nodes,
169            handshake_sender_ratchets,
170            application_sender_ratchets,
171            size,
172        };
173
174        // Set the encryption secret in the root node. We ignore the Result
175        // here, since the we rely on the tree math to be correct, i.e.
176        // root(size) < size.
177        let _ = secret_tree.set_node(
178            root(size),
179            Some(SecretTreeNode {
180                secret: encryption_secret.consume_secret(),
181            }),
182        );
183
184        secret_tree
185    }
186
187    /// Get current generation for a specific SenderRatchet
188    #[cfg(test)]
189    pub(crate) fn generation(&self, index: LeafNodeIndex, secret_type: SecretType) -> u32 {
190        match self
191            .ratchet_opt(index, secret_type)
192            .expect("Index out of bounds.")
193        {
194            Some(sender_ratchet) => sender_ratchet.generation(),
195            None => 0,
196        }
197    }
198
199    /// Initializes a specific SenderRatchet pair for a given index by
200    /// calculating and deleting the appropriate values in the SecretTree
201    fn initialize_sender_ratchets(
202        &mut self,
203        ciphersuite: Ciphersuite,
204        crypto: &impl OpenMlsCrypto,
205        index: LeafNodeIndex,
206    ) -> Result<(), SecretTreeError> {
207        log::trace!("Initializing sender ratchets for {index:?} with {ciphersuite}");
208        if index.u32() >= self.size.leaf_count() {
209            log::error!("Index is larger than the tree size.");
210            return Err(SecretTreeError::IndexOutOfBounds);
211        }
212        // Check if SenderRatchets are already initialized
213        if self
214            .ratchet_opt(index, SecretType::HandshakeSecret)?
215            .is_some()
216            && self
217                .ratchet_opt(index, SecretType::ApplicationSecret)?
218                .is_some()
219        {
220            log::trace!("The sender ratchets are initialized already.");
221            return Ok(());
222        }
223
224        // If we don't have a secret in the leaf node, we derive it
225        if self.get_node(index.into())?.is_none() {
226            // Collect empty nodes in the direct path until a non-empty node is
227            // found
228            let mut empty_nodes: Vec<ParentNodeIndex> = Vec::new();
229            let direct_path = direct_path(index, self.size);
230            log::trace!("Direct path for node {index:?}: {direct_path:?}");
231            for parent_node in direct_path {
232                empty_nodes.push(parent_node);
233                // Stop if we find a non-empty node
234                if self.get_node(parent_node.into())?.is_some() {
235                    break;
236                }
237            }
238
239            // Invert direct path
240            empty_nodes.reverse();
241
242            // Derive the secrets down all the way to the leaf node
243            for n in empty_nodes {
244                log::trace!("Derive down for parent node {n:?}.");
245                self.derive_down(ciphersuite, crypto, n)?;
246            }
247        }
248
249        // Calculate node secret and initialize SenderRatchets
250        let node_secret = match self.get_node(index.into())? {
251            Some(node) => &node.secret,
252            // We just derived all necessary nodes so this should not happen
253            None => {
254                return Err(SecretTreeError::LibraryError);
255            }
256        };
257
258        log::trace!("Deriving leaf node secrets for leaf {index:?}");
259
260        let handshake_ratchet_secret = node_secret.kdf_expand_label(
261            crypto,
262            ciphersuite,
263            "handshake",
264            b"",
265            ciphersuite.hash_length(),
266        )?;
267        let application_ratchet_secret = node_secret.kdf_expand_label(
268            crypto,
269            ciphersuite,
270            "application",
271            b"",
272            ciphersuite.hash_length(),
273        )?;
274
275        log_crypto!(
276            trace,
277            "handshake ratchet secret {handshake_ratchet_secret:x?}"
278        );
279        log_crypto!(
280            trace,
281            "application ratchet secret {application_ratchet_secret:x?}"
282        );
283
284        // Initialize SenderRatchets. We differentiate between the own
285        // SenderRatchets and the SenderRatchets of other members. With the
286        // `virtual-clients-draft` feature, the own SenderRatchets are
287        // [`DualUseRatchet`]s, which can produce key material for both
288        // encryption and decryption.
289        let (handshake_sender_ratchet, application_sender_ratchet) = if index == self.own_index {
290            #[cfg(not(feature = "virtual-clients-draft"))]
291            {
292                (
293                    SenderRatchet::EncryptionRatchet(RatchetSecret::initial_ratchet_secret(
294                        handshake_ratchet_secret,
295                    )),
296                    SenderRatchet::EncryptionRatchet(RatchetSecret::initial_ratchet_secret(
297                        application_ratchet_secret,
298                    )),
299                )
300            }
301            #[cfg(feature = "virtual-clients-draft")]
302            {
303                (
304                    SenderRatchet::DualUse(DualUseRatchet::new(handshake_ratchet_secret)),
305                    SenderRatchet::DualUse(DualUseRatchet::new(application_ratchet_secret)),
306                )
307            }
308        } else {
309            (
310                SenderRatchet::DecryptionRatchet(DecryptionRatchet::new(handshake_ratchet_secret)),
311                SenderRatchet::DecryptionRatchet(DecryptionRatchet::new(
312                    application_ratchet_secret,
313                )),
314            )
315        };
316
317        *self
318            .handshake_sender_ratchets
319            .get_mut(index.usize())
320            .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(handshake_sender_ratchet);
321        *self
322            .application_sender_ratchets
323            .get_mut(index.usize())
324            .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(application_sender_ratchet);
325
326        // Delete leaf node
327        self.set_node(index.into(), None)
328    }
329
330    /// Return RatchetSecrets for a given index and generation. This should be
331    /// called when decrypting an PrivateMessage received from another member.
332    /// Returns an error if index or generation are out of bound.
333    pub(crate) fn secret_for_decryption(
334        &mut self,
335        ciphersuite: Ciphersuite,
336        crypto: &impl OpenMlsCrypto,
337        index: LeafNodeIndex,
338        secret_type: SecretType,
339        generation: u32,
340        configuration: &SenderRatchetConfiguration,
341    ) -> Result<RatchetKeyMaterial, SecretTreeError> {
342        log::debug!(
343            "Generating {secret_type:?} decryption secret for {index:?} in generation {generation} with {ciphersuite}",
344        );
345        // Check tree bounds
346        if index.u32() >= self.size.leaf_count() {
347            log::error!("Sender index is not in the tree.");
348            return Err(SecretTreeError::IndexOutOfBounds);
349        }
350        if self.ratchet_opt(index, secret_type)?.is_none() {
351            log::trace!("   initialize sender ratchets");
352            self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
353        }
354        match self.ratchet_mut(index, secret_type)? {
355            SenderRatchet::EncryptionRatchet(_) => {
356                log::error!("This is the wrong ratchet type.");
357                Err(SecretTreeError::RatchetTypeError)
358            }
359            SenderRatchet::DecryptionRatchet(dec_ratchet) => {
360                log::trace!("   getting secret for decryption");
361                dec_ratchet.secret_for_decryption(ciphersuite, crypto, generation, configuration)
362            }
363            #[cfg(feature = "virtual-clients-draft")]
364            SenderRatchet::DualUse(dual_ratchet) => {
365                log::trace!("   getting secret for decryption (own dual-use ratchet)");
366                dual_ratchet.secret_for_decryption(ciphersuite, crypto, generation, configuration)
367            }
368        }
369    }
370
371    /// Return the next RatchetSecrets that should be used for encryption and
372    /// then increments the generation.
373    pub(crate) fn secret_for_encryption(
374        &mut self,
375        ciphersuite: Ciphersuite,
376        crypto: &impl OpenMlsCrypto,
377        index: LeafNodeIndex,
378        secret_type: SecretType,
379    ) -> Result<(u32, RatchetKeyMaterial), SecretTreeError> {
380        if self.ratchet_opt(index, secret_type)?.is_none() {
381            self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
382        }
383        match self.ratchet_mut(index, secret_type)? {
384            SenderRatchet::DecryptionRatchet(_) => {
385                log::error!("Invalid ratchet type. Got decryption, expected encryption.");
386                Err(SecretTreeError::RatchetTypeError)
387            }
388            SenderRatchet::EncryptionRatchet(enc_ratchet) => {
389                enc_ratchet.ratchet_forward(crypto, ciphersuite)
390            }
391            #[cfg(feature = "virtual-clients-draft")]
392            SenderRatchet::DualUse(dual_ratchet) => {
393                dual_ratchet.secret_for_encryption(ciphersuite, crypto)
394            }
395        }
396    }
397
398    #[cfg(feature = "virtual-clients-draft")]
399    pub(crate) fn delete_own_secret_for_generation(
400        &mut self,
401        secret_type: SecretType,
402        generation: Generation,
403    ) -> Result<(), SecretTreeError> {
404        match self.ratchet_mut(self.own_index, secret_type)? {
405            SenderRatchet::DualUse(dual_ratchet) => {
406                dual_ratchet.delete_secret_for_generation(generation);
407                Ok(())
408            }
409            SenderRatchet::EncryptionRatchet(_) | SenderRatchet::DecryptionRatchet(_) => {
410                Err(SecretTreeError::RatchetTypeError)
411            }
412        }
413    }
414
415    /// Returns a mutable reference to a specific SenderRatchet. The
416    /// SenderRatchet needs to be initialized.
417    fn ratchet_mut(
418        &mut self,
419        index: LeafNodeIndex,
420        secret_type: SecretType,
421    ) -> Result<&mut SenderRatchet, SecretTreeError> {
422        let sender_ratchets = match secret_type {
423            SecretType::HandshakeSecret => &mut self.handshake_sender_ratchets,
424            SecretType::ApplicationSecret => &mut self.application_sender_ratchets,
425        };
426        sender_ratchets
427            .get_mut(index.usize())
428            .and_then(|r| r.as_mut())
429            .ok_or(SecretTreeError::IndexOutOfBounds)
430    }
431
432    /// Returns an optional reference to a specific SenderRatchet
433    fn ratchet_opt(
434        &self,
435        index: LeafNodeIndex,
436        secret_type: SecretType,
437    ) -> Result<Option<&SenderRatchet>, SecretTreeError> {
438        let sender_ratchets = match secret_type {
439            SecretType::HandshakeSecret => &self.handshake_sender_ratchets,
440            SecretType::ApplicationSecret => &self.application_sender_ratchets,
441        };
442        match sender_ratchets.get(index.usize()) {
443            Some(sender_ratchet_option) => Ok(sender_ratchet_option.as_ref()),
444            None => Err(SecretTreeError::IndexOutOfBounds),
445        }
446    }
447
448    /// Derives the secrets for the child nodes in a SecretTree and blanks the
449    /// parent node.
450    fn derive_down(
451        &mut self,
452        ciphersuite: Ciphersuite,
453        crypto: &impl OpenMlsCrypto,
454        index_in_tree: ParentNodeIndex,
455    ) -> Result<(), SecretTreeError> {
456        log::debug!(
457            "Deriving tree secret for parent node {} with {}",
458            index_in_tree.u32(),
459            ciphersuite
460        );
461        let node_secret = match &self.get_node(index_in_tree.into())? {
462            Some(node) => &node.secret,
463            // This function only gets called top to bottom, so this should not happen
464            None => {
465                return Err(SecretTreeError::LibraryError);
466            }
467        };
468        log_crypto!(trace, "Node secret: {:x?}", node_secret.as_slice());
469        let left_index = left(index_in_tree);
470        let right_index = right(index_in_tree);
471        let (left_secret, right_secret) = derive_child_secrets(node_secret, crypto, ciphersuite)?;
472        log_crypto!(
473            trace,
474            "Left node ({}) secret: {:x?}",
475            left_index.test_u32(),
476            left_secret.as_slice()
477        );
478        log_crypto!(
479            trace,
480            "Right node ({}) secret: {:x?}",
481            right_index.test_u32(),
482            right_secret.as_slice()
483        );
484
485        // Populate left child
486        self.set_node(
487            left_index,
488            Some(SecretTreeNode {
489                secret: left_secret,
490            }),
491        )?;
492
493        // Populate right child
494        self.set_node(
495            right_index,
496            Some(SecretTreeNode {
497                secret: right_secret,
498            }),
499        )?;
500
501        // Delete parent node
502        self.set_node(index_in_tree.into(), None)
503    }
504
505    fn get_node(&self, index: TreeNodeIndex) -> Result<Option<&SecretTreeNode>, SecretTreeError> {
506        match index {
507            TreeNodeIndex::Leaf(leaf_index) => Ok(self
508                .leaf_nodes
509                .get(leaf_index.usize())
510                .ok_or(SecretTreeError::IndexOutOfBounds)?
511                .as_ref()),
512            TreeNodeIndex::Parent(parent_index) => Ok(self
513                .parent_nodes
514                .get(parent_index.usize())
515                .ok_or(SecretTreeError::IndexOutOfBounds)?
516                .as_ref()),
517        }
518    }
519
520    fn set_node(
521        &mut self,
522        index: TreeNodeIndex,
523        node: Option<SecretTreeNode>,
524    ) -> Result<(), SecretTreeError> {
525        match index {
526            TreeNodeIndex::Leaf(leaf_index) => {
527                *self
528                    .leaf_nodes
529                    .get_mut(leaf_index.usize())
530                    .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
531            }
532            TreeNodeIndex::Parent(parent_index) => {
533                *self
534                    .parent_nodes
535                    .get_mut(parent_index.usize())
536                    .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
537            }
538        }
539        Ok(())
540    }
541}