1use openmls_traits::crypto::OpenMlsCrypto;
2use openmls_traits::types::{Ciphersuite, CryptoError};
3use thiserror::Error;
4use tls_codec::{Error as TlsCodecError, TlsSerialize, TlsSize};
5
6use super::*;
7use crate::{
8 binary_tree::{
9 array_representation::{
10 direct_path, left, right, root, ParentNodeIndex, TreeNodeIndex, TreeSize,
11 },
12 LeafNodeIndex,
13 },
14 framing::*,
15 schedule::*,
16 tree::sender_ratchet::*,
17};
18
19#[derive(Error, Debug, Eq, PartialEq, Clone)]
21pub enum SecretTreeError {
22 #[error("Generation is too old to be processed.")]
24 TooDistantInThePast,
25 #[error("Generation is too far in the future to be processed.")]
27 TooDistantInTheFuture,
28 #[error("Index out of bounds")]
30 IndexOutOfBounds,
31 #[error("The requested secret was deleted to preserve forward secrecy.")]
33 SecretReuseError,
34 #[error("Cannot create decryption secrets from own sender ratchet or encryption secrets from the sender ratchets of other members.")]
36 RatchetTypeError,
37 #[error("Ratchet generation has reached `u32::MAX`.")]
39 RatchetTooLong,
40 #[error("An unrecoverable error has occurred due to a bug in the implementation.")]
42 LibraryError,
43 #[error(transparent)]
45 CodecError(#[from] TlsCodecError),
46 #[error(transparent)]
48 CryptoError(#[from] CryptoError),
49}
50
51#[derive(Debug, Copy, Clone)]
52pub(crate) enum SecretType {
53 HandshakeSecret,
54 ApplicationSecret,
55}
56
57impl From<&ContentType> for SecretType {
58 fn from(content_type: &ContentType) -> SecretType {
59 match content_type {
60 ContentType::Application => SecretType::ApplicationSecret,
61 ContentType::Commit => SecretType::HandshakeSecret,
62 ContentType::Proposal => SecretType::HandshakeSecret,
63 }
64 }
65}
66
67impl From<&PublicMessage> for SecretType {
68 fn from(public_message: &PublicMessage) -> SecretType {
69 SecretType::from(&public_message.content_type())
70 }
71}
72
73pub(crate) fn derive_child_secrets(
74 parent_secret: &Secret,
75 crypto: &impl OpenMlsCrypto,
76 ciphersuite: Ciphersuite,
77) -> Result<(Secret, Secret), CryptoError> {
78 let left_child = parent_secret.kdf_expand_label(
79 crypto,
80 ciphersuite,
81 "tree",
82 b"left",
83 ciphersuite.hash_length(),
84 )?;
85 let right_child = parent_secret.kdf_expand_label(
86 crypto,
87 ciphersuite,
88 "tree",
89 b"right",
90 ciphersuite.hash_length(),
91 )?;
92 Ok((left_child, right_child))
93}
94
95#[inline]
98pub(crate) fn derive_tree_secret(
99 ciphersuite: Ciphersuite,
100 secret: &Secret,
101 label: &str,
102 generation: u32,
103 length: usize,
104 crypto: &impl OpenMlsCrypto,
105) -> Result<Secret, SecretTreeError> {
106 log::debug!(
107 "Derive tree secret with label \"{label}\" in generation {generation} of length {length}"
108 );
109 log_crypto!(trace, "Input secret {:x?}", secret.as_slice());
110
111 let secret = secret.kdf_expand_label(
112 crypto,
113 ciphersuite,
114 label,
115 &generation.to_be_bytes(),
116 length,
117 )?;
118 log_crypto!(trace, "Derived secret {:x?}", secret.as_slice());
119 Ok(secret)
120}
121
122#[derive(Debug, TlsSerialize, TlsSize)]
123pub(crate) struct TreeContext {
124 pub(crate) node: u32,
125 pub(crate) generation: u32,
126}
127
128#[derive(Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
129#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
130pub(crate) struct SecretTreeNode {
131 pub(crate) secret: Secret,
132}
133
134#[derive(Serialize, Deserialize)]
135#[cfg_attr(any(feature = "test-utils", test), derive(PartialEq, Clone))]
136#[cfg_attr(any(feature = "crypto-debug", test), derive(Debug))]
137pub(crate) struct SecretTree {
138 own_index: LeafNodeIndex,
139 leaf_nodes: Vec<Option<SecretTreeNode>>,
140 parent_nodes: Vec<Option<SecretTreeNode>>,
141 handshake_sender_ratchets: Vec<Option<SenderRatchet>>,
142 application_sender_ratchets: Vec<Option<SenderRatchet>>,
143 size: TreeSize,
144}
145
146impl SecretTree {
147 pub(crate) fn new(
152 encryption_secret: EncryptionSecret,
153 size: TreeSize,
154 own_index: LeafNodeIndex,
155 ) -> Self {
156 let leaf_count = size.leaf_count() as usize;
157 let leaf_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
158 let parent_nodes = std::iter::repeat_with(|| None).take(leaf_count).collect();
159 let handshake_sender_ratchets = std::iter::repeat_with(|| None).take(leaf_count).collect();
160 let application_sender_ratchets =
161 std::iter::repeat_with(|| None).take(leaf_count).collect();
162
163 let mut secret_tree = SecretTree {
164 own_index,
165 leaf_nodes,
166 parent_nodes,
167 handshake_sender_ratchets,
168 application_sender_ratchets,
169 size,
170 };
171
172 let _ = secret_tree.set_node(
176 root(size),
177 Some(SecretTreeNode {
178 secret: encryption_secret.consume_secret(),
179 }),
180 );
181
182 secret_tree
183 }
184
185 #[cfg(test)]
187 pub(crate) fn generation(&self, index: LeafNodeIndex, secret_type: SecretType) -> u32 {
188 match self
189 .ratchet_opt(index, secret_type)
190 .expect("Index out of bounds.")
191 {
192 Some(sender_ratchet) => sender_ratchet.generation(),
193 None => 0,
194 }
195 }
196
197 fn initialize_sender_ratchets(
200 &mut self,
201 ciphersuite: Ciphersuite,
202 crypto: &impl OpenMlsCrypto,
203 index: LeafNodeIndex,
204 ) -> Result<(), SecretTreeError> {
205 log::trace!("Initializing sender ratchets for {index:?} with {ciphersuite}");
206 if index.u32() >= self.size.leaf_count() {
207 log::error!("Index is larger than the tree size.");
208 return Err(SecretTreeError::IndexOutOfBounds);
209 }
210 if self
212 .ratchet_opt(index, SecretType::HandshakeSecret)?
213 .is_some()
214 && self
215 .ratchet_opt(index, SecretType::ApplicationSecret)?
216 .is_some()
217 {
218 log::trace!("The sender ratchets are initialized already.");
219 return Ok(());
220 }
221
222 if self.get_node(index.into())?.is_none() {
224 let mut empty_nodes: Vec<ParentNodeIndex> = Vec::new();
227 let direct_path = direct_path(index, self.size);
228 log::trace!("Direct path for node {index:?}: {direct_path:?}");
229 for parent_node in direct_path {
230 empty_nodes.push(parent_node);
231 if self.get_node(parent_node.into())?.is_some() {
233 break;
234 }
235 }
236
237 empty_nodes.reverse();
239
240 for n in empty_nodes {
242 log::trace!("Derive down for parent node {n:?}.");
243 self.derive_down(ciphersuite, crypto, n)?;
244 }
245 }
246
247 let node_secret = match self.get_node(index.into())? {
249 Some(node) => &node.secret,
250 None => {
252 return Err(SecretTreeError::LibraryError);
253 }
254 };
255
256 log::trace!("Deriving leaf node secrets for leaf {index:?}");
257
258 let handshake_ratchet_secret = node_secret.kdf_expand_label(
259 crypto,
260 ciphersuite,
261 "handshake",
262 b"",
263 ciphersuite.hash_length(),
264 )?;
265 let application_ratchet_secret = node_secret.kdf_expand_label(
266 crypto,
267 ciphersuite,
268 "application",
269 b"",
270 ciphersuite.hash_length(),
271 )?;
272
273 log_crypto!(
274 trace,
275 "handshake ratchet secret {handshake_ratchet_secret:x?}"
276 );
277 log_crypto!(
278 trace,
279 "application ratchet secret {application_ratchet_secret:x?}"
280 );
281
282 let (handshake_sender_ratchet, application_sender_ratchet) = if index == self.own_index {
285 let handshake_sender_ratchet = SenderRatchet::EncryptionRatchet(
286 RatchetSecret::initial_ratchet_secret(handshake_ratchet_secret),
287 );
288 let application_sender_ratchet = SenderRatchet::EncryptionRatchet(
289 RatchetSecret::initial_ratchet_secret(application_ratchet_secret),
290 );
291
292 (handshake_sender_ratchet, application_sender_ratchet)
293 } else {
294 let handshake_sender_ratchet =
295 SenderRatchet::DecryptionRatchet(DecryptionRatchet::new(handshake_ratchet_secret));
296 let application_sender_ratchet = SenderRatchet::DecryptionRatchet(
297 DecryptionRatchet::new(application_ratchet_secret),
298 );
299
300 (handshake_sender_ratchet, application_sender_ratchet)
301 };
302
303 *self
304 .handshake_sender_ratchets
305 .get_mut(index.usize())
306 .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(handshake_sender_ratchet);
307 *self
308 .application_sender_ratchets
309 .get_mut(index.usize())
310 .ok_or(SecretTreeError::IndexOutOfBounds)? = Some(application_sender_ratchet);
311
312 self.set_node(index.into(), None)
314 }
315
316 pub(crate) fn secret_for_decryption(
320 &mut self,
321 ciphersuite: Ciphersuite,
322 crypto: &impl OpenMlsCrypto,
323 index: LeafNodeIndex,
324 secret_type: SecretType,
325 generation: u32,
326 configuration: &SenderRatchetConfiguration,
327 ) -> Result<RatchetKeyMaterial, SecretTreeError> {
328 log::debug!(
329 "Generating {secret_type:?} decryption secret for {index:?} in generation {generation} with {ciphersuite}",
330 );
331 if index.u32() >= self.size.leaf_count() {
333 log::error!("Sender index is not in the tree.");
334 return Err(SecretTreeError::IndexOutOfBounds);
335 }
336 if self.ratchet_opt(index, secret_type)?.is_none() {
337 log::trace!(" initialize sender ratchets");
338 self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
339 }
340 match self.ratchet_mut(index, secret_type)? {
341 SenderRatchet::EncryptionRatchet(_) => {
342 log::error!("This is the wrong ratchet type.");
343 Err(SecretTreeError::RatchetTypeError)
344 }
345 SenderRatchet::DecryptionRatchet(dec_ratchet) => {
346 log::trace!(" getting secret for decryption");
347 dec_ratchet.secret_for_decryption(ciphersuite, crypto, generation, configuration)
348 }
349 }
350 }
351
352 pub(crate) fn secret_for_encryption(
355 &mut self,
356 ciphersuite: Ciphersuite,
357 crypto: &impl OpenMlsCrypto,
358 index: LeafNodeIndex,
359 secret_type: SecretType,
360 ) -> Result<(u32, RatchetKeyMaterial), SecretTreeError> {
361 if self.ratchet_opt(index, secret_type)?.is_none() {
362 self.initialize_sender_ratchets(ciphersuite, crypto, index)?;
363 }
364 match self.ratchet_mut(index, secret_type)? {
365 SenderRatchet::DecryptionRatchet(_) => {
366 log::error!("Invalid ratchet type. Got decryption, expected encryption.");
367 Err(SecretTreeError::RatchetTypeError)
368 }
369 SenderRatchet::EncryptionRatchet(enc_ratchet) => {
370 enc_ratchet.ratchet_forward(crypto, ciphersuite)
371 }
372 }
373 }
374
375 fn ratchet_mut(
378 &mut self,
379 index: LeafNodeIndex,
380 secret_type: SecretType,
381 ) -> Result<&mut SenderRatchet, SecretTreeError> {
382 let sender_ratchets = match secret_type {
383 SecretType::HandshakeSecret => &mut self.handshake_sender_ratchets,
384 SecretType::ApplicationSecret => &mut self.application_sender_ratchets,
385 };
386 sender_ratchets
387 .get_mut(index.usize())
388 .and_then(|r| r.as_mut())
389 .ok_or(SecretTreeError::IndexOutOfBounds)
390 }
391
392 fn ratchet_opt(
394 &self,
395 index: LeafNodeIndex,
396 secret_type: SecretType,
397 ) -> Result<Option<&SenderRatchet>, SecretTreeError> {
398 let sender_ratchets = match secret_type {
399 SecretType::HandshakeSecret => &self.handshake_sender_ratchets,
400 SecretType::ApplicationSecret => &self.application_sender_ratchets,
401 };
402 match sender_ratchets.get(index.usize()) {
403 Some(sender_ratchet_option) => Ok(sender_ratchet_option.as_ref()),
404 None => Err(SecretTreeError::IndexOutOfBounds),
405 }
406 }
407
408 fn derive_down(
411 &mut self,
412 ciphersuite: Ciphersuite,
413 crypto: &impl OpenMlsCrypto,
414 index_in_tree: ParentNodeIndex,
415 ) -> Result<(), SecretTreeError> {
416 log::debug!(
417 "Deriving tree secret for parent node {} with {}",
418 index_in_tree.u32(),
419 ciphersuite
420 );
421 let node_secret = match &self.get_node(index_in_tree.into())? {
422 Some(node) => &node.secret,
423 None => {
425 return Err(SecretTreeError::LibraryError);
426 }
427 };
428 log_crypto!(trace, "Node secret: {:x?}", node_secret.as_slice());
429 let left_index = left(index_in_tree);
430 let right_index = right(index_in_tree);
431 let (left_secret, right_secret) = derive_child_secrets(node_secret, crypto, ciphersuite)?;
432 log_crypto!(
433 trace,
434 "Left node ({}) secret: {:x?}",
435 left_index.test_u32(),
436 left_secret.as_slice()
437 );
438 log_crypto!(
439 trace,
440 "Right node ({}) secret: {:x?}",
441 right_index.test_u32(),
442 right_secret.as_slice()
443 );
444
445 self.set_node(
447 left_index,
448 Some(SecretTreeNode {
449 secret: left_secret,
450 }),
451 )?;
452
453 self.set_node(
455 right_index,
456 Some(SecretTreeNode {
457 secret: right_secret,
458 }),
459 )?;
460
461 self.set_node(index_in_tree.into(), None)
463 }
464
465 fn get_node(&self, index: TreeNodeIndex) -> Result<Option<&SecretTreeNode>, SecretTreeError> {
466 match index {
467 TreeNodeIndex::Leaf(leaf_index) => Ok(self
468 .leaf_nodes
469 .get(leaf_index.usize())
470 .ok_or(SecretTreeError::IndexOutOfBounds)?
471 .as_ref()),
472 TreeNodeIndex::Parent(parent_index) => Ok(self
473 .parent_nodes
474 .get(parent_index.usize())
475 .ok_or(SecretTreeError::IndexOutOfBounds)?
476 .as_ref()),
477 }
478 }
479
480 fn set_node(
481 &mut self,
482 index: TreeNodeIndex,
483 node: Option<SecretTreeNode>,
484 ) -> Result<(), SecretTreeError> {
485 match index {
486 TreeNodeIndex::Leaf(leaf_index) => {
487 *self
488 .leaf_nodes
489 .get_mut(leaf_index.usize())
490 .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
491 }
492 TreeNodeIndex::Parent(parent_index) => {
493 *self
494 .parent_nodes
495 .get_mut(parent_index.usize())
496 .ok_or(SecretTreeError::IndexOutOfBounds)? = node;
497 }
498 }
499 Ok(())
500 }
501}