1use openmls_traits::crypto::OpenMlsCrypto;
2use openmls_traits::types::{Ciphersuite, CryptoError};
3use thiserror::Error;
4use tls_codec::{Error as TlsCodecError, TlsSerialize, TlsSize};
5
6use super::*;
7#[cfg(feature = "virtual-clients-draft")]
8use crate::tree::dual_use_ratchet::DualUseRatchet;
9use crate::{
10 binary_tree::{
11 array_representation::{
12 direct_path, left, right, root, ParentNodeIndex, TreeNodeIndex, TreeSize,
13 },
14 LeafNodeIndex,
15 },
16 framing::*,
17 schedule::*,
18 tree::sender_ratchet::*,
19};
20
21#[derive(Error, Debug, Eq, PartialEq, Clone)]
23pub enum SecretTreeError {
24 #[error("Generation is too old to be processed.")]
26 TooDistantInThePast,
27 #[error("Generation is too far in the future to be processed.")]
29 TooDistantInTheFuture,
30 #[error("Index out of bounds")]
32 IndexOutOfBounds,
33 #[error("The requested secret was deleted to preserve forward secrecy.")]
35 SecretReuseError,
36 #[error("Cannot create decryption secrets from own sender ratchet or encryption secrets from the sender ratchets of other members.")]
38 RatchetTypeError,
39 #[error("Ratchet generation has reached `u32::MAX`.")]
41 RatchetTooLong,
42 #[error("An unrecoverable error has occurred due to a bug in the implementation.")]
44 LibraryError,
45 #[error(transparent)]
47 CodecError(#[from] TlsCodecError),
48 #[error(transparent)]
50 CryptoError(#[from] CryptoError),
51}
52
53#[derive(Debug, Copy, Clone)]
54pub(crate) enum SecretType {
55 HandshakeSecret,
56 ApplicationSecret,
57}
58
59impl From<&ContentType> for SecretType {
60 fn from(content_type: &ContentType) -> SecretType {
61 match content_type {
62 ContentType::Application => SecretType::ApplicationSecret,
63 ContentType::Commit => SecretType::HandshakeSecret,
64 ContentType::Proposal => SecretType::HandshakeSecret,
65 }
66 }
67}
68
69impl From<&PublicMessage> for SecretType {
70 fn from(public_message: &PublicMessage) -> SecretType {
71 SecretType::from(&public_message.content_type())
72 }
73}
74
75pub(crate) fn derive_child_secrets(
76 parent_secret: &Secret,
77 crypto: &impl OpenMlsCrypto,
78 ciphersuite: Ciphersuite,
79) -> Result<(Secret, Secret), CryptoError> {
80 let left_child = parent_secret.kdf_expand_label(
81 crypto,
82 ciphersuite,
83 "tree",
84 b"left",
85 ciphersuite.hash_length(),
86 )?;
87 let right_child = parent_secret.kdf_expand_label(
88 crypto,
89 ciphersuite,
90 "tree",
91 b"right",
92 ciphersuite.hash_length(),
93 )?;
94 Ok((left_child, right_child))
95}
96
97#[inline]
100pub(crate) fn derive_tree_secret(
101 ciphersuite: Ciphersuite,
102 secret: &Secret,
103 label: &str,
104 generation: u32,
105 length: usize,
106 crypto: &impl OpenMlsCrypto,
107) -> Result<Secret, SecretTreeError> {
108 log::debug!(
109 "Derive tree secret with label \"{label}\" in generation {generation} of length {length}"
110 );
111 log_crypto!(trace, "Input secret {:x?}", secret.as_slice());
112
113 let secret = secret.kdf_expand_label(
114 crypto,
115 ciphersuite,
116 label,
117 &generation.to_be_bytes(),
118 length,
119 )?;
120 log_crypto!(trace, "Derived secret {:x?}", secret.as_slice());
121 Ok(secret)
122}
123
124#[derive(Debug, TlsSerialize, TlsSize)]
125pub(crate) struct TreeContext {
126 pub(crate) node: u32,
127 pub(crate) generation: u32,
128}
129
130#[derive(Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
131#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
132pub(crate) struct SecretTreeNode {
133 pub(crate) secret: Secret,
134}
135
136#[derive(Serialize, Deserialize)]
137#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
138#[cfg_attr(any(feature = "crypto-debug", test), derive(Debug))]
139pub(crate) struct SecretTree {
140 own_index: LeafNodeIndex,
141 leaf_nodes: Vec<Option<SecretTreeNode>>,
142 parent_nodes: Vec<Option<SecretTreeNode>>,
143 handshake_sender_ratchets: Vec<Option<SenderRatchet>>,
144 application_sender_ratchets: Vec<Option<SenderRatchet>>,
145 size: TreeSize,
146}
147
148impl SecretTree {
149 pub(crate) fn new(
154 encryption_secret: EncryptionSecret,
155 size: TreeSize,
156 own_index: LeafNodeIndex,
157 ) -> Self {
158 let leaf_count = size.leaf_count() as usize;
159 let leaf_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
160 let parent_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
161 let handshake_sender_ratchets = std::iter::repeat_with(|| None).take(leaf_count).collect();
162 let application_sender_ratchets =
163 std::iter::repeat_with(|| None).take(leaf_count).collect();
164
165 let mut secret_tree = SecretTree {
166 own_index,
167 leaf_nodes,
168 parent_nodes,
169 handshake_sender_ratchets,
170 application_sender_ratchets,
171 size,
172 };
173
174 let _ = secret_tree.set_node(
178 root(size),
179 Some(SecretTreeNode {
180 secret: encryption_secret.consume_secret(),
181 }),
182 );
183
184 secret_tree
185 }
186
187 #[cfg(test)]
189 pub(crate) fn generation(&self, index: LeafNodeIndex, secret_type: SecretType) -> u32 {
190 match self
191 .ratchet_opt(index, secret_type)
192 .expect("Index out of bounds.")
193 {
194 Some(sender_ratchet) => sender_ratchet.generation(),
195 None => 0,
196 }
197 }
198
199 fn initialize_sender_ratchets(
202 &mut self,
203 ciphersuite: Ciphersuite,
204 crypto: &impl OpenMlsCrypto,
205 index: LeafNodeIndex,
206 ) -> Result<(), SecretTreeError> {
207 log::trace!("Initializing sender ratchets for {index:?} with {ciphersuite}");
208 if index.u32() >= self.size.leaf_count() {
209 log::error!("Index is larger than the tree size.");
210 return Err(SecretTreeError::IndexOutOfBounds);
211 }
212 if self
214 .ratchet_opt(index, SecretType::HandshakeSecret)?
215 .is_some()
216 && self
217 .ratchet_opt(index, SecretType::ApplicationSecret)?
218 .is_some()
219 {
220 log::trace!("The sender ratchets are initialized already.");
221 return Ok(());
222 }
223
224 if self.get_node(index.into())?.is_none() {
226 let mut empty_nodes: Vec<ParentNodeIndex> = Vec::new();
229 let direct_path = direct_path(index, self.size);
230 log::trace!("Direct path for node {index:?}: {direct_path:?}");
231 for parent_node in direct_path {
232 empty_nodes.push(parent_node);
233 if self.get_node(parent_node.into())?.is_some() {
235 break;
236 }
237 }
238
239 empty_nodes.reverse();
241
242 for n in empty_nodes {
244 log::trace!("Derive down for parent node {n:?}.");
245 self.derive_down(ciphersuite, crypto, n)?;
246 }
247 }
248
249 let node_secret = match self.get_node(index.into())? {
251 Some(node) => &node.secret,
252 None => {
254 return Err(SecretTreeError::LibraryError);
255 }
256 };
257
258 log::trace!("Deriving leaf node secrets for leaf {index:?}");
259
260 let handshake_ratchet_secret = node_secret.kdf_expand_label(
261 crypto,
262 ciphersuite,
263 "handshake",
264 b"",
265 ciphersuite.hash_length(),
266 )?;
267 let application_ratchet_secret = node_secret.kdf_expand_label(
268 crypto,
269 ciphersuite,
270 "application",
271 b"",
272 ciphersuite.hash_length(),
273 )?;
274
275 log_crypto!(
276 trace,
277 "handshake ratchet secret {handshake_ratchet_secret:x?}"
278 );
279 log_crypto!(
280 trace,
281 "application ratchet secret {application_ratchet_secret:x?}"
282 );
283
284 let (handshake_sender_ratchet, application_sender_ratchet) = if index == self.own_index {
290 #[cfg(not(feature = "virtual-clients-draft"))]
291 {
292 (
293 SenderRatchet::EncryptionRatchet(RatchetSecret::initial_ratchet_secret(
294 handshake_ratchet_secret,
295 )),
296 SenderRatchet::EncryptionRatchet(RatchetSecret::initial_ratchet_secret(
297 application_ratchet_secret,
298 )),
299 )
300 }
301 #[cfg(feature = "virtual-clients-draft")]
302 {
303 (
304 SenderRatchet::DualUse(DualUseRatchet::new(handshake_ratchet_secret)),
305 SenderRatchet::DualUse(DualUseRatchet::new(application_ratchet_secret)),
306 )
307 }
308 } else {
309 (
310 SenderRatchet::DecryptionRatchet(DecryptionRatchet::new(handshake_ratchet_secret)),
311 SenderRatchet::DecryptionRatchet(DecryptionRatchet::new(
312 application_ratchet_secret,
313 )),
314 )
315 };
316
317 *self
318 .handshake_sender_ratchets
319 .get_mut(index.usize())
320 .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(handshake_sender_ratchet);
321 *self
322 .application_sender_ratchets
323 .get_mut(index.usize())
324 .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(application_sender_ratchet);
325
326 self.set_node(index.into(), None)
328 }
329
330 pub(crate) fn secret_for_decryption(
334 &mut self,
335 ciphersuite: Ciphersuite,
336 crypto: &impl OpenMlsCrypto,
337 index: LeafNodeIndex,
338 secret_type: SecretType,
339 generation: u32,
340 configuration: &SenderRatchetConfiguration,
341 ) -> Result<RatchetKeyMaterial, SecretTreeError> {
342 log::debug!(
343 "Generating {secret_type:?} decryption secret for {index:?} in generation {generation} with {ciphersuite}",
344 );
345 if index.u32() >= self.size.leaf_count() {
347 log::error!("Sender index is not in the tree.");
348 return Err(SecretTreeError::IndexOutOfBounds);
349 }
350 if self.ratchet_opt(index, secret_type)?.is_none() {
351 log::trace!(" initialize sender ratchets");
352 self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
353 }
354 match self.ratchet_mut(index, secret_type)? {
355 SenderRatchet::EncryptionRatchet(_) => {
356 log::error!("This is the wrong ratchet type.");
357 Err(SecretTreeError::RatchetTypeError)
358 }
359 SenderRatchet::DecryptionRatchet(dec_ratchet) => {
360 log::trace!(" getting secret for decryption");
361 dec_ratchet.secret_for_decryption(ciphersuite, crypto, generation, configuration)
362 }
363 #[cfg(feature = "virtual-clients-draft")]
364 SenderRatchet::DualUse(dual_ratchet) => {
365 log::trace!(" getting secret for decryption (own dual-use ratchet)");
366 dual_ratchet.secret_for_decryption(ciphersuite, crypto, generation, configuration)
367 }
368 }
369 }
370
371 pub(crate) fn secret_for_encryption(
374 &mut self,
375 ciphersuite: Ciphersuite,
376 crypto: &impl OpenMlsCrypto,
377 index: LeafNodeIndex,
378 secret_type: SecretType,
379 ) -> Result<(u32, RatchetKeyMaterial), SecretTreeError> {
380 if self.ratchet_opt(index, secret_type)?.is_none() {
381 self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
382 }
383 match self.ratchet_mut(index, secret_type)? {
384 SenderRatchet::DecryptionRatchet(_) => {
385 log::error!("Invalid ratchet type. Got decryption, expected encryption.");
386 Err(SecretTreeError::RatchetTypeError)
387 }
388 SenderRatchet::EncryptionRatchet(enc_ratchet) => {
389 enc_ratchet.ratchet_forward(crypto, ciphersuite)
390 }
391 #[cfg(feature = "virtual-clients-draft")]
392 SenderRatchet::DualUse(dual_ratchet) => {
393 dual_ratchet.secret_for_encryption(ciphersuite, crypto)
394 }
395 }
396 }
397
398 #[cfg(feature = "virtual-clients-draft")]
399 pub(crate) fn delete_own_secret_for_generation(
400 &mut self,
401 secret_type: SecretType,
402 generation: Generation,
403 ) -> Result<(), SecretTreeError> {
404 match self.ratchet_mut(self.own_index, secret_type)? {
405 SenderRatchet::DualUse(dual_ratchet) => {
406 dual_ratchet.delete_secret_for_generation(generation);
407 Ok(())
408 }
409 SenderRatchet::EncryptionRatchet(_) | SenderRatchet::DecryptionRatchet(_) => {
410 Err(SecretTreeError::RatchetTypeError)
411 }
412 }
413 }
414
415 fn ratchet_mut(
418 &mut self,
419 index: LeafNodeIndex,
420 secret_type: SecretType,
421 ) -> Result<&mut SenderRatchet, SecretTreeError> {
422 let sender_ratchets = match secret_type {
423 SecretType::HandshakeSecret => &mut self.handshake_sender_ratchets,
424 SecretType::ApplicationSecret => &mut self.application_sender_ratchets,
425 };
426 sender_ratchets
427 .get_mut(index.usize())
428 .and_then(|r| r.as_mut())
429 .ok_or(SecretTreeError::IndexOutOfBounds)
430 }
431
432 fn ratchet_opt(
434 &self,
435 index: LeafNodeIndex,
436 secret_type: SecretType,
437 ) -> Result<Option<&SenderRatchet>, SecretTreeError> {
438 let sender_ratchets = match secret_type {
439 SecretType::HandshakeSecret => &self.handshake_sender_ratchets,
440 SecretType::ApplicationSecret => &self.application_sender_ratchets,
441 };
442 match sender_ratchets.get(index.usize()) {
443 Some(sender_ratchet_option) => Ok(sender_ratchet_option.as_ref()),
444 None => Err(SecretTreeError::IndexOutOfBounds),
445 }
446 }
447
448 fn derive_down(
451 &mut self,
452 ciphersuite: Ciphersuite,
453 crypto: &impl OpenMlsCrypto,
454 index_in_tree: ParentNodeIndex,
455 ) -> Result<(), SecretTreeError> {
456 log::debug!(
457 "Deriving tree secret for parent node {} with {}",
458 index_in_tree.u32(),
459 ciphersuite
460 );
461 let node_secret = match &self.get_node(index_in_tree.into())? {
462 Some(node) => &node.secret,
463 None => {
465 return Err(SecretTreeError::LibraryError);
466 }
467 };
468 log_crypto!(trace, "Node secret: {:x?}", node_secret.as_slice());
469 let left_index = left(index_in_tree);
470 let right_index = right(index_in_tree);
471 let (left_secret, right_secret) = derive_child_secrets(node_secret, crypto, ciphersuite)?;
472 log_crypto!(
473 trace,
474 "Left node ({}) secret: {:x?}",
475 left_index.test_u32(),
476 left_secret.as_slice()
477 );
478 log_crypto!(
479 trace,
480 "Right node ({}) secret: {:x?}",
481 right_index.test_u32(),
482 right_secret.as_slice()
483 );
484
485 self.set_node(
487 left_index,
488 Some(SecretTreeNode {
489 secret: left_secret,
490 }),
491 )?;
492
493 self.set_node(
495 right_index,
496 Some(SecretTreeNode {
497 secret: right_secret,
498 }),
499 )?;
500
501 self.set_node(index_in_tree.into(), None)
503 }
504
505 fn get_node(&self, index: TreeNodeIndex) -> Result<Option<&SecretTreeNode>, SecretTreeError> {
506 match index {
507 TreeNodeIndex::Leaf(leaf_index) => Ok(self
508 .leaf_nodes
509 .get(leaf_index.usize())
510 .ok_or(SecretTreeError::IndexOutOfBounds)?
511 .as_ref()),
512 TreeNodeIndex::Parent(parent_index) => Ok(self
513 .parent_nodes
514 .get(parent_index.usize())
515 .ok_or(SecretTreeError::IndexOutOfBounds)?
516 .as_ref()),
517 }
518 }
519
520 fn set_node(
521 &mut self,
522 index: TreeNodeIndex,
523 node: Option<SecretTreeNode>,
524 ) -> Result<(), SecretTreeError> {
525 match index {
526 TreeNodeIndex::Leaf(leaf_index) => {
527 *self
528 .leaf_nodes
529 .get_mut(leaf_index.usize())
530 .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
531 }
532 TreeNodeIndex::Parent(parent_index) => {
533 *self
534 .parent_nodes
535 .get_mut(parent_index.usize())
536 .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
537 }
538 }
539 Ok(())
540 }
541}