openmls/tree/
secret_tree.rs

1use openmls_traits::crypto::OpenMlsCrypto;
2use openmls_traits::types::{Ciphersuite, CryptoError};
3use thiserror::Error;
4use tls_codec::{Error as TlsCodecError, TlsSerialize, TlsSize};
5
6use super::*;
7use crate::{
8    binary_tree::{
9        array_representation::{
10            direct_path, left, right, root, ParentNodeIndex, TreeNodeIndex, TreeSize,
11        },
12        LeafNodeIndex,
13    },
14    framing::*,
15    schedule::*,
16    tree::sender_ratchet::*,
17};
18
19/// Secret tree error
20#[derive(Error, Debug, Eq, PartialEq, Clone)]
21pub enum SecretTreeError {
22    /// Generation is too old to be processed.
23    #[error("Generation is too old to be processed.")]
24    TooDistantInThePast,
25    /// Generation is too far in the future to be processed.
26    #[error("Generation is too far in the future to be processed.")]
27    TooDistantInTheFuture,
28    /// Index out of bounds
29    #[error("Index out of bounds")]
30    IndexOutOfBounds,
31    /// The requested secret was deleted to preserve forward secrecy.
32    #[error("The requested secret was deleted to preserve forward secrecy.")]
33    SecretReuseError,
34    /// Cannot create decryption secrets from own sender ratchet or encryption secrets from the sender ratchets of other members.
35    #[error("Cannot create decryption secrets from own sender ratchet or encryption secrets from the sender ratchets of other members.")]
36    RatchetTypeError,
37    /// Ratchet generation has reached `u32::MAX`.
38    #[error("Ratchet generation has reached `u32::MAX`.")]
39    RatchetTooLong,
40    /// An unrecoverable error has occurred due to a bug in the implementation.
41    #[error("An unrecoverable error has occurred due to a bug in the implementation.")]
42    LibraryError,
43    /// See [`TlsCodecError`] for more details.
44    #[error(transparent)]
45    CodecError(#[from] TlsCodecError),
46    /// See [`CryptoError`] for more details.
47    #[error(transparent)]
48    CryptoError(#[from] CryptoError),
49}
50
51#[derive(Debug, Copy, Clone)]
52pub(crate) enum SecretType {
53    HandshakeSecret,
54    ApplicationSecret,
55}
56
57impl From<&ContentType> for SecretType {
58    fn from(content_type: &ContentType) -> SecretType {
59        match content_type {
60            ContentType::Application => SecretType::ApplicationSecret,
61            ContentType::Commit => SecretType::HandshakeSecret,
62            ContentType::Proposal => SecretType::HandshakeSecret,
63        }
64    }
65}
66
67impl From<&PublicMessage> for SecretType {
68    fn from(public_message: &PublicMessage) -> SecretType {
69        SecretType::from(&public_message.content_type())
70    }
71}
72
73/// Derives secrets for inner nodes of a SecretTree. This function corresponds
74/// to the `DeriveTreeSecret` defined in Section 10.1 of the MLS specification.
75#[inline]
76pub(crate) fn derive_tree_secret(
77    ciphersuite: Ciphersuite,
78    secret: &Secret,
79    label: &str,
80    generation: u32,
81    length: usize,
82    crypto: &impl OpenMlsCrypto,
83) -> Result<Secret, SecretTreeError> {
84    log::debug!(
85        "Derive tree secret with label \"{}\" in generation {} of length {}",
86        label,
87        generation,
88        length
89    );
90    log_crypto!(trace, "Input secret {:x?}", secret.as_slice());
91
92    let secret = secret.kdf_expand_label(
93        crypto,
94        ciphersuite,
95        label,
96        &generation.to_be_bytes(),
97        length,
98    )?;
99    log_crypto!(trace, "Derived secret {:x?}", secret.as_slice());
100    Ok(secret)
101}
102
103#[derive(Debug, TlsSerialize, TlsSize)]
104pub(crate) struct TreeContext {
105    pub(crate) node: u32,
106    pub(crate) generation: u32,
107}
108
109#[derive(Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
110#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
111pub(crate) struct SecretTreeNode {
112    pub(crate) secret: Secret,
113}
114
115#[derive(Serialize, Deserialize)]
116#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
117#[cfg_attr(any(feature = "crypto-debug", test), derive(Debug))]
118pub(crate) struct SecretTree {
119    own_index: LeafNodeIndex,
120    leaf_nodes: Vec<Option<SecretTreeNode>>,
121    parent_nodes: Vec<Option<SecretTreeNode>>,
122    handshake_sender_ratchets: Vec<Option<SenderRatchet>>,
123    application_sender_ratchets: Vec<Option<SenderRatchet>>,
124    size: TreeSize,
125}
126
127impl SecretTree {
128    /// Creates a new SecretTree based on an `encryption_secret` and group size
129    /// `size`. The inner nodes of the tree and the SenderRatchets only get
130    /// initialized when secrets are requested either through `secret()`
131    /// or `next_secret()`.
132    pub(crate) fn new(
133        encryption_secret: EncryptionSecret,
134        size: TreeSize,
135        own_index: LeafNodeIndex,
136    ) -> Self {
137        let leaf_count = size.leaf_count() as usize;
138        let leaf_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
139        let parent_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
140        let handshake_sender_ratchets = std::iter::repeat_with(|| None).take(leaf_count).collect();
141        let application_sender_ratchets =
142            std::iter::repeat_with(|| None).take(leaf_count).collect();
143
144        let mut secret_tree = SecretTree {
145            own_index,
146            leaf_nodes,
147            parent_nodes,
148            handshake_sender_ratchets,
149            application_sender_ratchets,
150            size,
151        };
152
153        // Set the encryption secret in the root node. We ignore the Result
154        // here, since the we rely on the tree math to be correct, i.e.
155        // root(size) < size.
156        let _ = secret_tree.set_node(
157            root(size),
158            Some(SecretTreeNode {
159                secret: encryption_secret.consume_secret(),
160            }),
161        );
162
163        secret_tree
164    }
165
166    /// Get current generation for a specific SenderRatchet
167    #[cfg(test)]
168    pub(crate) fn generation(&self, index: LeafNodeIndex, secret_type: SecretType) -> u32 {
169        match self
170            .ratchet_opt(index, secret_type)
171            .expect("Index out of bounds.")
172        {
173            Some(sender_ratchet) => sender_ratchet.generation(),
174            None => 0,
175        }
176    }
177
178    /// Initializes a specific SenderRatchet pair for a given index by
179    /// calculating and deleting the appropriate values in the SecretTree
180    fn initialize_sender_ratchets(
181        &mut self,
182        ciphersuite: Ciphersuite,
183        crypto: &impl OpenMlsCrypto,
184        index: LeafNodeIndex,
185    ) -> Result<(), SecretTreeError> {
186        log::trace!("Initializing sender ratchets for {index:?} with {ciphersuite}");
187        if index.u32() >= self.size.leaf_count() {
188            log::error!("Index is larger than the tree size.");
189            return Err(SecretTreeError::IndexOutOfBounds);
190        }
191        // Check if SenderRatchets are already initialized
192        if self
193            .ratchet_opt(index, SecretType::HandshakeSecret)?
194            .is_some()
195            && self
196                .ratchet_opt(index, SecretType::ApplicationSecret)?
197                .is_some()
198        {
199            log::trace!("The sender ratchets are initialized already.");
200            return Ok(());
201        }
202
203        // If we don't have a secret in the leaf node, we derive it
204        if self.get_node(index.into())?.is_none() {
205            // Collect empty nodes in the direct path until a non-empty node is
206            // found
207            let mut empty_nodes: Vec<ParentNodeIndex> = Vec::new();
208            let direct_path = direct_path(index, self.size);
209            log::trace!("Direct path for node {index:?}: {:?}", direct_path);
210            for parent_node in direct_path {
211                empty_nodes.push(parent_node);
212                // Stop if we find a non-empty node
213                if self.get_node(parent_node.into())?.is_some() {
214                    break;
215                }
216            }
217
218            // Invert direct path
219            empty_nodes.reverse();
220
221            // Derive the secrets down all the way to the leaf node
222            for n in empty_nodes {
223                log::trace!("Derive down for parent node {n:?}.");
224                self.derive_down(ciphersuite, crypto, n)?;
225            }
226        }
227
228        // Calculate node secret and initialize SenderRatchets
229        let node_secret = match self.get_node(index.into())? {
230            Some(node) => &node.secret,
231            // We just derived all necessary nodes so this should not happen
232            None => {
233                return Err(SecretTreeError::LibraryError);
234            }
235        };
236
237        log::trace!("Deriving leaf node secrets for leaf {index:?}");
238
239        let handshake_ratchet_secret = node_secret.kdf_expand_label(
240            crypto,
241            ciphersuite,
242            "handshake",
243            b"",
244            ciphersuite.hash_length(),
245        )?;
246        let application_ratchet_secret = node_secret.kdf_expand_label(
247            crypto,
248            ciphersuite,
249            "application",
250            b"",
251            ciphersuite.hash_length(),
252        )?;
253
254        log_crypto!(
255            trace,
256            "handshake ratchet secret {handshake_ratchet_secret:x?}"
257        );
258        log_crypto!(
259            trace,
260            "application ratchet secret {application_ratchet_secret:x?}"
261        );
262
263        // Initialize SenderRatchets, we differentiate between the own SenderRatchets
264        // and the SenderRatchets of other members
265        let (handshake_sender_ratchet, application_sender_ratchet) = if index == self.own_index {
266            let handshake_sender_ratchet = SenderRatchet::EncryptionRatchet(
267                RatchetSecret::initial_ratchet_secret(handshake_ratchet_secret),
268            );
269            let application_sender_ratchet = SenderRatchet::EncryptionRatchet(
270                RatchetSecret::initial_ratchet_secret(application_ratchet_secret),
271            );
272
273            (handshake_sender_ratchet, application_sender_ratchet)
274        } else {
275            let handshake_sender_ratchet =
276                SenderRatchet::DecryptionRatchet(DecryptionRatchet::new(handshake_ratchet_secret));
277            let application_sender_ratchet = SenderRatchet::DecryptionRatchet(
278                DecryptionRatchet::new(application_ratchet_secret),
279            );
280
281            (handshake_sender_ratchet, application_sender_ratchet)
282        };
283
284        *self
285            .handshake_sender_ratchets
286            .get_mut(index.usize())
287            .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(handshake_sender_ratchet);
288        *self
289            .application_sender_ratchets
290            .get_mut(index.usize())
291            .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(application_sender_ratchet);
292
293        // Delete leaf node
294        self.set_node(index.into(), None)
295    }
296
297    /// Return RatchetSecrets for a given index and generation. This should be
298    /// called when decrypting an PrivateMessage received from another member.
299    /// Returns an error if index or generation are out of bound.
300    pub(crate) fn secret_for_decryption(
301        &mut self,
302        ciphersuite: Ciphersuite,
303        crypto: &impl OpenMlsCrypto,
304        index: LeafNodeIndex,
305        secret_type: SecretType,
306        generation: u32,
307        configuration: &SenderRatchetConfiguration,
308    ) -> Result<RatchetKeyMaterial, SecretTreeError> {
309        log::debug!(
310            "Generating {:?} decryption secret for {:?} in generation {} with {}",
311            secret_type,
312            index,
313            generation,
314            ciphersuite,
315        );
316        // Check tree bounds
317        if index.u32() >= self.size.leaf_count() {
318            log::error!("Sender index is not in the tree.");
319            return Err(SecretTreeError::IndexOutOfBounds);
320        }
321        if self.ratchet_opt(index, secret_type)?.is_none() {
322            log::trace!("   initialize sender ratchets");
323            self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
324        }
325        match self.ratchet_mut(index, secret_type)? {
326            SenderRatchet::EncryptionRatchet(_) => {
327                log::error!("This is the wrong ratchet type.");
328                Err(SecretTreeError::RatchetTypeError)
329            }
330            SenderRatchet::DecryptionRatchet(dec_ratchet) => {
331                log::trace!("   getting secret for decryption");
332                dec_ratchet.secret_for_decryption(ciphersuite, crypto, generation, configuration)
333            }
334        }
335    }
336
337    /// Return the next RatchetSecrets that should be used for encryption and
338    /// then increments the generation.
339    pub(crate) fn secret_for_encryption(
340        &mut self,
341        ciphersuite: Ciphersuite,
342        crypto: &impl OpenMlsCrypto,
343        index: LeafNodeIndex,
344        secret_type: SecretType,
345    ) -> Result<(u32, RatchetKeyMaterial), SecretTreeError> {
346        if self.ratchet_opt(index, secret_type)?.is_none() {
347            self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
348        }
349        match self.ratchet_mut(index, secret_type)? {
350            SenderRatchet::DecryptionRatchet(_) => {
351                log::error!("Invalid ratchet type. Got decryption, expected encryption.");
352                Err(SecretTreeError::RatchetTypeError)
353            }
354            SenderRatchet::EncryptionRatchet(enc_ratchet) => {
355                enc_ratchet.ratchet_forward(crypto, ciphersuite)
356            }
357        }
358    }
359
360    /// Returns a mutable reference to a specific SenderRatchet. The
361    /// SenderRatchet needs to be initialized.
362    fn ratchet_mut(
363        &mut self,
364        index: LeafNodeIndex,
365        secret_type: SecretType,
366    ) -> Result<&mut SenderRatchet, SecretTreeError> {
367        let sender_ratchets = match secret_type {
368            SecretType::HandshakeSecret => &mut self.handshake_sender_ratchets,
369            SecretType::ApplicationSecret => &mut self.application_sender_ratchets,
370        };
371        sender_ratchets
372            .get_mut(index.usize())
373            .and_then(|r| r.as_mut())
374            .ok_or(SecretTreeError::IndexOutOfBounds)
375    }
376
377    /// Returns an optional reference to a specific SenderRatchet
378    fn ratchet_opt(
379        &self,
380        index: LeafNodeIndex,
381        secret_type: SecretType,
382    ) -> Result<Option<&SenderRatchet>, SecretTreeError> {
383        let sender_ratchets = match secret_type {
384            SecretType::HandshakeSecret => &self.handshake_sender_ratchets,
385            SecretType::ApplicationSecret => &self.application_sender_ratchets,
386        };
387        match sender_ratchets.get(index.usize()) {
388            Some(sender_ratchet_option) => Ok(sender_ratchet_option.as_ref()),
389            None => Err(SecretTreeError::IndexOutOfBounds),
390        }
391    }
392
393    /// Derives the secrets for the child nodes in a SecretTree and blanks the
394    /// parent node.
395    fn derive_down(
396        &mut self,
397        ciphersuite: Ciphersuite,
398        crypto: &impl OpenMlsCrypto,
399        index_in_tree: ParentNodeIndex,
400    ) -> Result<(), SecretTreeError> {
401        log::debug!(
402            "Deriving tree secret for parent node {} with {}",
403            index_in_tree.u32(),
404            ciphersuite
405        );
406        let hash_len = ciphersuite.hash_length();
407        let node_secret = match &self.get_node(index_in_tree.into())? {
408            Some(node) => &node.secret,
409            // This function only gets called top to bottom, so this should not happen
410            None => {
411                return Err(SecretTreeError::LibraryError);
412            }
413        };
414        log_crypto!(trace, "Node secret: {:x?}", node_secret.as_slice());
415        let left_index = left(index_in_tree);
416        let right_index = right(index_in_tree);
417        let left_secret =
418            node_secret.kdf_expand_label(crypto, ciphersuite, "tree", b"left", hash_len)?;
419        let right_secret =
420            node_secret.kdf_expand_label(crypto, ciphersuite, "tree", b"right", hash_len)?;
421        log_crypto!(
422            trace,
423            "Left node ({}) secret: {:x?}",
424            left_index.test_u32(),
425            left_secret.as_slice()
426        );
427        log_crypto!(
428            trace,
429            "Right node ({}) secret: {:x?}",
430            right_index.test_u32(),
431            right_secret.as_slice()
432        );
433
434        // Populate left child
435        self.set_node(
436            left_index,
437            Some(SecretTreeNode {
438                secret: left_secret,
439            }),
440        )?;
441
442        // Populate right child
443        self.set_node(
444            right_index,
445            Some(SecretTreeNode {
446                secret: right_secret,
447            }),
448        )?;
449
450        // Delete parent node
451        self.set_node(index_in_tree.into(), None)
452    }
453
454    fn get_node(&self, index: TreeNodeIndex) -> Result<Option<&SecretTreeNode>, SecretTreeError> {
455        match index {
456            TreeNodeIndex::Leaf(leaf_index) => Ok(self
457                .leaf_nodes
458                .get(leaf_index.usize())
459                .ok_or(SecretTreeError::IndexOutOfBounds)?
460                .as_ref()),
461            TreeNodeIndex::Parent(parent_index) => Ok(self
462                .parent_nodes
463                .get(parent_index.usize())
464                .ok_or(SecretTreeError::IndexOutOfBounds)?
465                .as_ref()),
466        }
467    }
468
469    fn set_node(
470        &mut self,
471        index: TreeNodeIndex,
472        node: Option<SecretTreeNode>,
473    ) -> Result<(), SecretTreeError> {
474        match index {
475            TreeNodeIndex::Leaf(leaf_index) => {
476                *self
477                    .leaf_nodes
478                    .get_mut(leaf_index.usize())
479                    .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
480            }
481            TreeNodeIndex::Parent(parent_index) => {
482                *self
483                    .parent_nodes
484                    .get_mut(parent_index.usize())
485                    .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
486            }
487        }
488        Ok(())
489    }
490}