openmls/tree/
secret_tree.rs

1use openmls_traits::crypto::OpenMlsCrypto;
2use openmls_traits::types::{Ciphersuite, CryptoError};
3use thiserror::Error;
4use tls_codec::{Error as TlsCodecError, TlsSerialize, TlsSize};
5
6use super::*;
7use crate::{
8    binary_tree::{
9        array_representation::{
10            direct_path, left, right, root, ParentNodeIndex, TreeNodeIndex, TreeSize,
11        },
12        LeafNodeIndex,
13    },
14    framing::*,
15    schedule::*,
16    tree::sender_ratchet::*,
17};
18
19/// Secret tree error
20#[derive(Error, Debug, Eq, PartialEq, Clone)]
21pub enum SecretTreeError {
22    /// Generation is too old to be processed.
23    #[error("Generation is too old to be processed.")]
24    TooDistantInThePast,
25    /// Generation is too far in the future to be processed.
26    #[error("Generation is too far in the future to be processed.")]
27    TooDistantInTheFuture,
28    /// Index out of bounds
29    #[error("Index out of bounds")]
30    IndexOutOfBounds,
31    /// The requested secret was deleted to preserve forward secrecy.
32    #[error("The requested secret was deleted to preserve forward secrecy.")]
33    SecretReuseError,
34    /// Cannot create decryption secrets from own sender ratchet or encryption secrets from the sender ratchets of other members.
35    #[error("Cannot create decryption secrets from own sender ratchet or encryption secrets from the sender ratchets of other members.")]
36    RatchetTypeError,
37    /// Ratchet generation has reached `u32::MAX`.
38    #[error("Ratchet generation has reached `u32::MAX`.")]
39    RatchetTooLong,
40    /// An unrecoverable error has occurred due to a bug in the implementation.
41    #[error("An unrecoverable error has occurred due to a bug in the implementation.")]
42    LibraryError,
43    /// See [`TlsCodecError`] for more details.
44    #[error(transparent)]
45    CodecError(#[from] TlsCodecError),
46    /// See [`CryptoError`] for more details.
47    #[error(transparent)]
48    CryptoError(#[from] CryptoError),
49}
50
51#[derive(Debug, Copy, Clone)]
52pub(crate) enum SecretType {
53    HandshakeSecret,
54    ApplicationSecret,
55}
56
57impl From<&ContentType> for SecretType {
58    fn from(content_type: &ContentType) -> SecretType {
59        match content_type {
60            ContentType::Application => SecretType::ApplicationSecret,
61            ContentType::Commit => SecretType::HandshakeSecret,
62            ContentType::Proposal => SecretType::HandshakeSecret,
63        }
64    }
65}
66
67impl From<&PublicMessage> for SecretType {
68    fn from(public_message: &PublicMessage) -> SecretType {
69        SecretType::from(&public_message.content_type())
70    }
71}
72
73/// Derives secrets for inner nodes of a SecretTree. This function corresponds
74/// to the `DeriveTreeSecret` defined in Section 10.1 of the MLS specification.
75#[inline]
76pub(crate) fn derive_tree_secret(
77    ciphersuite: Ciphersuite,
78    secret: &Secret,
79    label: &str,
80    generation: u32,
81    length: usize,
82    crypto: &impl OpenMlsCrypto,
83) -> Result<Secret, SecretTreeError> {
84    log::debug!(
85        "Derive tree secret with label \"{label}\" in generation {generation} of length {length}"
86    );
87    log_crypto!(trace, "Input secret {:x?}", secret.as_slice());
88
89    let secret = secret.kdf_expand_label(
90        crypto,
91        ciphersuite,
92        label,
93        &generation.to_be_bytes(),
94        length,
95    )?;
96    log_crypto!(trace, "Derived secret {:x?}", secret.as_slice());
97    Ok(secret)
98}
99
100#[derive(Debug, TlsSerialize, TlsSize)]
101pub(crate) struct TreeContext {
102    pub(crate) node: u32,
103    pub(crate) generation: u32,
104}
105
106#[derive(Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
107#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
108pub(crate) struct SecretTreeNode {
109    pub(crate) secret: Secret,
110}
111
112#[derive(Serialize, Deserialize)]
113#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
114#[cfg_attr(any(feature = "crypto-debug", test), derive(Debug))]
115pub(crate) struct SecretTree {
116    own_index: LeafNodeIndex,
117    leaf_nodes: Vec<Option<SecretTreeNode>>,
118    parent_nodes: Vec<Option<SecretTreeNode>>,
119    handshake_sender_ratchets: Vec<Option<SenderRatchet>>,
120    application_sender_ratchets: Vec<Option<SenderRatchet>>,
121    size: TreeSize,
122}
123
124impl SecretTree {
125    /// Creates a new SecretTree based on an `encryption_secret` and group size
126    /// `size`. The inner nodes of the tree and the SenderRatchets only get
127    /// initialized when secrets are requested either through `secret()`
128    /// or `next_secret()`.
129    pub(crate) fn new(
130        encryption_secret: EncryptionSecret,
131        size: TreeSize,
132        own_index: LeafNodeIndex,
133    ) -> Self {
134        let leaf_count = size.leaf_count() as usize;
135        let leaf_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
136        let parent_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
137        let handshake_sender_ratchets = std::iter::repeat_with(|| None).take(leaf_count).collect();
138        let application_sender_ratchets =
139            std::iter::repeat_with(|| None).take(leaf_count).collect();
140
141        let mut secret_tree = SecretTree {
142            own_index,
143            leaf_nodes,
144            parent_nodes,
145            handshake_sender_ratchets,
146            application_sender_ratchets,
147            size,
148        };
149
150        // Set the encryption secret in the root node. We ignore the Result
151        // here, since the we rely on the tree math to be correct, i.e.
152        // root(size) < size.
153        let _ = secret_tree.set_node(
154            root(size),
155            Some(SecretTreeNode {
156                secret: encryption_secret.consume_secret(),
157            }),
158        );
159
160        secret_tree
161    }
162
163    /// Get current generation for a specific SenderRatchet
164    #[cfg(test)]
165    pub(crate) fn generation(&self, index: LeafNodeIndex, secret_type: SecretType) -> u32 {
166        match self
167            .ratchet_opt(index, secret_type)
168            .expect("Index out of bounds.")
169        {
170            Some(sender_ratchet) => sender_ratchet.generation(),
171            None => 0,
172        }
173    }
174
175    /// Initializes a specific SenderRatchet pair for a given index by
176    /// calculating and deleting the appropriate values in the SecretTree
177    fn initialize_sender_ratchets(
178        &mut self,
179        ciphersuite: Ciphersuite,
180        crypto: &impl OpenMlsCrypto,
181        index: LeafNodeIndex,
182    ) -> Result<(), SecretTreeError> {
183        log::trace!("Initializing sender ratchets for {index:?} with {ciphersuite}");
184        if index.u32() >= self.size.leaf_count() {
185            log::error!("Index is larger than the tree size.");
186            return Err(SecretTreeError::IndexOutOfBounds);
187        }
188        // Check if SenderRatchets are already initialized
189        if self
190            .ratchet_opt(index, SecretType::HandshakeSecret)?
191            .is_some()
192            && self
193                .ratchet_opt(index, SecretType::ApplicationSecret)?
194                .is_some()
195        {
196            log::trace!("The sender ratchets are initialized already.");
197            return Ok(());
198        }
199
200        // If we don't have a secret in the leaf node, we derive it
201        if self.get_node(index.into())?.is_none() {
202            // Collect empty nodes in the direct path until a non-empty node is
203            // found
204            let mut empty_nodes: Vec<ParentNodeIndex> = Vec::new();
205            let direct_path = direct_path(index, self.size);
206            log::trace!("Direct path for node {index:?}: {direct_path:?}");
207            for parent_node in direct_path {
208                empty_nodes.push(parent_node);
209                // Stop if we find a non-empty node
210                if self.get_node(parent_node.into())?.is_some() {
211                    break;
212                }
213            }
214
215            // Invert direct path
216            empty_nodes.reverse();
217
218            // Derive the secrets down all the way to the leaf node
219            for n in empty_nodes {
220                log::trace!("Derive down for parent node {n:?}.");
221                self.derive_down(ciphersuite, crypto, n)?;
222            }
223        }
224
225        // Calculate node secret and initialize SenderRatchets
226        let node_secret = match self.get_node(index.into())? {
227            Some(node) => &node.secret,
228            // We just derived all necessary nodes so this should not happen
229            None => {
230                return Err(SecretTreeError::LibraryError);
231            }
232        };
233
234        log::trace!("Deriving leaf node secrets for leaf {index:?}");
235
236        let handshake_ratchet_secret = node_secret.kdf_expand_label(
237            crypto,
238            ciphersuite,
239            "handshake",
240            b"",
241            ciphersuite.hash_length(),
242        )?;
243        let application_ratchet_secret = node_secret.kdf_expand_label(
244            crypto,
245            ciphersuite,
246            "application",
247            b"",
248            ciphersuite.hash_length(),
249        )?;
250
251        log_crypto!(
252            trace,
253            "handshake ratchet secret {handshake_ratchet_secret:x?}"
254        );
255        log_crypto!(
256            trace,
257            "application ratchet secret {application_ratchet_secret:x?}"
258        );
259
260        // Initialize SenderRatchets, we differentiate between the own SenderRatchets
261        // and the SenderRatchets of other members
262        let (handshake_sender_ratchet, application_sender_ratchet) = if index == self.own_index {
263            let handshake_sender_ratchet = SenderRatchet::EncryptionRatchet(
264                RatchetSecret::initial_ratchet_secret(handshake_ratchet_secret),
265            );
266            let application_sender_ratchet = SenderRatchet::EncryptionRatchet(
267                RatchetSecret::initial_ratchet_secret(application_ratchet_secret),
268            );
269
270            (handshake_sender_ratchet, application_sender_ratchet)
271        } else {
272            let handshake_sender_ratchet =
273                SenderRatchet::DecryptionRatchet(DecryptionRatchet::new(handshake_ratchet_secret));
274            let application_sender_ratchet = SenderRatchet::DecryptionRatchet(
275                DecryptionRatchet::new(application_ratchet_secret),
276            );
277
278            (handshake_sender_ratchet, application_sender_ratchet)
279        };
280
281        *self
282            .handshake_sender_ratchets
283            .get_mut(index.usize())
284            .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(handshake_sender_ratchet);
285        *self
286            .application_sender_ratchets
287            .get_mut(index.usize())
288            .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(application_sender_ratchet);
289
290        // Delete leaf node
291        self.set_node(index.into(), None)
292    }
293
294    /// Return RatchetSecrets for a given index and generation. This should be
295    /// called when decrypting an PrivateMessage received from another member.
296    /// Returns an error if index or generation are out of bound.
297    pub(crate) fn secret_for_decryption(
298        &mut self,
299        ciphersuite: Ciphersuite,
300        crypto: &impl OpenMlsCrypto,
301        index: LeafNodeIndex,
302        secret_type: SecretType,
303        generation: u32,
304        configuration: &SenderRatchetConfiguration,
305    ) -> Result<RatchetKeyMaterial, SecretTreeError> {
306        log::debug!(
307            "Generating {secret_type:?} decryption secret for {index:?} in generation {generation} with {ciphersuite}",
308        );
309        // Check tree bounds
310        if index.u32() >= self.size.leaf_count() {
311            log::error!("Sender index is not in the tree.");
312            return Err(SecretTreeError::IndexOutOfBounds);
313        }
314        if self.ratchet_opt(index, secret_type)?.is_none() {
315            log::trace!("   initialize sender ratchets");
316            self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
317        }
318        match self.ratchet_mut(index, secret_type)? {
319            SenderRatchet::EncryptionRatchet(_) => {
320                log::error!("This is the wrong ratchet type.");
321                Err(SecretTreeError::RatchetTypeError)
322            }
323            SenderRatchet::DecryptionRatchet(dec_ratchet) => {
324                log::trace!("   getting secret for decryption");
325                dec_ratchet.secret_for_decryption(ciphersuite, crypto, generation, configuration)
326            }
327        }
328    }
329
330    /// Return the next RatchetSecrets that should be used for encryption and
331    /// then increments the generation.
332    pub(crate) fn secret_for_encryption(
333        &mut self,
334        ciphersuite: Ciphersuite,
335        crypto: &impl OpenMlsCrypto,
336        index: LeafNodeIndex,
337        secret_type: SecretType,
338    ) -> Result<(u32, RatchetKeyMaterial), SecretTreeError> {
339        if self.ratchet_opt(index, secret_type)?.is_none() {
340            self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
341        }
342        match self.ratchet_mut(index, secret_type)? {
343            SenderRatchet::DecryptionRatchet(_) => {
344                log::error!("Invalid ratchet type. Got decryption, expected encryption.");
345                Err(SecretTreeError::RatchetTypeError)
346            }
347            SenderRatchet::EncryptionRatchet(enc_ratchet) => {
348                enc_ratchet.ratchet_forward(crypto, ciphersuite)
349            }
350        }
351    }
352
353    /// Returns a mutable reference to a specific SenderRatchet. The
354    /// SenderRatchet needs to be initialized.
355    fn ratchet_mut(
356        &mut self,
357        index: LeafNodeIndex,
358        secret_type: SecretType,
359    ) -> Result<&mut SenderRatchet, SecretTreeError> {
360        let sender_ratchets = match secret_type {
361            SecretType::HandshakeSecret => &mut self.handshake_sender_ratchets,
362            SecretType::ApplicationSecret => &mut self.application_sender_ratchets,
363        };
364        sender_ratchets
365            .get_mut(index.usize())
366            .and_then(|r| r.as_mut())
367            .ok_or(SecretTreeError::IndexOutOfBounds)
368    }
369
370    /// Returns an optional reference to a specific SenderRatchet
371    fn ratchet_opt(
372        &self,
373        index: LeafNodeIndex,
374        secret_type: SecretType,
375    ) -> Result<Option<&SenderRatchet>, SecretTreeError> {
376        let sender_ratchets = match secret_type {
377            SecretType::HandshakeSecret => &self.handshake_sender_ratchets,
378            SecretType::ApplicationSecret => &self.application_sender_ratchets,
379        };
380        match sender_ratchets.get(index.usize()) {
381            Some(sender_ratchet_option) => Ok(sender_ratchet_option.as_ref()),
382            None => Err(SecretTreeError::IndexOutOfBounds),
383        }
384    }
385
386    /// Derives the secrets for the child nodes in a SecretTree and blanks the
387    /// parent node.
388    fn derive_down(
389        &mut self,
390        ciphersuite: Ciphersuite,
391        crypto: &impl OpenMlsCrypto,
392        index_in_tree: ParentNodeIndex,
393    ) -> Result<(), SecretTreeError> {
394        log::debug!(
395            "Deriving tree secret for parent node {} with {}",
396            index_in_tree.u32(),
397            ciphersuite
398        );
399        let hash_len = ciphersuite.hash_length();
400        let node_secret = match &self.get_node(index_in_tree.into())? {
401            Some(node) => &node.secret,
402            // This function only gets called top to bottom, so this should not happen
403            None => {
404                return Err(SecretTreeError::LibraryError);
405            }
406        };
407        log_crypto!(trace, "Node secret: {:x?}", node_secret.as_slice());
408        let left_index = left(index_in_tree);
409        let right_index = right(index_in_tree);
410        let left_secret =
411            node_secret.kdf_expand_label(crypto, ciphersuite, "tree", b"left", hash_len)?;
412        let right_secret =
413            node_secret.kdf_expand_label(crypto, ciphersuite, "tree", b"right", hash_len)?;
414        log_crypto!(
415            trace,
416            "Left node ({}) secret: {:x?}",
417            left_index.test_u32(),
418            left_secret.as_slice()
419        );
420        log_crypto!(
421            trace,
422            "Right node ({}) secret: {:x?}",
423            right_index.test_u32(),
424            right_secret.as_slice()
425        );
426
427        // Populate left child
428        self.set_node(
429            left_index,
430            Some(SecretTreeNode {
431                secret: left_secret,
432            }),
433        )?;
434
435        // Populate right child
436        self.set_node(
437            right_index,
438            Some(SecretTreeNode {
439                secret: right_secret,
440            }),
441        )?;
442
443        // Delete parent node
444        self.set_node(index_in_tree.into(), None)
445    }
446
447    fn get_node(&self, index: TreeNodeIndex) -> Result<Option<&SecretTreeNode>, SecretTreeError> {
448        match index {
449            TreeNodeIndex::Leaf(leaf_index) => Ok(self
450                .leaf_nodes
451                .get(leaf_index.usize())
452                .ok_or(SecretTreeError::IndexOutOfBounds)?
453                .as_ref()),
454            TreeNodeIndex::Parent(parent_index) => Ok(self
455                .parent_nodes
456                .get(parent_index.usize())
457                .ok_or(SecretTreeError::IndexOutOfBounds)?
458                .as_ref()),
459        }
460    }
461
462    fn set_node(
463        &mut self,
464        index: TreeNodeIndex,
465        node: Option<SecretTreeNode>,
466    ) -> Result<(), SecretTreeError> {
467        match index {
468            TreeNodeIndex::Leaf(leaf_index) => {
469                *self
470                    .leaf_nodes
471                    .get_mut(leaf_index.usize())
472                    .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
473            }
474            TreeNodeIndex::Parent(parent_index) => {
475                *self
476                    .parent_nodes
477                    .get_mut(parent_index.usize())
478                    .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
479            }
480        }
481        Ok(())
482    }
483}