1use openmls_traits::crypto::OpenMlsCrypto;
2use openmls_traits::types::{Ciphersuite, CryptoError};
3use thiserror::Error;
4use tls_codec::{Error as TlsCodecError, TlsSerialize, TlsSize};
5
6use super::*;
7use crate::{
8 binary_tree::{
9 array_representation::{
10 direct_path, left, right, root, ParentNodeIndex, TreeNodeIndex, TreeSize,
11 },
12 LeafNodeIndex,
13 },
14 framing::*,
15 schedule::*,
16 tree::sender_ratchet::*,
17};
18
19#[derive(Error, Debug, Eq, PartialEq, Clone)]
21pub enum SecretTreeError {
22 #[error("Generation is too old to be processed.")]
24 TooDistantInThePast,
25 #[error("Generation is too far in the future to be processed.")]
27 TooDistantInTheFuture,
28 #[error("Index out of bounds")]
30 IndexOutOfBounds,
31 #[error("The requested secret was deleted to preserve forward secrecy.")]
33 SecretReuseError,
34 #[error("Cannot create decryption secrets from own sender ratchet or encryption secrets from the sender ratchets of other members.")]
36 RatchetTypeError,
37 #[error("Ratchet generation has reached `u32::MAX`.")]
39 RatchetTooLong,
40 #[error("An unrecoverable error has occurred due to a bug in the implementation.")]
42 LibraryError,
43 #[error(transparent)]
45 CodecError(#[from] TlsCodecError),
46 #[error(transparent)]
48 CryptoError(#[from] CryptoError),
49}
50
51#[derive(Debug, Copy, Clone)]
52pub(crate) enum SecretType {
53 HandshakeSecret,
54 ApplicationSecret,
55}
56
57impl From<&ContentType> for SecretType {
58 fn from(content_type: &ContentType) -> SecretType {
59 match content_type {
60 ContentType::Application => SecretType::ApplicationSecret,
61 ContentType::Commit => SecretType::HandshakeSecret,
62 ContentType::Proposal => SecretType::HandshakeSecret,
63 }
64 }
65}
66
67impl From<&PublicMessage> for SecretType {
68 fn from(public_message: &PublicMessage) -> SecretType {
69 SecretType::from(&public_message.content_type())
70 }
71}
72
73#[inline]
76pub(crate) fn derive_tree_secret(
77 ciphersuite: Ciphersuite,
78 secret: &Secret,
79 label: &str,
80 generation: u32,
81 length: usize,
82 crypto: &impl OpenMlsCrypto,
83) -> Result<Secret, SecretTreeError> {
84 log::debug!(
85 "Derive tree secret with label \"{}\" in generation {} of length {}",
86 label,
87 generation,
88 length
89 );
90 log_crypto!(trace, "Input secret {:x?}", secret.as_slice());
91
92 let secret = secret.kdf_expand_label(
93 crypto,
94 ciphersuite,
95 label,
96 &generation.to_be_bytes(),
97 length,
98 )?;
99 log_crypto!(trace, "Derived secret {:x?}", secret.as_slice());
100 Ok(secret)
101}
102
103#[derive(Debug, TlsSerialize, TlsSize)]
104pub(crate) struct TreeContext {
105 pub(crate) node: u32,
106 pub(crate) generation: u32,
107}
108
109#[derive(Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
110#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
111pub(crate) struct SecretTreeNode {
112 pub(crate) secret: Secret,
113}
114
115#[derive(Serialize, Deserialize)]
116#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
117#[cfg_attr(any(feature = "crypto-debug", test), derive(Debug))]
118pub(crate) struct SecretTree {
119 own_index: LeafNodeIndex,
120 leaf_nodes: Vec<Option<SecretTreeNode>>,
121 parent_nodes: Vec<Option<SecretTreeNode>>,
122 handshake_sender_ratchets: Vec<Option<SenderRatchet>>,
123 application_sender_ratchets: Vec<Option<SenderRatchet>>,
124 size: TreeSize,
125}
126
127impl SecretTree {
128 pub(crate) fn new(
133 encryption_secret: EncryptionSecret,
134 size: TreeSize,
135 own_index: LeafNodeIndex,
136 ) -> Self {
137 let leaf_count = size.leaf_count() as usize;
138 let leaf_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
139 let parent_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
140 let handshake_sender_ratchets = std::iter::repeat_with(|| None).take(leaf_count).collect();
141 let application_sender_ratchets =
142 std::iter::repeat_with(|| None).take(leaf_count).collect();
143
144 let mut secret_tree = SecretTree {
145 own_index,
146 leaf_nodes,
147 parent_nodes,
148 handshake_sender_ratchets,
149 application_sender_ratchets,
150 size,
151 };
152
153 let _ = secret_tree.set_node(
157 root(size),
158 Some(SecretTreeNode {
159 secret: encryption_secret.consume_secret(),
160 }),
161 );
162
163 secret_tree
164 }
165
166 #[cfg(test)]
168 pub(crate) fn generation(&self, index: LeafNodeIndex, secret_type: SecretType) -> u32 {
169 match self
170 .ratchet_opt(index, secret_type)
171 .expect("Index out of bounds.")
172 {
173 Some(sender_ratchet) => sender_ratchet.generation(),
174 None => 0,
175 }
176 }
177
178 fn initialize_sender_ratchets(
181 &mut self,
182 ciphersuite: Ciphersuite,
183 crypto: &impl OpenMlsCrypto,
184 index: LeafNodeIndex,
185 ) -> Result<(), SecretTreeError> {
186 log::trace!("Initializing sender ratchets for {index:?} with {ciphersuite}");
187 if index.u32() >= self.size.leaf_count() {
188 log::error!("Index is larger than the tree size.");
189 return Err(SecretTreeError::IndexOutOfBounds);
190 }
191 if self
193 .ratchet_opt(index, SecretType::HandshakeSecret)?
194 .is_some()
195 && self
196 .ratchet_opt(index, SecretType::ApplicationSecret)?
197 .is_some()
198 {
199 log::trace!("The sender ratchets are initialized already.");
200 return Ok(());
201 }
202
203 if self.get_node(index.into())?.is_none() {
205 let mut empty_nodes: Vec<ParentNodeIndex> = Vec::new();
208 let direct_path = direct_path(index, self.size);
209 log::trace!("Direct path for node {index:?}: {:?}", direct_path);
210 for parent_node in direct_path {
211 empty_nodes.push(parent_node);
212 if self.get_node(parent_node.into())?.is_some() {
214 break;
215 }
216 }
217
218 empty_nodes.reverse();
220
221 for n in empty_nodes {
223 log::trace!("Derive down for parent node {n:?}.");
224 self.derive_down(ciphersuite, crypto, n)?;
225 }
226 }
227
228 let node_secret = match self.get_node(index.into())? {
230 Some(node) => &node.secret,
231 None => {
233 return Err(SecretTreeError::LibraryError);
234 }
235 };
236
237 log::trace!("Deriving leaf node secrets for leaf {index:?}");
238
239 let handshake_ratchet_secret = node_secret.kdf_expand_label(
240 crypto,
241 ciphersuite,
242 "handshake",
243 b"",
244 ciphersuite.hash_length(),
245 )?;
246 let application_ratchet_secret = node_secret.kdf_expand_label(
247 crypto,
248 ciphersuite,
249 "application",
250 b"",
251 ciphersuite.hash_length(),
252 )?;
253
254 log_crypto!(
255 trace,
256 "handshake ratchet secret {handshake_ratchet_secret:x?}"
257 );
258 log_crypto!(
259 trace,
260 "application ratchet secret {application_ratchet_secret:x?}"
261 );
262
263 let (handshake_sender_ratchet, application_sender_ratchet) = if index == self.own_index {
266 let handshake_sender_ratchet = SenderRatchet::EncryptionRatchet(
267 RatchetSecret::initial_ratchet_secret(handshake_ratchet_secret),
268 );
269 let application_sender_ratchet = SenderRatchet::EncryptionRatchet(
270 RatchetSecret::initial_ratchet_secret(application_ratchet_secret),
271 );
272
273 (handshake_sender_ratchet, application_sender_ratchet)
274 } else {
275 let handshake_sender_ratchet =
276 SenderRatchet::DecryptionRatchet(DecryptionRatchet::new(handshake_ratchet_secret));
277 let application_sender_ratchet = SenderRatchet::DecryptionRatchet(
278 DecryptionRatchet::new(application_ratchet_secret),
279 );
280
281 (handshake_sender_ratchet, application_sender_ratchet)
282 };
283
284 *self
285 .handshake_sender_ratchets
286 .get_mut(index.usize())
287 .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(handshake_sender_ratchet);
288 *self
289 .application_sender_ratchets
290 .get_mut(index.usize())
291 .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(application_sender_ratchet);
292
293 self.set_node(index.into(), None)
295 }
296
297 pub(crate) fn secret_for_decryption(
301 &mut self,
302 ciphersuite: Ciphersuite,
303 crypto: &impl OpenMlsCrypto,
304 index: LeafNodeIndex,
305 secret_type: SecretType,
306 generation: u32,
307 configuration: &SenderRatchetConfiguration,
308 ) -> Result<RatchetKeyMaterial, SecretTreeError> {
309 log::debug!(
310 "Generating {:?} decryption secret for {:?} in generation {} with {}",
311 secret_type,
312 index,
313 generation,
314 ciphersuite,
315 );
316 if index.u32() >= self.size.leaf_count() {
318 log::error!("Sender index is not in the tree.");
319 return Err(SecretTreeError::IndexOutOfBounds);
320 }
321 if self.ratchet_opt(index, secret_type)?.is_none() {
322 log::trace!(" initialize sender ratchets");
323 self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
324 }
325 match self.ratchet_mut(index, secret_type)? {
326 SenderRatchet::EncryptionRatchet(_) => {
327 log::error!("This is the wrong ratchet type.");
328 Err(SecretTreeError::RatchetTypeError)
329 }
330 SenderRatchet::DecryptionRatchet(dec_ratchet) => {
331 log::trace!(" getting secret for decryption");
332 dec_ratchet.secret_for_decryption(ciphersuite, crypto, generation, configuration)
333 }
334 }
335 }
336
337 pub(crate) fn secret_for_encryption(
340 &mut self,
341 ciphersuite: Ciphersuite,
342 crypto: &impl OpenMlsCrypto,
343 index: LeafNodeIndex,
344 secret_type: SecretType,
345 ) -> Result<(u32, RatchetKeyMaterial), SecretTreeError> {
346 if self.ratchet_opt(index, secret_type)?.is_none() {
347 self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
348 }
349 match self.ratchet_mut(index, secret_type)? {
350 SenderRatchet::DecryptionRatchet(_) => {
351 log::error!("Invalid ratchet type. Got decryption, expected encryption.");
352 Err(SecretTreeError::RatchetTypeError)
353 }
354 SenderRatchet::EncryptionRatchet(enc_ratchet) => {
355 enc_ratchet.ratchet_forward(crypto, ciphersuite)
356 }
357 }
358 }
359
360 fn ratchet_mut(
363 &mut self,
364 index: LeafNodeIndex,
365 secret_type: SecretType,
366 ) -> Result<&mut SenderRatchet, SecretTreeError> {
367 let sender_ratchets = match secret_type {
368 SecretType::HandshakeSecret => &mut self.handshake_sender_ratchets,
369 SecretType::ApplicationSecret => &mut self.application_sender_ratchets,
370 };
371 sender_ratchets
372 .get_mut(index.usize())
373 .and_then(|r| r.as_mut())
374 .ok_or(SecretTreeError::IndexOutOfBounds)
375 }
376
377 fn ratchet_opt(
379 &self,
380 index: LeafNodeIndex,
381 secret_type: SecretType,
382 ) -> Result<Option<&SenderRatchet>, SecretTreeError> {
383 let sender_ratchets = match secret_type {
384 SecretType::HandshakeSecret => &self.handshake_sender_ratchets,
385 SecretType::ApplicationSecret => &self.application_sender_ratchets,
386 };
387 match sender_ratchets.get(index.usize()) {
388 Some(sender_ratchet_option) => Ok(sender_ratchet_option.as_ref()),
389 None => Err(SecretTreeError::IndexOutOfBounds),
390 }
391 }
392
393 fn derive_down(
396 &mut self,
397 ciphersuite: Ciphersuite,
398 crypto: &impl OpenMlsCrypto,
399 index_in_tree: ParentNodeIndex,
400 ) -> Result<(), SecretTreeError> {
401 log::debug!(
402 "Deriving tree secret for parent node {} with {}",
403 index_in_tree.u32(),
404 ciphersuite
405 );
406 let hash_len = ciphersuite.hash_length();
407 let node_secret = match &self.get_node(index_in_tree.into())? {
408 Some(node) => &node.secret,
409 None => {
411 return Err(SecretTreeError::LibraryError);
412 }
413 };
414 log_crypto!(trace, "Node secret: {:x?}", node_secret.as_slice());
415 let left_index = left(index_in_tree);
416 let right_index = right(index_in_tree);
417 let left_secret =
418 node_secret.kdf_expand_label(crypto, ciphersuite, "tree", b"left", hash_len)?;
419 let right_secret =
420 node_secret.kdf_expand_label(crypto, ciphersuite, "tree", b"right", hash_len)?;
421 log_crypto!(
422 trace,
423 "Left node ({}) secret: {:x?}",
424 left_index.test_u32(),
425 left_secret.as_slice()
426 );
427 log_crypto!(
428 trace,
429 "Right node ({}) secret: {:x?}",
430 right_index.test_u32(),
431 right_secret.as_slice()
432 );
433
434 self.set_node(
436 left_index,
437 Some(SecretTreeNode {
438 secret: left_secret,
439 }),
440 )?;
441
442 self.set_node(
444 right_index,
445 Some(SecretTreeNode {
446 secret: right_secret,
447 }),
448 )?;
449
450 self.set_node(index_in_tree.into(), None)
452 }
453
454 fn get_node(&self, index: TreeNodeIndex) -> Result<Option<&SecretTreeNode>, SecretTreeError> {
455 match index {
456 TreeNodeIndex::Leaf(leaf_index) => Ok(self
457 .leaf_nodes
458 .get(leaf_index.usize())
459 .ok_or(SecretTreeError::IndexOutOfBounds)?
460 .as_ref()),
461 TreeNodeIndex::Parent(parent_index) => Ok(self
462 .parent_nodes
463 .get(parent_index.usize())
464 .ok_or(SecretTreeError::IndexOutOfBounds)?
465 .as_ref()),
466 }
467 }
468
469 fn set_node(
470 &mut self,
471 index: TreeNodeIndex,
472 node: Option<SecretTreeNode>,
473 ) -> Result<(), SecretTreeError> {
474 match index {
475 TreeNodeIndex::Leaf(leaf_index) => {
476 *self
477 .leaf_nodes
478 .get_mut(leaf_index.usize())
479 .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
480 }
481 TreeNodeIndex::Parent(parent_index) => {
482 *self
483 .parent_nodes
484 .get_mut(parent_index.usize())
485 .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
486 }
487 }
488 Ok(())
489 }
490}