1#[cfg(any(feature = "test-utils", test))]
22use std::fmt;
23
24use openmls_traits::{
25 crypto::OpenMlsCrypto,
26 signatures::Signer,
27 types::{Ciphersuite, CryptoError},
28};
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
32
33use self::{
34 diff::{StagedTreeSyncDiff, TreeSyncDiff},
35 node::{
36 leaf_node::{
37 Capabilities, NewLeafNodeParams, TreeInfoTbs, TreePosition, VerifiableLeafNode,
38 },
39 NodeIn,
40 },
41 treesync_node::{TreeSyncLeafNode, TreeSyncNode, TreeSyncParentNode},
42};
43use crate::binary_tree::array_representation::ParentNodeIndex;
44#[cfg(any(feature = "test-utils", test))]
45use crate::{binary_tree::array_representation::level, test_utils::bytes_to_hex};
46use crate::{
47 binary_tree::{
48 array_representation::{is_node_in_tree, tree::TreeNode, LeafNodeIndex, TreeSize},
49 MlsBinaryTree, MlsBinaryTreeError,
50 },
51 ciphersuite::{signable::Verifiable, Secret},
52 credentials::CredentialWithKey,
53 error::LibraryError,
54 extensions::Extensions,
55 group::{GroupId, Member},
56 key_packages::Lifetime,
57 messages::{PathSecret, PathSecretError},
58 schedule::CommitSecret,
59 storage::OpenMlsProvider,
60};
61
62mod hashes;
64use errors::*;
65
66pub(crate) mod diff;
68pub(crate) mod node;
69pub(crate) mod treekem;
70pub(crate) mod treesync_node;
71
72use node::encryption_keys::EncryptionKeyPair;
73
74pub mod errors;
76#[cfg(feature = "test-utils")]
77pub use node::encryption_keys::test_utils;
78pub use node::encryption_keys::EncryptionKey;
79
80pub use node::{
82 leaf_node::{
83 LeafNode, LeafNodeParameters, LeafNodeParametersBuilder, LeafNodeSource,
84 LeafNodeUpdateError,
85 },
86 parent_node::ParentNode,
87 Node,
88};
89
90#[cfg(any(feature = "test-utils", test))]
92pub mod tests_and_kats;
93
94#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
96pub struct RatchetTree(Vec<Option<Node>>);
97
98#[derive(Error, Debug, PartialEq, Clone)]
100pub enum RatchetTreeError {
101 #[error("The ratchet tree has no nodes.")]
103 MissingNodes,
104 #[error("The ratchet tree has trailing blank nodes.")]
106 TrailingBlankNodes,
107 #[error("Invalid node signature.")]
109 InvalidNodeSignature,
110 #[error("Wrong node type.")]
112 WrongNodeType,
113}
114
115impl RatchetTree {
116 fn trimmed(mut nodes: Vec<Option<Node>>) -> Self {
120 match nodes.iter().enumerate().rfind(|(_, node)| node.is_some()) {
122 Some((rightmost_nonempty_position, _)) => {
123 nodes.resize(rightmost_nonempty_position + 1, None);
125 }
126 None => {
127 nodes.clear();
129 }
130 }
131
132 debug_assert!(!nodes.is_empty(), "Caller should have ensured that `RatchetTree::trimmed` is not called with a vector that is empty after removing all trailing blank nodes.");
133 Self(nodes)
134 }
135
136 pub(crate) fn try_from_nodes(
138 ciphersuite: Ciphersuite,
139 crypto: &impl OpenMlsCrypto,
140 nodes: Vec<Option<NodeIn>>,
141 group_id: &GroupId,
142 ) -> Result<Self, RatchetTreeError> {
143 match nodes.last() {
147 Some(None) => {
148 Err(RatchetTreeError::TrailingBlankNodes)
150 }
151 None => {
152 Err(RatchetTreeError::MissingNodes)
154 }
155 Some(Some(_)) => {
156 let mut verified_nodes = Vec::new();
161 for (index, node) in nodes.into_iter().enumerate() {
162 let verified_node = match (index % 2, node) {
163 (0, Some(NodeIn::LeafNode(leaf_node))) => {
165 let tree_position = TreePosition::new(
166 group_id.clone(),
167 LeafNodeIndex::new((index / 2) as u32),
168 );
169 let verifiable_leaf_node = leaf_node.into_verifiable_leaf_node();
170 let signature_key = verifiable_leaf_node
171 .signature_key()
172 .clone()
173 .into_signature_public_key_enriched(
174 ciphersuite.signature_algorithm(),
175 );
176 Some(Node::LeafNode(match verifiable_leaf_node {
177 VerifiableLeafNode::KeyPackage(leaf_node) => leaf_node
178 .verify(crypto, &signature_key)
179 .map_err(|_| RatchetTreeError::InvalidNodeSignature)?,
180 VerifiableLeafNode::Update(mut leaf_node) => {
181 leaf_node.add_tree_position(tree_position);
182 leaf_node
183 .verify(crypto, &signature_key)
184 .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
185 }
186 VerifiableLeafNode::Commit(mut leaf_node) => {
187 leaf_node.add_tree_position(tree_position);
188 leaf_node
189 .verify(crypto, &signature_key)
190 .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
191 }
192 }))
193 }
194 (1, Some(NodeIn::ParentNode(parent_node))) => {
196 Some(Node::ParentNode(parent_node))
197 }
198 (_, None) => None,
200 _ => {
202 return Err(RatchetTreeError::WrongNodeType);
203 }
204 };
205 verified_nodes.push(verified_node);
206 }
207 Ok(Self::trimmed(verified_nodes))
208 }
209 }
210 }
211}
212
213#[derive(
216 PartialEq,
217 Eq,
218 Clone,
219 Debug,
220 Serialize,
221 Deserialize,
222 TlsDeserialize,
223 TlsDeserializeBytes,
224 TlsSerialize,
225 TlsSize,
226)]
227pub struct RatchetTreeIn(Vec<Option<NodeIn>>);
228
229impl RatchetTreeIn {
230 pub fn into_verified(
233 self,
234 ciphersuite: Ciphersuite,
235 crypto: &impl OpenMlsCrypto,
236 group_id: &GroupId,
237 ) -> Result<RatchetTree, RatchetTreeError> {
238 RatchetTree::try_from_nodes(ciphersuite, crypto, self.0, group_id)
239 }
240
241 fn from_ratchet_tree(ratchet_tree: RatchetTree) -> Self {
242 let nodes = ratchet_tree
243 .0
244 .into_iter()
245 .map(|node| node.map(NodeIn::from))
246 .collect();
247 Self(nodes)
248 }
249
250 #[cfg(test)]
251 pub(crate) fn from_nodes(nodes: Vec<Option<NodeIn>>) -> Self {
252 Self(nodes)
253 }
254}
255
256impl From<RatchetTree> for RatchetTreeIn {
257 fn from(ratchet_tree: RatchetTree) -> Self {
258 RatchetTreeIn::from_ratchet_tree(ratchet_tree)
259 }
260}
261
262#[cfg(any(feature = "test-utils", test))]
265impl From<RatchetTreeIn> for RatchetTree {
266 fn from(ratchet_tree_in: RatchetTreeIn) -> Self {
267 Self(
268 ratchet_tree_in
269 .0
270 .into_iter()
271 .map(|node| node.map(Node::from))
272 .collect(),
273 )
274 }
275}
276
277#[cfg(any(feature = "test-utils", test))]
278fn log2(x: u32) -> usize {
279 if x == 0 {
280 return 0;
281 }
282 (31 - x.leading_zeros()) as usize
283}
284
285#[cfg(any(feature = "test-utils", test))]
286pub(crate) fn root(size: u32) -> u32 {
287 (1 << log2(size)) - 1
288}
289
290#[cfg(any(feature = "test-utils", test))]
291impl fmt::Display for RatchetTree {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 let factor = 3;
294 let nodes = &self.0;
295 let tree_size = nodes.len() as u32;
296
297 for (i, node) in nodes.iter().enumerate() {
298 let level = level(i as u32);
299 write!(f, "{i:04}")?;
300 if let Some(node) = node {
301 let (key_bytes, parent_hash_bytes) = match node {
302 Node::LeafNode(leaf_node) => {
303 write!(f, "\tL ")?;
304 let key_bytes = leaf_node.encryption_key().as_slice();
305 let parent_hash_bytes = leaf_node
306 .parent_hash()
307 .map(bytes_to_hex)
308 .unwrap_or_default();
309 (key_bytes, parent_hash_bytes)
310 }
311 Node::ParentNode(parent_node) => {
312 if root(tree_size) == i as u32 {
313 write!(f, "\tP (*) ")?;
314 } else {
315 write!(f, "\tP ")?;
316 }
317 let key_bytes = parent_node.public_key().as_slice();
318 let parent_hash_string = bytes_to_hex(parent_node.parent_hash());
319 (key_bytes, parent_hash_string)
320 }
321 };
322 write!(
323 f,
324 "PK: {} PH: {} | ",
325 bytes_to_hex(key_bytes),
326 if !parent_hash_bytes.is_empty() {
327 parent_hash_bytes
328 } else {
329 str::repeat(" ", 32)
330 }
331 )?;
332
333 write!(f, "{}◼︎", str::repeat(" ", level * factor))?;
334 } else {
335 if root(tree_size) == i as u32 {
336 write!(
337 f,
338 "\t_ (*) PK: {} PH: {} | ",
339 str::repeat("__", 32),
340 str::repeat("__", 32)
341 )?;
342 } else {
343 write!(
344 f,
345 "\t_ PK: {} PH: {} | ",
346 str::repeat("__", 32),
347 str::repeat("__", 32)
348 )?;
349 }
350
351 write!(f, "{}❑", str::repeat(" ", level * factor))?;
352 }
353 writeln!(f)?;
354 }
355
356 Ok(())
357 }
358}
359
360#[derive(Debug, Serialize, Deserialize)]
373#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq, Clone))]
374pub(crate) struct TreeSync {
375 tree: MlsBinaryTree<TreeSyncLeafNode, TreeSyncParentNode>,
376 tree_hash: Vec<u8>,
377}
378
379impl TreeSync {
380 pub(crate) fn new(
385 provider: &impl OpenMlsProvider,
386 signer: &impl Signer,
387 ciphersuite: Ciphersuite,
388 credential_with_key: CredentialWithKey,
389 life_time: Lifetime,
390 capabilities: Capabilities,
391 extensions: Extensions,
392 ) -> Result<(Self, CommitSecret, EncryptionKeyPair), LibraryError> {
393 let new_leaf_node_params = NewLeafNodeParams {
394 ciphersuite,
395 credential_with_key,
396 leaf_node_source: LeafNodeSource::KeyPackage(life_time),
398 capabilities,
399 extensions,
400 tree_info_tbs: TreeInfoTbs::KeyPackage,
401 };
402 let (leaf, encryption_key_pair) = LeafNode::new(provider, signer, new_leaf_node_params)?;
403
404 let node = Node::LeafNode(leaf);
405 let path_secret: PathSecret = Secret::random(ciphersuite, provider.rand())
406 .map_err(LibraryError::unexpected_crypto_error)?
407 .into();
408 let commit_secret: CommitSecret = path_secret
409 .derive_path_secret(provider.crypto(), ciphersuite)?
410 .into();
411 let nodes = vec![TreeSyncNode::from(node).into()];
412 let tree = MlsBinaryTree::new(nodes)
413 .map_err(|_| LibraryError::custom("Unexpected error creating the binary tree."))?;
414 let mut tree_sync = Self {
415 tree,
416 tree_hash: vec![],
417 };
418 tree_sync.populate_parent_hashes(provider.crypto(), ciphersuite)?;
420
421 Ok((tree_sync, commit_secret, encryption_key_pair))
422 }
423
424 pub(crate) fn tree_hash(&self) -> &[u8] {
426 self.tree_hash.as_slice()
427 }
428
429 pub(crate) fn merge_diff(&mut self, tree_sync_diff: StagedTreeSyncDiff) {
432 let (diff, new_tree_hash) = tree_sync_diff.into_parts();
433 self.tree_hash = new_tree_hash;
434 self.tree.merge_diff(diff);
435 }
436
437 pub(crate) fn empty_diff(&self) -> TreeSyncDiff {
440 self.into()
441 }
442
443 pub(crate) fn from_ratchet_tree(
447 crypto: &impl OpenMlsCrypto,
448 ciphersuite: Ciphersuite,
449 ratchet_tree: RatchetTree,
450 ) -> Result<Self, TreeSyncFromNodesError> {
451 let mut ts_nodes: Vec<TreeNode<TreeSyncLeafNode, TreeSyncParentNode>> =
453 Vec::with_capacity(ratchet_tree.0.len());
454
455 for (node_index, node_option) in ratchet_tree.0.into_iter().enumerate() {
457 let ts_node_option: TreeNode<TreeSyncLeafNode, TreeSyncParentNode> = match node_option {
458 Some(node) => TreeSyncNode::from(node).into(),
459 None => {
460 if node_index % 2 == 0 {
461 TreeNode::Leaf(TreeSyncLeafNode::blank())
462 } else {
463 TreeNode::Parent(TreeSyncParentNode::blank())
464 }
465 }
466 };
467 ts_nodes.push(ts_node_option);
468 }
469
470 let tree = MlsBinaryTree::new(ts_nodes).map_err(|_| PublicTreeError::MalformedTree)?;
471 let mut tree_sync = Self {
472 tree,
473 tree_hash: vec![],
474 };
475
476 tree_sync
478 .verify_parent_hashes(crypto, ciphersuite)
479 .map_err(|e| match e {
480 TreeSyncParentHashError::LibraryError(e) => e.into(),
481 TreeSyncParentHashError::InvalidParentHash => {
482 TreeSyncFromNodesError::from(PublicTreeError::InvalidParentHash)
483 }
484 })?;
485
486 tree_sync.populate_parent_hashes(crypto, ciphersuite)?;
488 Ok(tree_sync)
489 }
490
491 pub(crate) fn free_leaf_index(&self) -> LeafNodeIndex {
496 let diff = self.empty_diff();
497 diff.free_leaf_index()
498 }
499
500 fn populate_parent_hashes(
502 &mut self,
503 crypto: &impl OpenMlsCrypto,
504 ciphersuite: Ciphersuite,
505 ) -> Result<(), LibraryError> {
506 let diff = self.empty_diff();
507 let staged_diff = diff.into_staged_diff(crypto, ciphersuite)?;
510 self.merge_diff(staged_diff);
512 Ok(())
513 }
514
515 fn verify_parent_hashes(
520 &self,
521 crypto: &impl OpenMlsCrypto,
522 ciphersuite: Ciphersuite,
523 ) -> Result<(), TreeSyncParentHashError> {
524 let diff = self.empty_diff();
538 diff.verify_parent_hashes(crypto, ciphersuite)
540 }
541
542 pub(crate) fn tree_size(&self) -> TreeSize {
544 self.tree.tree_size()
545 }
546
547 pub(crate) fn full_leaves(&self) -> impl Iterator<Item = &LeafNode> {
549 self.tree
550 .leaves()
551 .filter_map(|(_, tsn)| tsn.node().as_ref())
552 }
553
554 pub(crate) fn full_parents(&self) -> impl Iterator<Item = (ParentNodeIndex, &ParentNode)> {
556 self.tree
557 .parents()
558 .filter_map(|(index, tsn)| tsn.node().as_ref().map(|pn| (index, pn)))
559 }
560
561 fn rightmost_full_leaf(&self) -> LeafNodeIndex {
563 let mut index = LeafNodeIndex::new(0);
564 for (leaf_index, leaf) in self.tree.leaves() {
565 if leaf.node().as_ref().is_some() {
566 index = leaf_index;
567 }
568 }
569 index
570 }
571
572 pub(crate) fn full_leave_members(&self) -> impl Iterator<Item = Member> + '_ {
577 self.tree
578 .leaves()
579 .filter_map(|(index, tsn)| tsn.node().as_ref().map(|node| (index, node)))
581 .map(|(index, leaf_node)| {
583 Member::new(
584 index,
585 leaf_node.encryption_key().as_slice().to_vec(),
586 leaf_node.signature_key().as_slice().to_vec(),
587 leaf_node.credential().clone(),
588 )
589 })
590 }
591
592 pub fn export_ratchet_tree(&self) -> RatchetTree {
595 let mut nodes = Vec::new();
596
597 let max_length = self.rightmost_full_leaf();
599
600 let mut leaves = self
603 .tree
604 .leaves()
605 .map(|(_, leaf)| leaf)
606 .take(max_length.usize() + 1);
607
608 if let Some(leaf) = leaves.next() {
610 nodes.push(leaf.node().clone().map(Node::LeafNode));
611 } else {
612 return RatchetTree::trimmed(vec![]);
614 }
615
616 let default_parent = TreeSyncParentNode::default();
618
619 let parents = self
621 .tree
622 .parents()
623 .map(|(_, parent)| parent)
625 .take(max_length.usize())
627 .chain(
629 (self.tree.parents().count()..self.tree.leaves().count() - 1)
630 .map(|_| &default_parent),
631 );
632
633 for (leaf, parent) in leaves.zip(parents) {
635 nodes.push(parent.node().clone().map(Node::ParentNode));
636 nodes.push(leaf.node().clone().map(Node::LeafNode));
637 }
638
639 RatchetTree::trimmed(nodes)
640 }
641
642 pub(crate) fn leaf(&self, leaf_index: LeafNodeIndex) -> Option<&LeafNode> {
645 let tsn = self.tree.leaf(leaf_index);
646 tsn.node().as_ref()
647 }
648
649 pub(crate) fn is_leaf_in_tree(&self, leaf_index: LeafNodeIndex) -> bool {
652 is_node_in_tree(leaf_index.into(), self.tree.tree_size())
653 }
654
655 pub(crate) fn owned_encryption_keys(&self, leaf_index: LeafNodeIndex) -> Vec<EncryptionKey> {
658 self.empty_diff()
659 .encryption_keys(leaf_index)
660 .cloned()
661 .collect::<Vec<EncryptionKey>>()
662 }
663
664 pub(crate) fn derive_path_secrets(
678 &self,
679 crypto: &impl OpenMlsCrypto,
680 ciphersuite: Ciphersuite,
681 mut path_secret: PathSecret,
682 sender_index: LeafNodeIndex,
683 leaf_index: LeafNodeIndex,
684 ) -> Result<(Vec<EncryptionKeyPair>, CommitSecret), DerivePathError> {
685 let subtree_path = self.tree.subtree_path(leaf_index, sender_index);
688 let mut keypairs = Vec::new();
689 for parent_index in subtree_path {
690 let tsn = self.tree.parent_by_index(parent_index);
692 if let Some(ref parent_node) = tsn.node() {
694 if !parent_node.unmerged_leaves().contains(&leaf_index) {
697 let keypair = path_secret.derive_key_pair(crypto, ciphersuite)?;
698 if parent_node.encryption_key() != keypair.public_key() {
701 return Err(DerivePathError::PublicKeyMismatch);
702 } else {
703 keypairs.push(keypair);
706 path_secret = path_secret.derive_path_secret(crypto, ciphersuite)?;
707 }
708 };
709 }
712 }
713 Ok((keypairs, path_secret.into()))
714 }
715
716 pub(crate) fn parent(&self, node_index: ParentNodeIndex) -> Option<&ParentNode> {
719 let tsn = self.tree.parent(node_index);
720 tsn.node().as_ref()
721 }
722}
723
724#[cfg(test)]
725impl TreeSync {
726 pub(crate) fn leaf_count(&self) -> u32 {
727 self.tree.leaf_count()
728 }
729}
730
731#[cfg(test)]
732mod test {
733 use super::*;
734
735 #[cfg(debug_assertions)]
736 #[test]
737 #[should_panic]
738 fn test_ratchet_tree_internal_empty() {
740 RatchetTree::trimmed(vec![]);
741 }
742
743 #[cfg(debug_assertions)]
744 #[test]
745 #[should_panic]
746 fn test_ratchet_tree_internal_empty_after_trim() {
748 RatchetTree::trimmed(vec![None]);
749 }
750
751 #[openmls_test::openmls_test]
752 fn test_ratchet_tree_trailing_blank_nodes(
753 ciphersuite: Ciphersuite,
754 provider: &impl OpenMlsProvider,
755 ) {
756 let (key_package, _, _) = crate::key_packages::tests::key_package(ciphersuite, provider);
757 let node_in = NodeIn::from(Node::LeafNode(LeafNode::from(key_package)));
758 let tests = [
759 (vec![], false),
760 (vec![None], false),
761 (vec![None, None], false),
762 (vec![None, None, None], false),
763 (vec![Some(node_in.clone())], true),
764 (vec![Some(node_in.clone()), None], false),
765 (
766 vec![Some(node_in.clone()), None, Some(node_in.clone())],
767 true,
768 ),
769 (
770 vec![Some(node_in.clone()), None, Some(node_in), None],
771 false,
772 ),
773 ];
774
775 for (test, expected) in tests.into_iter() {
776 let got = RatchetTree::try_from_nodes(
777 ciphersuite,
778 provider.crypto(),
779 test,
780 &GroupId::random(provider.rand()),
781 )
782 .is_ok();
783 assert_eq!(got, expected);
784 }
785 }
786
787 #[cfg(not(debug_assertions))]
788 #[test]
789 fn test_ratchet_tree_internal_empty() {
791 RatchetTree::trimmed(vec![]);
792 }
793
794 #[cfg(not(debug_assertions))]
795 #[test]
796 fn test_ratchet_tree_internal_empty_after_trim() {
798 RatchetTree::trimmed(vec![None]);
799 }
800}