1#[cfg(any(feature = "test-utils", test))]
22use std::fmt;
23
24use openmls_traits::{
25 crypto::OpenMlsCrypto,
26 signatures::Signer,
27 types::{Ciphersuite, CryptoError},
28};
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
32
33use self::{
34 diff::{StagedTreeSyncDiff, TreeSyncDiff},
35 node::{
36 leaf_node::{
37 Capabilities, NewLeafNodeParams, TreeInfoTbs, TreePosition, VerifiableLeafNode,
38 },
39 NodeIn,
40 },
41 treesync_node::{TreeSyncLeafNode, TreeSyncNode, TreeSyncParentNode},
42};
43use crate::binary_tree::array_representation::ParentNodeIndex;
44#[cfg(any(feature = "test-utils", test))]
45use crate::{binary_tree::array_representation::level, test_utils::bytes_to_hex};
46use crate::{
47 binary_tree::{
48 array_representation::{is_node_in_tree, LeafNodeIndex, TreeSize},
49 MlsBinaryTree, MlsBinaryTreeError,
50 },
51 ciphersuite::{signable::Verifiable, Secret},
52 credentials::CredentialWithKey,
53 error::LibraryError,
54 extensions::Extensions,
55 group::{GroupId, Member},
56 key_packages::Lifetime,
57 messages::{PathSecret, PathSecretError},
58 schedule::CommitSecret,
59 storage::OpenMlsProvider,
60};
61
62mod hashes;
64use errors::*;
65
66pub(crate) mod diff;
68pub(crate) mod node;
69pub(crate) mod treekem;
70pub(crate) mod treesync_node;
71
72use node::encryption_keys::EncryptionKeyPair;
73
74pub mod errors;
76#[cfg(feature = "test-utils")]
77pub use node::encryption_keys::test_utils;
78pub use node::encryption_keys::EncryptionKey;
79
80pub use node::{
82 leaf_node::{
83 LeafNode, LeafNodeParameters, LeafNodeParametersBuilder, LeafNodeSource,
84 LeafNodeUpdateError,
85 },
86 parent_node::ParentNode,
87 Node,
88};
89
90#[cfg(any(feature = "test-utils", test))]
92pub mod tests_and_kats;
93
94#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
96pub struct RatchetTree(Vec<Option<Node>>);
97
98#[derive(Error, Debug, PartialEq, Clone)]
100pub enum RatchetTreeError {
101 #[error("The ratchet tree has no nodes.")]
103 MissingNodes,
104 #[error("The ratchet tree has trailing blank nodes.")]
106 TrailingBlankNodes,
107 #[error("Invalid node signature.")]
109 InvalidNodeSignature,
110 #[error("Wrong node type.")]
112 WrongNodeType,
113}
114
115impl RatchetTree {
116 fn trimmed(mut nodes: Vec<Option<Node>>) -> Self {
120 match nodes.iter().enumerate().rfind(|(_, node)| node.is_some()) {
122 Some((rightmost_nonempty_position, _)) => {
123 nodes.resize(rightmost_nonempty_position + 1, None);
125 }
126 None => {
127 nodes.clear();
129 }
130 }
131
132 debug_assert!(!nodes.is_empty(), "Caller should have ensured that `RatchetTree::trimmed` is not called with a vector that is empty after removing all trailing blank nodes.");
133 Self(nodes)
134 }
135
136 pub(crate) fn try_from_nodes(
138 ciphersuite: Ciphersuite,
139 crypto: &impl OpenMlsCrypto,
140 nodes: Vec<Option<NodeIn>>,
141 group_id: &GroupId,
142 ) -> Result<Self, RatchetTreeError> {
143 match nodes.last() {
147 Some(None) => {
148 Err(RatchetTreeError::TrailingBlankNodes)
150 }
151 None => {
152 Err(RatchetTreeError::MissingNodes)
154 }
155 Some(Some(_)) => {
156 let mut verified_nodes = Vec::new();
161 for (index, node) in nodes.into_iter().enumerate() {
162 let verified_node = match (index % 2, node) {
163 (0, Some(NodeIn::LeafNode(leaf_node))) => {
165 let tree_position = TreePosition::new(
166 group_id.clone(),
167 LeafNodeIndex::new((index / 2) as u32),
168 );
169 let verifiable_leaf_node = leaf_node.into_verifiable_leaf_node();
170 let signature_key = verifiable_leaf_node
171 .signature_key()
172 .clone()
173 .into_signature_public_key_enriched(
174 ciphersuite.signature_algorithm(),
175 );
176 Some(Node::leaf_node(match verifiable_leaf_node {
177 VerifiableLeafNode::KeyPackage(leaf_node) => leaf_node
178 .verify(crypto, &signature_key)
179 .map_err(|_| RatchetTreeError::InvalidNodeSignature)?,
180 VerifiableLeafNode::Update(mut leaf_node) => {
181 leaf_node.add_tree_position(tree_position);
182 leaf_node
183 .verify(crypto, &signature_key)
184 .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
185 }
186 VerifiableLeafNode::Commit(mut leaf_node) => {
187 leaf_node.add_tree_position(tree_position);
188 leaf_node
189 .verify(crypto, &signature_key)
190 .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
191 }
192 }))
193 }
194 (1, Some(NodeIn::ParentNode(parent_node))) => {
196 Some(Node::ParentNode(parent_node))
197 }
198 (_, None) => None,
200 _ => {
202 return Err(RatchetTreeError::WrongNodeType);
203 }
204 };
205 verified_nodes.push(verified_node);
206 }
207 Ok(Self::trimmed(verified_nodes))
208 }
209 }
210 }
211}
212
213#[derive(
216 PartialEq,
217 Eq,
218 Clone,
219 Debug,
220 Serialize,
221 Deserialize,
222 TlsDeserialize,
223 TlsDeserializeBytes,
224 TlsSerialize,
225 TlsSize,
226)]
227pub struct RatchetTreeIn(Vec<Option<NodeIn>>);
228
229impl RatchetTreeIn {
230 pub fn into_verified(
233 self,
234 ciphersuite: Ciphersuite,
235 crypto: &impl OpenMlsCrypto,
236 group_id: &GroupId,
237 ) -> Result<RatchetTree, RatchetTreeError> {
238 RatchetTree::try_from_nodes(ciphersuite, crypto, self.0, group_id)
239 }
240
241 fn from_ratchet_tree(ratchet_tree: RatchetTree) -> Self {
242 let nodes = ratchet_tree
243 .0
244 .into_iter()
245 .map(|node| node.map(NodeIn::from))
246 .collect();
247 Self(nodes)
248 }
249
250 #[cfg(test)]
251 pub(crate) fn from_nodes(nodes: Vec<Option<NodeIn>>) -> Self {
252 Self(nodes)
253 }
254}
255
256impl From<RatchetTree> for RatchetTreeIn {
257 fn from(ratchet_tree: RatchetTree) -> Self {
258 RatchetTreeIn::from_ratchet_tree(ratchet_tree)
259 }
260}
261
262#[cfg(any(feature = "test-utils", test))]
265impl From<RatchetTreeIn> for RatchetTree {
266 fn from(ratchet_tree_in: RatchetTreeIn) -> Self {
267 Self(
268 ratchet_tree_in
269 .0
270 .into_iter()
271 .map(|node| node.map(Node::from))
272 .collect(),
273 )
274 }
275}
276
277#[cfg(any(feature = "test-utils", test))]
278fn log2(x: u32) -> usize {
279 if x == 0 {
280 return 0;
281 }
282 (31 - x.leading_zeros()) as usize
283}
284
285#[cfg(any(feature = "test-utils", test))]
286pub(crate) fn root(size: u32) -> u32 {
287 (1 << log2(size)) - 1
288}
289
290#[cfg(any(feature = "test-utils", test))]
291impl fmt::Display for RatchetTree {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 let factor = 3;
294 let nodes = &self.0;
295 let tree_size = nodes.len() as u32;
296
297 for (i, node) in nodes.iter().enumerate() {
298 let level = level(i as u32);
299 write!(f, "{i:04}")?;
300 if let Some(node) = node {
301 let (key_bytes, parent_hash_bytes) = match node {
302 Node::LeafNode(leaf_node) => {
303 write!(f, "\tL ")?;
304 let key_bytes = leaf_node.encryption_key().as_slice();
305 let parent_hash_bytes = leaf_node
306 .parent_hash()
307 .map(bytes_to_hex)
308 .unwrap_or_default();
309 (key_bytes, parent_hash_bytes)
310 }
311 Node::ParentNode(parent_node) => {
312 if root(tree_size) == i as u32 {
313 write!(f, "\tP (*) ")?;
314 } else {
315 write!(f, "\tP ")?;
316 }
317 let key_bytes = parent_node.public_key().as_slice();
318 let parent_hash_string = bytes_to_hex(parent_node.parent_hash());
319 (key_bytes, parent_hash_string)
320 }
321 };
322 write!(
323 f,
324 "PK: {} PH: {} | ",
325 bytes_to_hex(key_bytes),
326 if !parent_hash_bytes.is_empty() {
327 parent_hash_bytes
328 } else {
329 str::repeat(" ", 32)
330 }
331 )?;
332
333 write!(f, "{}◼︎", str::repeat(" ", level * factor))?;
334 } else {
335 if root(tree_size) == i as u32 {
336 write!(
337 f,
338 "\t_ (*) PK: {} PH: {} | ",
339 str::repeat("__", 32),
340 str::repeat("__", 32)
341 )?;
342 } else {
343 write!(
344 f,
345 "\t_ PK: {} PH: {} | ",
346 str::repeat("__", 32),
347 str::repeat("__", 32)
348 )?;
349 }
350
351 write!(f, "{}❑", str::repeat(" ", level * factor))?;
352 }
353 writeln!(f)?;
354 }
355
356 Ok(())
357 }
358}
359
360#[derive(Debug, Serialize, Deserialize)]
373#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq, Clone))]
374pub struct TreeSync {
375 tree: MlsBinaryTree<TreeSyncLeafNode, TreeSyncParentNode>,
376 tree_hash: Vec<u8>,
377}
378
379impl TreeSync {
380 pub(crate) fn new(
385 provider: &impl OpenMlsProvider,
386 signer: &impl Signer,
387 ciphersuite: Ciphersuite,
388 credential_with_key: CredentialWithKey,
389 life_time: Lifetime,
390 capabilities: Capabilities,
391 extensions: Extensions<LeafNode>,
392 ) -> Result<(Self, CommitSecret, EncryptionKeyPair), LibraryError> {
393 let new_leaf_node_params = NewLeafNodeParams {
394 ciphersuite,
395 credential_with_key,
396 leaf_node_source: LeafNodeSource::KeyPackage(life_time),
398 capabilities,
399 extensions,
400 tree_info_tbs: TreeInfoTbs::KeyPackage,
401 };
402 let (leaf, encryption_key_pair) = LeafNode::new(provider, signer, new_leaf_node_params)?;
403
404 let node = Node::leaf_node(leaf);
405 let path_secret: PathSecret = Secret::random(ciphersuite, provider.rand())
406 .map_err(LibraryError::unexpected_crypto_error)?
407 .into();
408 let commit_secret: CommitSecret = path_secret
409 .derive_path_secret(provider.crypto(), ciphersuite)?
410 .into();
411 let nodes = vec![TreeSyncNode::from(node).into()];
412 let tree = MlsBinaryTree::new(nodes)
413 .map_err(|_| LibraryError::custom("Unexpected error creating the binary tree."))?;
414 let mut tree_sync = Self {
415 tree,
416 tree_hash: vec![],
417 };
418 tree_sync.populate_parent_hashes(provider.crypto(), ciphersuite)?;
420
421 Ok((tree_sync, commit_secret, encryption_key_pair))
422 }
423
424 pub(crate) fn tree(&self) -> &MlsBinaryTree<TreeSyncLeafNode, TreeSyncParentNode> {
426 &self.tree
427 }
428
429 pub(crate) fn tree_hash(&self) -> &[u8] {
431 self.tree_hash.as_slice()
432 }
433
434 pub(crate) fn merge_diff(&mut self, tree_sync_diff: StagedTreeSyncDiff) {
437 let (diff, new_tree_hash) = tree_sync_diff.into_parts();
438 self.tree_hash = new_tree_hash;
439 self.tree.merge_diff(diff);
440 }
441
442 pub(crate) fn empty_diff(&self) -> TreeSyncDiff<'_> {
445 self.into()
446 }
447
448 pub(crate) fn from_ratchet_tree(
452 crypto: &impl OpenMlsCrypto,
453 ciphersuite: Ciphersuite,
454 ratchet_tree: RatchetTree,
455 ) -> Result<Self, TreeSyncFromNodesError> {
456 let total_nodes = ratchet_tree.0.len();
458 let mut leaf_nodes = Vec::with_capacity(total_nodes.div_ceil(2));
459 let mut parent_nodes = Vec::with_capacity(total_nodes / 2);
460
461 for (node_index, node_option) in ratchet_tree.0.into_iter().enumerate() {
463 if node_index % 2 == 0 {
464 let leaf = match node_option {
465 Some(node) => match TreeSyncNode::from(node) {
466 TreeSyncNode::Leaf(l) => *l,
467 TreeSyncNode::Parent(_) => {
468 return Err(TreeSyncFromNodesError::from(
469 PublicTreeError::MalformedTree,
470 ))
471 }
472 },
473 None => TreeSyncLeafNode::blank(),
474 };
475 leaf_nodes.push(leaf);
476 } else {
477 let parent = match node_option {
478 Some(node) => match TreeSyncNode::from(node) {
479 TreeSyncNode::Parent(p) => *p,
480 TreeSyncNode::Leaf(_) => {
481 return Err(TreeSyncFromNodesError::from(
482 PublicTreeError::MalformedTree,
483 ))
484 }
485 },
486 None => TreeSyncParentNode::blank(),
487 };
488 parent_nodes.push(parent);
489 }
490 }
491
492 let tree = MlsBinaryTree::from_components(leaf_nodes, parent_nodes)
493 .map_err(|_| PublicTreeError::MalformedTree)?;
494 let mut tree_sync = Self {
495 tree,
496 tree_hash: vec![],
497 };
498
499 tree_sync
501 .verify_parent_hashes(crypto, ciphersuite)
502 .map_err(|e| match e {
503 TreeSyncParentHashError::LibraryError(e) => e.into(),
504 TreeSyncParentHashError::InvalidParentHash => {
505 TreeSyncFromNodesError::from(PublicTreeError::InvalidParentHash)
506 }
507 })?;
508
509 tree_sync.populate_parent_hashes(crypto, ciphersuite)?;
511 Ok(tree_sync)
512 }
513
514 pub(crate) fn free_leaf_index(&self) -> LeafNodeIndex {
519 let diff = self.empty_diff();
520 diff.free_leaf_index()
521 }
522
523 fn populate_parent_hashes(
525 &mut self,
526 crypto: &impl OpenMlsCrypto,
527 ciphersuite: Ciphersuite,
528 ) -> Result<(), LibraryError> {
529 let diff = self.empty_diff();
530 let staged_diff = diff.into_staged_diff(crypto, ciphersuite)?;
533 self.merge_diff(staged_diff);
535 Ok(())
536 }
537
538 fn verify_parent_hashes(
543 &self,
544 crypto: &impl OpenMlsCrypto,
545 ciphersuite: Ciphersuite,
546 ) -> Result<(), TreeSyncParentHashError> {
547 let diff = self.empty_diff();
561 diff.verify_parent_hashes(crypto, ciphersuite)
563 }
564
565 pub(crate) fn tree_size(&self) -> TreeSize {
567 self.tree.tree_size()
568 }
569
570 pub fn full_leaves(&self) -> impl Iterator<Item = (LeafNodeIndex, &LeafNode)> {
572 self.tree
573 .leaves()
574 .filter_map(|(index, tsn)| tsn.node().as_ref().map(|ln| (index, ln)))
575 }
576
577 pub fn full_parents(&self) -> impl Iterator<Item = (ParentNodeIndex, &ParentNode)> {
579 self.tree
580 .parents()
581 .filter_map(|(index, tsn)| tsn.node().as_ref().map(|pn| (index, pn)))
582 }
583
584 pub fn blank_parents<'a>(&'a self) -> impl Iterator<Item = ParentNodeIndex> + 'a {
586 self.tree
587 .parents()
588 .filter_map(|(index, tsn)| tsn.node().as_ref().map_or(Some(index), |_| None))
589 }
590
591 pub fn blank_leaves<'a>(&'a self) -> impl Iterator<Item = LeafNodeIndex> + 'a {
593 self.tree
594 .leaves()
595 .filter_map(|(index, tsn)| tsn.node().as_ref().map_or(Some(index), |_| None))
596 }
597
598 fn rightmost_full_leaf(&self) -> LeafNodeIndex {
600 let mut index = LeafNodeIndex::new(0);
601 for (leaf_index, leaf) in self.tree.leaves() {
602 if leaf.node().as_ref().is_some() {
603 index = leaf_index;
604 }
605 }
606 index
607 }
608
609 pub(crate) fn full_leaf_members(&self) -> impl Iterator<Item = Member> + '_ {
614 self.tree
615 .leaves()
616 .filter_map(|(index, tsn)| tsn.node().as_ref().map(|node| (index, node)))
618 .map(|(index, leaf_node)| {
620 Member::new(
621 index,
622 leaf_node.encryption_key().as_slice().to_vec(),
623 leaf_node.signature_key().as_slice().to_vec(),
624 leaf_node.credential().clone(),
625 )
626 })
627 }
628
629 pub fn export_ratchet_tree(&self) -> RatchetTree {
632 let mut nodes = Vec::new();
633
634 let max_length = self.rightmost_full_leaf();
636
637 let mut leaves = self
640 .tree
641 .leaves()
642 .map(|(_, leaf)| leaf)
643 .take(max_length.usize() + 1);
644
645 if let Some(leaf) = leaves.next() {
647 nodes.push(leaf.node().clone().map(Node::leaf_node));
648 } else {
649 return RatchetTree::trimmed(vec![]);
651 }
652
653 let default_parent = TreeSyncParentNode::default();
655
656 let parents = self
658 .tree
659 .parents()
660 .map(|(_, parent)| parent)
662 .take(max_length.usize())
664 .chain(
666 (self.tree.parents().count()..self.tree.leaves().count() - 1)
667 .map(|_| &default_parent),
668 );
669
670 for (leaf, parent) in leaves.zip(parents) {
672 nodes.push(parent.node().clone().map(Node::parent_node));
673 nodes.push(leaf.node().clone().map(Node::leaf_node));
674 }
675
676 RatchetTree::trimmed(nodes)
677 }
678
679 pub(crate) fn leaf(&self, leaf_index: LeafNodeIndex) -> Option<&LeafNode> {
682 let tsn = self.tree.leaf(leaf_index);
683 tsn.node().as_ref()
684 }
685
686 pub(crate) fn is_leaf_in_tree(&self, leaf_index: LeafNodeIndex) -> bool {
689 is_node_in_tree(leaf_index.into(), self.tree.tree_size())
690 }
691
692 pub(crate) fn owned_encryption_keys(&self, leaf_index: LeafNodeIndex) -> Vec<EncryptionKey> {
695 self.empty_diff()
696 .encryption_keys(leaf_index)
697 .cloned()
698 .collect::<Vec<EncryptionKey>>()
699 }
700
701 pub(crate) fn derive_path_secrets(
715 &self,
716 crypto: &impl OpenMlsCrypto,
717 ciphersuite: Ciphersuite,
718 mut path_secret: PathSecret,
719 sender_index: LeafNodeIndex,
720 leaf_index: LeafNodeIndex,
721 ) -> Result<(Vec<EncryptionKeyPair>, CommitSecret), DerivePathError> {
722 let subtree_path = self.tree.subtree_path(leaf_index, sender_index);
725 let mut keypairs = Vec::new();
726 for parent_index in subtree_path {
727 let tsn = self.tree.parent_by_index(parent_index);
729 if let Some(ref parent_node) = tsn.node() {
731 if !parent_node.unmerged_leaves().contains(&leaf_index) {
734 let keypair = path_secret.derive_key_pair(crypto, ciphersuite)?;
735 if parent_node.encryption_key() != keypair.public_key() {
738 return Err(DerivePathError::PublicKeyMismatch);
739 } else {
740 keypairs.push(keypair);
743 path_secret = path_secret.derive_path_secret(crypto, ciphersuite)?;
744 }
745 };
746 }
749 }
750 Ok((keypairs, path_secret.into()))
751 }
752
753 pub(crate) fn parent(&self, node_index: ParentNodeIndex) -> Option<&ParentNode> {
756 let tsn = self.tree.parent(node_index);
757 tsn.node().as_ref()
758 }
759}
760
761#[cfg(test)]
762impl TreeSync {
763 pub(crate) fn leaf_count(&self) -> u32 {
764 self.tree.leaf_count()
765 }
766}
767
768#[cfg(test)]
769mod test {
770 use super::*;
771
772 #[cfg(debug_assertions)]
773 #[test]
774 #[should_panic]
775 fn test_ratchet_tree_internal_empty() {
777 RatchetTree::trimmed(vec![]);
778 }
779
780 #[cfg(debug_assertions)]
781 #[test]
782 #[should_panic]
783 fn test_ratchet_tree_internal_empty_after_trim() {
785 RatchetTree::trimmed(vec![None]);
786 }
787
788 #[openmls_test::openmls_test]
789 fn test_ratchet_tree_trailing_blank_nodes() {
790 let provider = &Provider::default();
791 let (key_package, _, _) = crate::key_packages::tests::key_package(ciphersuite, provider);
792 let node_in = NodeIn::from(Node::leaf_node(LeafNode::from(key_package)));
793 let tests = [
794 (vec![], false),
795 (vec![None], false),
796 (vec![None, None], false),
797 (vec![None, None, None], false),
798 (vec![Some(node_in.clone())], true),
799 (vec![Some(node_in.clone()), None], false),
800 (
801 vec![Some(node_in.clone()), None, Some(node_in.clone())],
802 true,
803 ),
804 (
805 vec![Some(node_in.clone()), None, Some(node_in), None],
806 false,
807 ),
808 ];
809
810 for (test, expected) in tests.into_iter() {
811 let got = RatchetTree::try_from_nodes(
812 ciphersuite,
813 provider.crypto(),
814 test,
815 &GroupId::random(provider.rand()),
816 )
817 .is_ok();
818 assert_eq!(got, expected);
819 }
820 }
821
822 #[cfg(not(debug_assertions))]
823 #[test]
824 fn test_ratchet_tree_internal_empty() {
826 RatchetTree::trimmed(vec![]);
827 }
828
829 #[cfg(not(debug_assertions))]
830 #[test]
831 fn test_ratchet_tree_internal_empty_after_trim() {
833 RatchetTree::trimmed(vec![None]);
834 }
835}