1use openmls_traits::crypto::OpenMlsCrypto;
2use openmls_traits::types::{Ciphersuite, CryptoError};
3use thiserror::Error;
4use tls_codec::{Error as TlsCodecError, TlsSerialize, TlsSize};
5
6use super::*;
7use crate::{
8 binary_tree::{
9 array_representation::{
10 direct_path, left, right, root, ParentNodeIndex, TreeNodeIndex, TreeSize,
11 },
12 LeafNodeIndex,
13 },
14 framing::*,
15 schedule::*,
16 tree::sender_ratchet::*,
17};
18
19#[derive(Error, Debug, Eq, PartialEq, Clone)]
21pub enum SecretTreeError {
22 #[error("Generation is too old to be processed.")]
24 TooDistantInThePast,
25 #[error("Generation is too far in the future to be processed.")]
27 TooDistantInTheFuture,
28 #[error("Index out of bounds")]
30 IndexOutOfBounds,
31 #[error("The requested secret was deleted to preserve forward secrecy.")]
33 SecretReuseError,
34 #[error("Cannot create decryption secrets from own sender ratchet or encryption secrets from the sender ratchets of other members.")]
36 RatchetTypeError,
37 #[error("Ratchet generation has reached `u32::MAX`.")]
39 RatchetTooLong,
40 #[error("An unrecoverable error has occurred due to a bug in the implementation.")]
42 LibraryError,
43 #[error(transparent)]
45 CodecError(#[from] TlsCodecError),
46 #[error(transparent)]
48 CryptoError(#[from] CryptoError),
49}
50
51#[derive(Debug, Copy, Clone)]
52pub(crate) enum SecretType {
53 HandshakeSecret,
54 ApplicationSecret,
55}
56
57impl From<&ContentType> for SecretType {
58 fn from(content_type: &ContentType) -> SecretType {
59 match content_type {
60 ContentType::Application => SecretType::ApplicationSecret,
61 ContentType::Commit => SecretType::HandshakeSecret,
62 ContentType::Proposal => SecretType::HandshakeSecret,
63 }
64 }
65}
66
67impl From<&PublicMessage> for SecretType {
68 fn from(public_message: &PublicMessage) -> SecretType {
69 SecretType::from(&public_message.content_type())
70 }
71}
72
73#[inline]
76pub(crate) fn derive_tree_secret(
77 ciphersuite: Ciphersuite,
78 secret: &Secret,
79 label: &str,
80 generation: u32,
81 length: usize,
82 crypto: &impl OpenMlsCrypto,
83) -> Result<Secret, SecretTreeError> {
84 log::debug!(
85 "Derive tree secret with label \"{label}\" in generation {generation} of length {length}"
86 );
87 log_crypto!(trace, "Input secret {:x?}", secret.as_slice());
88
89 let secret = secret.kdf_expand_label(
90 crypto,
91 ciphersuite,
92 label,
93 &generation.to_be_bytes(),
94 length,
95 )?;
96 log_crypto!(trace, "Derived secret {:x?}", secret.as_slice());
97 Ok(secret)
98}
99
100#[derive(Debug, TlsSerialize, TlsSize)]
101pub(crate) struct TreeContext {
102 pub(crate) node: u32,
103 pub(crate) generation: u32,
104}
105
106#[derive(Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
107#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
108pub(crate) struct SecretTreeNode {
109 pub(crate) secret: Secret,
110}
111
112#[derive(Serialize, Deserialize)]
113#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
114#[cfg_attr(any(feature = "crypto-debug", test), derive(Debug))]
115pub(crate) struct SecretTree {
116 own_index: LeafNodeIndex,
117 leaf_nodes: Vec<Option<SecretTreeNode>>,
118 parent_nodes: Vec<Option<SecretTreeNode>>,
119 handshake_sender_ratchets: Vec<Option<SenderRatchet>>,
120 application_sender_ratchets: Vec<Option<SenderRatchet>>,
121 size: TreeSize,
122}
123
124impl SecretTree {
125 pub(crate) fn new(
130 encryption_secret: EncryptionSecret,
131 size: TreeSize,
132 own_index: LeafNodeIndex,
133 ) -> Self {
134 let leaf_count = size.leaf_count() as usize;
135 let leaf_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
136 let parent_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
137 let handshake_sender_ratchets = std::iter::repeat_with(|| None).take(leaf_count).collect();
138 let application_sender_ratchets =
139 std::iter::repeat_with(|| None).take(leaf_count).collect();
140
141 let mut secret_tree = SecretTree {
142 own_index,
143 leaf_nodes,
144 parent_nodes,
145 handshake_sender_ratchets,
146 application_sender_ratchets,
147 size,
148 };
149
150 let _ = secret_tree.set_node(
154 root(size),
155 Some(SecretTreeNode {
156 secret: encryption_secret.consume_secret(),
157 }),
158 );
159
160 secret_tree
161 }
162
163 #[cfg(test)]
165 pub(crate) fn generation(&self, index: LeafNodeIndex, secret_type: SecretType) -> u32 {
166 match self
167 .ratchet_opt(index, secret_type)
168 .expect("Index out of bounds.")
169 {
170 Some(sender_ratchet) => sender_ratchet.generation(),
171 None => 0,
172 }
173 }
174
175 fn initialize_sender_ratchets(
178 &mut self,
179 ciphersuite: Ciphersuite,
180 crypto: &impl OpenMlsCrypto,
181 index: LeafNodeIndex,
182 ) -> Result<(), SecretTreeError> {
183 log::trace!("Initializing sender ratchets for {index:?} with {ciphersuite}");
184 if index.u32() >= self.size.leaf_count() {
185 log::error!("Index is larger than the tree size.");
186 return Err(SecretTreeError::IndexOutOfBounds);
187 }
188 if self
190 .ratchet_opt(index, SecretType::HandshakeSecret)?
191 .is_some()
192 && self
193 .ratchet_opt(index, SecretType::ApplicationSecret)?
194 .is_some()
195 {
196 log::trace!("The sender ratchets are initialized already.");
197 return Ok(());
198 }
199
200 if self.get_node(index.into())?.is_none() {
202 let mut empty_nodes: Vec<ParentNodeIndex> = Vec::new();
205 let direct_path = direct_path(index, self.size);
206 log::trace!("Direct path for node {index:?}: {direct_path:?}");
207 for parent_node in direct_path {
208 empty_nodes.push(parent_node);
209 if self.get_node(parent_node.into())?.is_some() {
211 break;
212 }
213 }
214
215 empty_nodes.reverse();
217
218 for n in empty_nodes {
220 log::trace!("Derive down for parent node {n:?}.");
221 self.derive_down(ciphersuite, crypto, n)?;
222 }
223 }
224
225 let node_secret = match self.get_node(index.into())? {
227 Some(node) => &node.secret,
228 None => {
230 return Err(SecretTreeError::LibraryError);
231 }
232 };
233
234 log::trace!("Deriving leaf node secrets for leaf {index:?}");
235
236 let handshake_ratchet_secret = node_secret.kdf_expand_label(
237 crypto,
238 ciphersuite,
239 "handshake",
240 b"",
241 ciphersuite.hash_length(),
242 )?;
243 let application_ratchet_secret = node_secret.kdf_expand_label(
244 crypto,
245 ciphersuite,
246 "application",
247 b"",
248 ciphersuite.hash_length(),
249 )?;
250
251 log_crypto!(
252 trace,
253 "handshake ratchet secret {handshake_ratchet_secret:x?}"
254 );
255 log_crypto!(
256 trace,
257 "application ratchet secret {application_ratchet_secret:x?}"
258 );
259
260 let (handshake_sender_ratchet, application_sender_ratchet) = if index == self.own_index {
263 let handshake_sender_ratchet = SenderRatchet::EncryptionRatchet(
264 RatchetSecret::initial_ratchet_secret(handshake_ratchet_secret),
265 );
266 let application_sender_ratchet = SenderRatchet::EncryptionRatchet(
267 RatchetSecret::initial_ratchet_secret(application_ratchet_secret),
268 );
269
270 (handshake_sender_ratchet, application_sender_ratchet)
271 } else {
272 let handshake_sender_ratchet =
273 SenderRatchet::DecryptionRatchet(DecryptionRatchet::new(handshake_ratchet_secret));
274 let application_sender_ratchet = SenderRatchet::DecryptionRatchet(
275 DecryptionRatchet::new(application_ratchet_secret),
276 );
277
278 (handshake_sender_ratchet, application_sender_ratchet)
279 };
280
281 *self
282 .handshake_sender_ratchets
283 .get_mut(index.usize())
284 .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(handshake_sender_ratchet);
285 *self
286 .application_sender_ratchets
287 .get_mut(index.usize())
288 .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(application_sender_ratchet);
289
290 self.set_node(index.into(), None)
292 }
293
294 pub(crate) fn secret_for_decryption(
298 &mut self,
299 ciphersuite: Ciphersuite,
300 crypto: &impl OpenMlsCrypto,
301 index: LeafNodeIndex,
302 secret_type: SecretType,
303 generation: u32,
304 configuration: &SenderRatchetConfiguration,
305 ) -> Result<RatchetKeyMaterial, SecretTreeError> {
306 log::debug!(
307 "Generating {secret_type:?} decryption secret for {index:?} in generation {generation} with {ciphersuite}",
308 );
309 if index.u32() >= self.size.leaf_count() {
311 log::error!("Sender index is not in the tree.");
312 return Err(SecretTreeError::IndexOutOfBounds);
313 }
314 if self.ratchet_opt(index, secret_type)?.is_none() {
315 log::trace!(" initialize sender ratchets");
316 self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
317 }
318 match self.ratchet_mut(index, secret_type)? {
319 SenderRatchet::EncryptionRatchet(_) => {
320 log::error!("This is the wrong ratchet type.");
321 Err(SecretTreeError::RatchetTypeError)
322 }
323 SenderRatchet::DecryptionRatchet(dec_ratchet) => {
324 log::trace!(" getting secret for decryption");
325 dec_ratchet.secret_for_decryption(ciphersuite, crypto, generation, configuration)
326 }
327 }
328 }
329
330 pub(crate) fn secret_for_encryption(
333 &mut self,
334 ciphersuite: Ciphersuite,
335 crypto: &impl OpenMlsCrypto,
336 index: LeafNodeIndex,
337 secret_type: SecretType,
338 ) -> Result<(u32, RatchetKeyMaterial), SecretTreeError> {
339 if self.ratchet_opt(index, secret_type)?.is_none() {
340 self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
341 }
342 match self.ratchet_mut(index, secret_type)? {
343 SenderRatchet::DecryptionRatchet(_) => {
344 log::error!("Invalid ratchet type. Got decryption, expected encryption.");
345 Err(SecretTreeError::RatchetTypeError)
346 }
347 SenderRatchet::EncryptionRatchet(enc_ratchet) => {
348 enc_ratchet.ratchet_forward(crypto, ciphersuite)
349 }
350 }
351 }
352
353 fn ratchet_mut(
356 &mut self,
357 index: LeafNodeIndex,
358 secret_type: SecretType,
359 ) -> Result<&mut SenderRatchet, SecretTreeError> {
360 let sender_ratchets = match secret_type {
361 SecretType::HandshakeSecret => &mut self.handshake_sender_ratchets,
362 SecretType::ApplicationSecret => &mut self.application_sender_ratchets,
363 };
364 sender_ratchets
365 .get_mut(index.usize())
366 .and_then(|r| r.as_mut())
367 .ok_or(SecretTreeError::IndexOutOfBounds)
368 }
369
370 fn ratchet_opt(
372 &self,
373 index: LeafNodeIndex,
374 secret_type: SecretType,
375 ) -> Result<Option<&SenderRatchet>, SecretTreeError> {
376 let sender_ratchets = match secret_type {
377 SecretType::HandshakeSecret => &self.handshake_sender_ratchets,
378 SecretType::ApplicationSecret => &self.application_sender_ratchets,
379 };
380 match sender_ratchets.get(index.usize()) {
381 Some(sender_ratchet_option) => Ok(sender_ratchet_option.as_ref()),
382 None => Err(SecretTreeError::IndexOutOfBounds),
383 }
384 }
385
386 fn derive_down(
389 &mut self,
390 ciphersuite: Ciphersuite,
391 crypto: &impl OpenMlsCrypto,
392 index_in_tree: ParentNodeIndex,
393 ) -> Result<(), SecretTreeError> {
394 log::debug!(
395 "Deriving tree secret for parent node {} with {}",
396 index_in_tree.u32(),
397 ciphersuite
398 );
399 let hash_len = ciphersuite.hash_length();
400 let node_secret = match &self.get_node(index_in_tree.into())? {
401 Some(node) => &node.secret,
402 None => {
404 return Err(SecretTreeError::LibraryError);
405 }
406 };
407 log_crypto!(trace, "Node secret: {:x?}", node_secret.as_slice());
408 let left_index = left(index_in_tree);
409 let right_index = right(index_in_tree);
410 let left_secret =
411 node_secret.kdf_expand_label(crypto, ciphersuite, "tree", b"left", hash_len)?;
412 let right_secret =
413 node_secret.kdf_expand_label(crypto, ciphersuite, "tree", b"right", hash_len)?;
414 log_crypto!(
415 trace,
416 "Left node ({}) secret: {:x?}",
417 left_index.test_u32(),
418 left_secret.as_slice()
419 );
420 log_crypto!(
421 trace,
422 "Right node ({}) secret: {:x?}",
423 right_index.test_u32(),
424 right_secret.as_slice()
425 );
426
427 self.set_node(
429 left_index,
430 Some(SecretTreeNode {
431 secret: left_secret,
432 }),
433 )?;
434
435 self.set_node(
437 right_index,
438 Some(SecretTreeNode {
439 secret: right_secret,
440 }),
441 )?;
442
443 self.set_node(index_in_tree.into(), None)
445 }
446
447 fn get_node(&self, index: TreeNodeIndex) -> Result<Option<&SecretTreeNode>, SecretTreeError> {
448 match index {
449 TreeNodeIndex::Leaf(leaf_index) => Ok(self
450 .leaf_nodes
451 .get(leaf_index.usize())
452 .ok_or(SecretTreeError::IndexOutOfBounds)?
453 .as_ref()),
454 TreeNodeIndex::Parent(parent_index) => Ok(self
455 .parent_nodes
456 .get(parent_index.usize())
457 .ok_or(SecretTreeError::IndexOutOfBounds)?
458 .as_ref()),
459 }
460 }
461
462 fn set_node(
463 &mut self,
464 index: TreeNodeIndex,
465 node: Option<SecretTreeNode>,
466 ) -> Result<(), SecretTreeError> {
467 match index {
468 TreeNodeIndex::Leaf(leaf_index) => {
469 *self
470 .leaf_nodes
471 .get_mut(leaf_index.usize())
472 .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
473 }
474 TreeNodeIndex::Parent(parent_index) => {
475 *self
476 .parent_nodes
477 .get_mut(parent_index.usize())
478 .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
479 }
480 }
481 Ok(())
482 }
483}