1#[cfg(any(feature = "test-utils", test))]
22use std::fmt;
23
24use openmls_traits::{
25 crypto::OpenMlsCrypto,
26 signatures::Signer,
27 types::{Ciphersuite, CryptoError},
28};
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
32
33use self::{
34 diff::{StagedTreeSyncDiff, TreeSyncDiff},
35 node::{
36 leaf_node::{
37 Capabilities, NewLeafNodeParams, TreeInfoTbs, TreePosition, VerifiableLeafNode,
38 },
39 NodeIn,
40 },
41 treesync_node::{TreeSyncLeafNode, TreeSyncNode, TreeSyncParentNode},
42};
43use crate::binary_tree::array_representation::ParentNodeIndex;
44#[cfg(any(feature = "test-utils", test))]
45use crate::{binary_tree::array_representation::level, test_utils::bytes_to_hex};
46use crate::{
47 binary_tree::{
48 array_representation::{is_node_in_tree, LeafNodeIndex, TreeSize},
49 MlsBinaryTree, MlsBinaryTreeError,
50 },
51 ciphersuite::{signable::Verifiable, Secret},
52 credentials::CredentialWithKey,
53 error::LibraryError,
54 extensions::Extensions,
55 group::{GroupId, Member},
56 key_packages::Lifetime,
57 messages::{PathSecret, PathSecretError},
58 schedule::CommitSecret,
59 storage::OpenMlsProvider,
60};
61
62mod hashes;
64use errors::*;
65
66pub(crate) mod diff;
68pub(crate) mod node;
69pub(crate) mod treekem;
70pub(crate) mod treesync_node;
71
72use node::encryption_keys::EncryptionKeyPair;
73
74pub mod errors;
76#[cfg(feature = "test-utils")]
77pub use node::encryption_keys::test_utils;
78pub use node::encryption_keys::EncryptionKey;
79
80pub use node::{
82 leaf_node::{
83 LeafNode, LeafNodeParameters, LeafNodeParametersBuilder, LeafNodeSource,
84 LeafNodeUpdateError,
85 },
86 parent_node::ParentNode,
87 Node,
88};
89
90#[cfg(any(feature = "test-utils", test))]
92pub mod tests_and_kats;
93
94#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
96pub struct RatchetTree(Vec<Option<Node>>);
97
98#[derive(Error, Debug, PartialEq, Clone)]
100pub enum RatchetTreeError {
101 #[error("The ratchet tree has no nodes.")]
103 MissingNodes,
104 #[error("The ratchet tree has trailing blank nodes.")]
106 TrailingBlankNodes,
107 #[error("Invalid node signature.")]
109 InvalidNodeSignature,
110 #[error("Wrong node type.")]
112 WrongNodeType,
113}
114
115impl RatchetTree {
116 fn trimmed(mut nodes: Vec<Option<Node>>) -> Self {
120 match nodes.iter().enumerate().rfind(|(_, node)| node.is_some()) {
122 Some((rightmost_nonempty_position, _)) => {
123 nodes.resize(rightmost_nonempty_position + 1, None);
125 }
126 None => {
127 nodes.clear();
129 }
130 }
131
132 debug_assert!(!nodes.is_empty(), "Caller should have ensured that `RatchetTree::trimmed` is not called with a vector that is empty after removing all trailing blank nodes.");
133 Self(nodes)
134 }
135
136 pub(crate) fn try_from_nodes(
138 ciphersuite: Ciphersuite,
139 crypto: &impl OpenMlsCrypto,
140 nodes: Vec<Option<NodeIn>>,
141 group_id: &GroupId,
142 ) -> Result<Self, RatchetTreeError> {
143 match nodes.last() {
147 Some(None) => {
148 Err(RatchetTreeError::TrailingBlankNodes)
150 }
151 None => {
152 Err(RatchetTreeError::MissingNodes)
154 }
155 Some(Some(_)) => {
156 let mut verified_nodes = Vec::new();
161 for (index, node) in nodes.into_iter().enumerate() {
162 let verified_node = match (index % 2, node) {
163 (0, Some(NodeIn::LeafNode(leaf_node))) => {
165 let tree_position = TreePosition::new(
166 group_id.clone(),
167 LeafNodeIndex::new((index / 2) as u32),
168 );
169 let verifiable_leaf_node = leaf_node.into_verifiable_leaf_node();
170 let signature_key = verifiable_leaf_node
171 .signature_key()
172 .clone()
173 .into_signature_public_key_enriched(
174 ciphersuite.signature_algorithm(),
175 );
176 Some(Node::leaf_node(match verifiable_leaf_node {
177 VerifiableLeafNode::KeyPackage(leaf_node) => leaf_node
178 .verify(crypto, &signature_key)
179 .map_err(|_| RatchetTreeError::InvalidNodeSignature)?,
180 VerifiableLeafNode::Update(mut leaf_node) => {
181 leaf_node.add_tree_position(tree_position);
182 leaf_node
183 .verify(crypto, &signature_key)
184 .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
185 }
186 VerifiableLeafNode::Commit(mut leaf_node) => {
187 leaf_node.add_tree_position(tree_position);
188 leaf_node
189 .verify(crypto, &signature_key)
190 .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
191 }
192 }))
193 }
194 (1, Some(NodeIn::ParentNode(parent_node))) => {
196 Some(Node::ParentNode(parent_node))
197 }
198 (_, None) => None,
200 _ => {
202 return Err(RatchetTreeError::WrongNodeType);
203 }
204 };
205 verified_nodes.push(verified_node);
206 }
207 Ok(Self::trimmed(verified_nodes))
208 }
209 }
210 }
211}
212
213#[derive(
216 PartialEq,
217 Eq,
218 Clone,
219 Debug,
220 Serialize,
221 Deserialize,
222 TlsDeserialize,
223 TlsDeserializeBytes,
224 TlsSerialize,
225 TlsSize,
226)]
227pub struct RatchetTreeIn(Vec<Option<NodeIn>>);
228
229impl RatchetTreeIn {
230 pub fn into_verified(
233 self,
234 ciphersuite: Ciphersuite,
235 crypto: &impl OpenMlsCrypto,
236 group_id: &GroupId,
237 ) -> Result<RatchetTree, RatchetTreeError> {
238 RatchetTree::try_from_nodes(ciphersuite, crypto, self.0, group_id)
239 }
240
241 fn from_ratchet_tree(ratchet_tree: RatchetTree) -> Self {
242 let nodes = ratchet_tree
243 .0
244 .into_iter()
245 .map(|node| node.map(NodeIn::from))
246 .collect();
247 Self(nodes)
248 }
249
250 #[cfg(test)]
251 pub(crate) fn from_nodes(nodes: Vec<Option<NodeIn>>) -> Self {
252 Self(nodes)
253 }
254}
255
256impl From<RatchetTree> for RatchetTreeIn {
257 fn from(ratchet_tree: RatchetTree) -> Self {
258 RatchetTreeIn::from_ratchet_tree(ratchet_tree)
259 }
260}
261
262#[cfg(any(feature = "test-utils", test))]
265impl From<RatchetTreeIn> for RatchetTree {
266 fn from(ratchet_tree_in: RatchetTreeIn) -> Self {
267 Self(
268 ratchet_tree_in
269 .0
270 .into_iter()
271 .map(|node| node.map(Node::from))
272 .collect(),
273 )
274 }
275}
276
277#[cfg(any(feature = "test-utils", test))]
278fn log2(x: u32) -> usize {
279 if x == 0 {
280 return 0;
281 }
282 (31 - x.leading_zeros()) as usize
283}
284
285#[cfg(any(feature = "test-utils", test))]
286pub(crate) fn root(size: u32) -> u32 {
287 (1 << log2(size)) - 1
288}
289
290#[cfg(any(feature = "test-utils", test))]
291impl fmt::Display for RatchetTree {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 let factor = 3;
294 let nodes = &self.0;
295 let tree_size = nodes.len() as u32;
296
297 for (i, node) in nodes.iter().enumerate() {
298 let level = level(i as u32);
299 write!(f, "{i:04}")?;
300 if let Some(node) = node {
301 let (key_bytes, parent_hash_bytes) = match node {
302 Node::LeafNode(leaf_node) => {
303 write!(f, "\tL ")?;
304 let key_bytes = leaf_node.encryption_key().as_slice();
305 let parent_hash_bytes = leaf_node
306 .parent_hash()
307 .map(bytes_to_hex)
308 .unwrap_or_default();
309 (key_bytes, parent_hash_bytes)
310 }
311 Node::ParentNode(parent_node) => {
312 if root(tree_size) == i as u32 {
313 write!(f, "\tP (*) ")?;
314 } else {
315 write!(f, "\tP ")?;
316 }
317 let key_bytes = parent_node.public_key().as_slice();
318 let parent_hash_string = bytes_to_hex(parent_node.parent_hash());
319 (key_bytes, parent_hash_string)
320 }
321 };
322 write!(
323 f,
324 "PK: {} PH: {} | ",
325 bytes_to_hex(key_bytes),
326 if !parent_hash_bytes.is_empty() {
327 parent_hash_bytes
328 } else {
329 str::repeat(" ", 32)
330 }
331 )?;
332
333 write!(f, "{}◼︎", str::repeat(" ", level * factor))?;
334 } else {
335 if root(tree_size) == i as u32 {
336 write!(
337 f,
338 "\t_ (*) PK: {} PH: {} | ",
339 str::repeat("__", 32),
340 str::repeat("__", 32)
341 )?;
342 } else {
343 write!(
344 f,
345 "\t_ PK: {} PH: {} | ",
346 str::repeat("__", 32),
347 str::repeat("__", 32)
348 )?;
349 }
350
351 write!(f, "{}❑", str::repeat(" ", level * factor))?;
352 }
353 writeln!(f)?;
354 }
355
356 Ok(())
357 }
358}
359
360#[derive(Debug, Serialize, Deserialize)]
373#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq, Clone))]
374pub(crate) struct TreeSync {
375 tree: MlsBinaryTree<TreeSyncLeafNode, TreeSyncParentNode>,
376 tree_hash: Vec<u8>,
377}
378
379impl TreeSync {
380 pub(crate) fn new(
385 provider: &impl OpenMlsProvider,
386 signer: &impl Signer,
387 ciphersuite: Ciphersuite,
388 credential_with_key: CredentialWithKey,
389 life_time: Lifetime,
390 capabilities: Capabilities,
391 extensions: Extensions,
392 ) -> Result<(Self, CommitSecret, EncryptionKeyPair), LibraryError> {
393 let new_leaf_node_params = NewLeafNodeParams {
394 ciphersuite,
395 credential_with_key,
396 leaf_node_source: LeafNodeSource::KeyPackage(life_time),
398 capabilities,
399 extensions,
400 tree_info_tbs: TreeInfoTbs::KeyPackage,
401 };
402 let (leaf, encryption_key_pair) = LeafNode::new(provider, signer, new_leaf_node_params)?;
403
404 let node = Node::leaf_node(leaf);
405 let path_secret: PathSecret = Secret::random(ciphersuite, provider.rand())
406 .map_err(LibraryError::unexpected_crypto_error)?
407 .into();
408 let commit_secret: CommitSecret = path_secret
409 .derive_path_secret(provider.crypto(), ciphersuite)?
410 .into();
411 let nodes = vec![TreeSyncNode::from(node).into()];
412 let tree = MlsBinaryTree::new(nodes)
413 .map_err(|_| LibraryError::custom("Unexpected error creating the binary tree."))?;
414 let mut tree_sync = Self {
415 tree,
416 tree_hash: vec![],
417 };
418 tree_sync.populate_parent_hashes(provider.crypto(), ciphersuite)?;
420
421 Ok((tree_sync, commit_secret, encryption_key_pair))
422 }
423
424 pub(crate) fn tree_hash(&self) -> &[u8] {
426 self.tree_hash.as_slice()
427 }
428
429 pub(crate) fn merge_diff(&mut self, tree_sync_diff: StagedTreeSyncDiff) {
432 let (diff, new_tree_hash) = tree_sync_diff.into_parts();
433 self.tree_hash = new_tree_hash;
434 self.tree.merge_diff(diff);
435 }
436
437 pub(crate) fn empty_diff(&self) -> TreeSyncDiff<'_> {
440 self.into()
441 }
442
443 pub(crate) fn from_ratchet_tree(
447 crypto: &impl OpenMlsCrypto,
448 ciphersuite: Ciphersuite,
449 ratchet_tree: RatchetTree,
450 ) -> Result<Self, TreeSyncFromNodesError> {
451 let total_nodes = ratchet_tree.0.len();
453 let mut leaf_nodes = Vec::with_capacity(total_nodes.div_ceil(2));
454 let mut parent_nodes = Vec::with_capacity(total_nodes / 2);
455
456 for (node_index, node_option) in ratchet_tree.0.into_iter().enumerate() {
458 if node_index % 2 == 0 {
459 let leaf = match node_option {
460 Some(node) => match TreeSyncNode::from(node) {
461 TreeSyncNode::Leaf(l) => *l,
462 TreeSyncNode::Parent(_) => {
463 return Err(TreeSyncFromNodesError::from(
464 PublicTreeError::MalformedTree,
465 ))
466 }
467 },
468 None => TreeSyncLeafNode::blank(),
469 };
470 leaf_nodes.push(leaf);
471 } else {
472 let parent = match node_option {
473 Some(node) => match TreeSyncNode::from(node) {
474 TreeSyncNode::Parent(p) => *p,
475 TreeSyncNode::Leaf(_) => {
476 return Err(TreeSyncFromNodesError::from(
477 PublicTreeError::MalformedTree,
478 ))
479 }
480 },
481 None => TreeSyncParentNode::blank(),
482 };
483 parent_nodes.push(parent);
484 }
485 }
486
487 let tree = MlsBinaryTree::from_components(leaf_nodes, parent_nodes)
488 .map_err(|_| PublicTreeError::MalformedTree)?;
489 let mut tree_sync = Self {
490 tree,
491 tree_hash: vec![],
492 };
493
494 tree_sync
496 .verify_parent_hashes(crypto, ciphersuite)
497 .map_err(|e| match e {
498 TreeSyncParentHashError::LibraryError(e) => e.into(),
499 TreeSyncParentHashError::InvalidParentHash => {
500 TreeSyncFromNodesError::from(PublicTreeError::InvalidParentHash)
501 }
502 })?;
503
504 tree_sync.populate_parent_hashes(crypto, ciphersuite)?;
506 Ok(tree_sync)
507 }
508
509 pub(crate) fn free_leaf_index(&self) -> LeafNodeIndex {
514 let diff = self.empty_diff();
515 diff.free_leaf_index()
516 }
517
518 fn populate_parent_hashes(
520 &mut self,
521 crypto: &impl OpenMlsCrypto,
522 ciphersuite: Ciphersuite,
523 ) -> Result<(), LibraryError> {
524 let diff = self.empty_diff();
525 let staged_diff = diff.into_staged_diff(crypto, ciphersuite)?;
528 self.merge_diff(staged_diff);
530 Ok(())
531 }
532
533 fn verify_parent_hashes(
538 &self,
539 crypto: &impl OpenMlsCrypto,
540 ciphersuite: Ciphersuite,
541 ) -> Result<(), TreeSyncParentHashError> {
542 let diff = self.empty_diff();
556 diff.verify_parent_hashes(crypto, ciphersuite)
558 }
559
560 pub(crate) fn tree_size(&self) -> TreeSize {
562 self.tree.tree_size()
563 }
564
565 pub(crate) fn full_leaves(&self) -> impl Iterator<Item = &LeafNode> {
567 self.tree
568 .leaves()
569 .filter_map(|(_, tsn)| tsn.node().as_ref())
570 }
571
572 pub(crate) fn full_parents(&self) -> impl Iterator<Item = (ParentNodeIndex, &ParentNode)> {
574 self.tree
575 .parents()
576 .filter_map(|(index, tsn)| tsn.node().as_ref().map(|pn| (index, pn)))
577 }
578
579 fn rightmost_full_leaf(&self) -> LeafNodeIndex {
581 let mut index = LeafNodeIndex::new(0);
582 for (leaf_index, leaf) in self.tree.leaves() {
583 if leaf.node().as_ref().is_some() {
584 index = leaf_index;
585 }
586 }
587 index
588 }
589
590 pub(crate) fn full_leave_members(&self) -> impl Iterator<Item = Member> + '_ {
595 self.tree
596 .leaves()
597 .filter_map(|(index, tsn)| tsn.node().as_ref().map(|node| (index, node)))
599 .map(|(index, leaf_node)| {
601 Member::new(
602 index,
603 leaf_node.encryption_key().as_slice().to_vec(),
604 leaf_node.signature_key().as_slice().to_vec(),
605 leaf_node.credential().clone(),
606 )
607 })
608 }
609
610 pub fn export_ratchet_tree(&self) -> RatchetTree {
613 let mut nodes = Vec::new();
614
615 let max_length = self.rightmost_full_leaf();
617
618 let mut leaves = self
621 .tree
622 .leaves()
623 .map(|(_, leaf)| leaf)
624 .take(max_length.usize() + 1);
625
626 if let Some(leaf) = leaves.next() {
628 nodes.push(leaf.node().clone().map(Node::leaf_node));
629 } else {
630 return RatchetTree::trimmed(vec![]);
632 }
633
634 let default_parent = TreeSyncParentNode::default();
636
637 let parents = self
639 .tree
640 .parents()
641 .map(|(_, parent)| parent)
643 .take(max_length.usize())
645 .chain(
647 (self.tree.parents().count()..self.tree.leaves().count() - 1)
648 .map(|_| &default_parent),
649 );
650
651 for (leaf, parent) in leaves.zip(parents) {
653 nodes.push(parent.node().clone().map(Node::parent_node));
654 nodes.push(leaf.node().clone().map(Node::leaf_node));
655 }
656
657 RatchetTree::trimmed(nodes)
658 }
659
660 pub(crate) fn leaf(&self, leaf_index: LeafNodeIndex) -> Option<&LeafNode> {
663 let tsn = self.tree.leaf(leaf_index);
664 tsn.node().as_ref()
665 }
666
667 pub(crate) fn is_leaf_in_tree(&self, leaf_index: LeafNodeIndex) -> bool {
670 is_node_in_tree(leaf_index.into(), self.tree.tree_size())
671 }
672
673 pub(crate) fn owned_encryption_keys(&self, leaf_index: LeafNodeIndex) -> Vec<EncryptionKey> {
676 self.empty_diff()
677 .encryption_keys(leaf_index)
678 .cloned()
679 .collect::<Vec<EncryptionKey>>()
680 }
681
682 pub(crate) fn derive_path_secrets(
696 &self,
697 crypto: &impl OpenMlsCrypto,
698 ciphersuite: Ciphersuite,
699 mut path_secret: PathSecret,
700 sender_index: LeafNodeIndex,
701 leaf_index: LeafNodeIndex,
702 ) -> Result<(Vec<EncryptionKeyPair>, CommitSecret), DerivePathError> {
703 let subtree_path = self.tree.subtree_path(leaf_index, sender_index);
706 let mut keypairs = Vec::new();
707 for parent_index in subtree_path {
708 let tsn = self.tree.parent_by_index(parent_index);
710 if let Some(ref parent_node) = tsn.node() {
712 if !parent_node.unmerged_leaves().contains(&leaf_index) {
715 let keypair = path_secret.derive_key_pair(crypto, ciphersuite)?;
716 if parent_node.encryption_key() != keypair.public_key() {
719 return Err(DerivePathError::PublicKeyMismatch);
720 } else {
721 keypairs.push(keypair);
724 path_secret = path_secret.derive_path_secret(crypto, ciphersuite)?;
725 }
726 };
727 }
730 }
731 Ok((keypairs, path_secret.into()))
732 }
733
734 pub(crate) fn parent(&self, node_index: ParentNodeIndex) -> Option<&ParentNode> {
737 let tsn = self.tree.parent(node_index);
738 tsn.node().as_ref()
739 }
740}
741
742#[cfg(test)]
743impl TreeSync {
744 pub(crate) fn leaf_count(&self) -> u32 {
745 self.tree.leaf_count()
746 }
747}
748
749#[cfg(test)]
750mod test {
751 use super::*;
752
753 #[cfg(debug_assertions)]
754 #[test]
755 #[should_panic]
756 fn test_ratchet_tree_internal_empty() {
758 RatchetTree::trimmed(vec![]);
759 }
760
761 #[cfg(debug_assertions)]
762 #[test]
763 #[should_panic]
764 fn test_ratchet_tree_internal_empty_after_trim() {
766 RatchetTree::trimmed(vec![None]);
767 }
768
769 #[openmls_test::openmls_test]
770 fn test_ratchet_tree_trailing_blank_nodes() {
771 let provider = &Provider::default();
772 let (key_package, _, _) = crate::key_packages::tests::key_package(ciphersuite, provider);
773 let node_in = NodeIn::from(Node::leaf_node(LeafNode::from(key_package)));
774 let tests = [
775 (vec![], false),
776 (vec![None], false),
777 (vec![None, None], false),
778 (vec![None, None, None], false),
779 (vec![Some(node_in.clone())], true),
780 (vec![Some(node_in.clone()), None], false),
781 (
782 vec![Some(node_in.clone()), None, Some(node_in.clone())],
783 true,
784 ),
785 (
786 vec![Some(node_in.clone()), None, Some(node_in), None],
787 false,
788 ),
789 ];
790
791 for (test, expected) in tests.into_iter() {
792 let got = RatchetTree::try_from_nodes(
793 ciphersuite,
794 provider.crypto(),
795 test,
796 &GroupId::random(provider.rand()),
797 )
798 .is_ok();
799 assert_eq!(got, expected);
800 }
801 }
802
803 #[cfg(not(debug_assertions))]
804 #[test]
805 fn test_ratchet_tree_internal_empty() {
807 RatchetTree::trimmed(vec![]);
808 }
809
810 #[cfg(not(debug_assertions))]
811 #[test]
812 fn test_ratchet_tree_internal_empty_after_trim() {
814 RatchetTree::trimmed(vec![None]);
815 }
816}