1#[cfg(any(feature = "test-utils", test))]
22use std::fmt;
23
24use openmls_traits::{
25 crypto::OpenMlsCrypto,
26 signatures::Signer,
27 types::{Ciphersuite, CryptoError},
28};
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
32
33use self::{
34 diff::{StagedTreeSyncDiff, TreeSyncDiff},
35 node::{
36 leaf_node::{
37 Capabilities, NewLeafNodeParams, TreeInfoTbs, TreePosition, VerifiableLeafNode,
38 },
39 NodeIn,
40 },
41 treesync_node::{TreeSyncLeafNode, TreeSyncNode, TreeSyncParentNode},
42};
43#[cfg(any(feature = "test-utils", test))]
44use crate::{binary_tree::array_representation::level, test_utils::bytes_to_hex};
45use crate::{
46 binary_tree::array_representation::ParentNodeIndex, treesync::node::leaf_node::LeafNodeIn,
47};
48use crate::{
49 binary_tree::{
50 array_representation::{is_node_in_tree, LeafNodeIndex, TreeSize},
51 MlsBinaryTree, MlsBinaryTreeError,
52 },
53 ciphersuite::{signable::Verifiable, Secret},
54 credentials::CredentialWithKey,
55 error::LibraryError,
56 extensions::Extensions,
57 group::{GroupId, Member},
58 key_packages::Lifetime,
59 messages::{PathSecret, PathSecretError},
60 schedule::CommitSecret,
61 storage::OpenMlsProvider,
62};
63
64mod hashes;
66use errors::*;
67
68pub(crate) mod diff;
70pub(crate) mod node;
71pub(crate) mod treekem;
72pub(crate) mod treesync_node;
73
74use node::encryption_keys::EncryptionKeyPair;
75
76pub mod errors;
78#[cfg(feature = "test-utils")]
79pub use node::encryption_keys::test_utils;
80pub use node::encryption_keys::EncryptionKey;
81
82pub use node::{
84 leaf_node::{
85 LeafNode, LeafNodeParameters, LeafNodeParametersBuilder, LeafNodeSource,
86 LeafNodeUpdateError,
87 },
88 parent_node::ParentNode,
89 Node,
90};
91
92#[cfg(any(feature = "test-utils", test))]
94pub mod tests_and_kats;
95
96#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSerialize, TlsSize)]
98pub struct RatchetTree(Vec<Option<Node>>);
99
100#[derive(Error, Debug, PartialEq, Clone)]
102pub enum RatchetTreeError {
103 #[error("The ratchet tree has no nodes.")]
105 MissingNodes,
106 #[error("The ratchet tree has trailing blank nodes.")]
108 TrailingBlankNodes,
109 #[error("Invalid node signature.")]
111 InvalidNodeSignature,
112 #[error("Wrong node type.")]
114 WrongNodeType,
115}
116
117impl RatchetTree {
118 fn trimmed(mut nodes: Vec<Option<Node>>) -> Self {
122 match nodes.iter().enumerate().rfind(|(_, node)| node.is_some()) {
124 Some((rightmost_nonempty_position, _)) => {
125 nodes.resize(rightmost_nonempty_position + 1, None);
127 }
128 None => {
129 nodes.clear();
131 }
132 }
133
134 debug_assert!(!nodes.is_empty(), "Caller should have ensured that `RatchetTree::trimmed` is not called with a vector that is empty after removing all trailing blank nodes.");
135 Self(nodes)
136 }
137
138 pub(crate) fn try_from_nodes(
140 ciphersuite: Ciphersuite,
141 crypto: &impl OpenMlsCrypto,
142 nodes: Vec<Option<NodeIn>>,
143 group_id: &GroupId,
144 ) -> Result<Self, RatchetTreeError> {
145 match nodes.last() {
149 Some(None) => {
150 Err(RatchetTreeError::TrailingBlankNodes)
152 }
153 None => {
154 Err(RatchetTreeError::MissingNodes)
156 }
157 Some(Some(_)) => {
158 let mut verified_nodes = Vec::new();
163 for (index, node) in nodes.into_iter().enumerate() {
164 let verified_node = match (index % 2, node) {
165 (0, Some(NodeIn::LeafNode(leaf_node))) => {
167 let tree_position = TreePosition::new(
168 group_id.clone(),
169 LeafNodeIndex::new((index / 2) as u32),
170 );
171 let verifiable_leaf_node = leaf_node.into_verifiable_leaf_node();
172 let signature_key = verifiable_leaf_node
173 .signature_key()
174 .clone()
175 .into_signature_public_key_enriched(
176 ciphersuite.signature_algorithm(),
177 );
178 Some(Node::leaf_node(match verifiable_leaf_node {
179 VerifiableLeafNode::KeyPackage(leaf_node) => leaf_node
180 .verify(crypto, &signature_key)
181 .map_err(|_| RatchetTreeError::InvalidNodeSignature)?,
182 VerifiableLeafNode::Update(mut leaf_node) => {
183 leaf_node.add_tree_position(tree_position);
184 leaf_node
185 .verify(crypto, &signature_key)
186 .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
187 }
188 VerifiableLeafNode::Commit(mut leaf_node) => {
189 leaf_node.add_tree_position(tree_position);
190 leaf_node
191 .verify(crypto, &signature_key)
192 .map_err(|_| RatchetTreeError::InvalidNodeSignature)?
193 }
194 }))
195 }
196 (1, Some(NodeIn::ParentNode(parent_node))) => {
198 Some(Node::ParentNode(parent_node))
199 }
200 (_, None) => None,
202 _ => {
204 return Err(RatchetTreeError::WrongNodeType);
205 }
206 };
207 verified_nodes.push(verified_node);
208 }
209 Ok(Self::trimmed(verified_nodes))
210 }
211 }
212 }
213
214 pub fn nodes(&self) -> impl Iterator<Item = &Node> {
216 self.0.iter().flatten()
217 }
218
219 pub fn leaves(&self) -> impl Iterator<Item = &LeafNode> {
221 self.nodes().filter_map(|node| match node {
222 Node::LeafNode(leaf_node) => Some(&**leaf_node),
223 Node::ParentNode(_parent_node) => None,
224 })
225 }
226
227 pub fn parents(&self) -> impl Iterator<Item = &ParentNode> {
229 self.nodes().filter_map(|node| match node {
230 Node::ParentNode(parent_node) => Some(&**parent_node),
231 Node::LeafNode(_leaf_node) => None,
232 })
233 }
234}
235
236#[derive(
239 PartialEq,
240 Eq,
241 Clone,
242 Debug,
243 Serialize,
244 Deserialize,
245 TlsDeserialize,
246 TlsDeserializeBytes,
247 TlsSerialize,
248 TlsSize,
249)]
250pub struct RatchetTreeIn(Vec<Option<NodeIn>>);
251
252impl RatchetTreeIn {
253 pub fn into_verified(
256 self,
257 ciphersuite: Ciphersuite,
258 crypto: &impl OpenMlsCrypto,
259 group_id: &GroupId,
260 ) -> Result<RatchetTree, RatchetTreeError> {
261 RatchetTree::try_from_nodes(ciphersuite, crypto, self.0, group_id)
262 }
263
264 pub fn nodes(&self) -> impl Iterator<Item = &NodeIn> {
266 self.0.iter().flatten()
267 }
268
269 pub fn leaves(&self) -> impl Iterator<Item = &LeafNodeIn> {
271 self.nodes().filter_map(|node| match node {
272 NodeIn::LeafNode(leaf_node) => Some(&**leaf_node),
273 NodeIn::ParentNode(_parent_node) => None,
274 })
275 }
276
277 pub fn parents(&self) -> impl Iterator<Item = &ParentNode> {
279 self.nodes().filter_map(|node| match node {
280 NodeIn::ParentNode(parent_node) => Some(&**parent_node),
281 NodeIn::LeafNode(_leaf_node) => None,
282 })
283 }
284
285 fn from_ratchet_tree(ratchet_tree: RatchetTree) -> Self {
286 let nodes = ratchet_tree
287 .0
288 .into_iter()
289 .map(|node| node.map(NodeIn::from))
290 .collect();
291 Self(nodes)
292 }
293
294 #[cfg(test)]
295 pub(crate) fn from_nodes(nodes: Vec<Option<NodeIn>>) -> Self {
296 Self(nodes)
297 }
298}
299
300impl From<RatchetTree> for RatchetTreeIn {
301 fn from(ratchet_tree: RatchetTree) -> Self {
302 RatchetTreeIn::from_ratchet_tree(ratchet_tree)
303 }
304}
305
306#[cfg(any(feature = "test-utils", test))]
309impl From<RatchetTreeIn> for RatchetTree {
310 fn from(ratchet_tree_in: RatchetTreeIn) -> Self {
311 Self(
312 ratchet_tree_in
313 .0
314 .into_iter()
315 .map(|node| node.map(Node::from))
316 .collect(),
317 )
318 }
319}
320
321#[cfg(any(feature = "test-utils", test))]
322fn log2(x: u32) -> usize {
323 if x == 0 {
324 return 0;
325 }
326 (31 - x.leading_zeros()) as usize
327}
328
329#[cfg(any(feature = "test-utils", test))]
330pub(crate) fn root(size: u32) -> u32 {
331 (1 << log2(size)) - 1
332}
333
334#[cfg(any(feature = "test-utils", test))]
335impl fmt::Display for RatchetTree {
336 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337 let factor = 3;
338 let nodes = &self.0;
339 let tree_size = nodes.len() as u32;
340
341 for (i, node) in nodes.iter().enumerate() {
342 let level = level(i as u32);
343 write!(f, "{i:04}")?;
344 if let Some(node) = node {
345 let (key_bytes, parent_hash_bytes) = match node {
346 Node::LeafNode(leaf_node) => {
347 write!(f, "\tL ")?;
348 let key_bytes = leaf_node.encryption_key().as_slice();
349 let parent_hash_bytes = leaf_node
350 .parent_hash()
351 .map(bytes_to_hex)
352 .unwrap_or_default();
353 (key_bytes, parent_hash_bytes)
354 }
355 Node::ParentNode(parent_node) => {
356 if root(tree_size) == i as u32 {
357 write!(f, "\tP (*) ")?;
358 } else {
359 write!(f, "\tP ")?;
360 }
361 let key_bytes = parent_node.public_key().as_slice();
362 let parent_hash_string = bytes_to_hex(parent_node.parent_hash());
363 (key_bytes, parent_hash_string)
364 }
365 };
366 write!(
367 f,
368 "PK: {} PH: {} | ",
369 bytes_to_hex(key_bytes),
370 if !parent_hash_bytes.is_empty() {
371 parent_hash_bytes
372 } else {
373 str::repeat(" ", 32)
374 }
375 )?;
376
377 write!(f, "{}◼︎", str::repeat(" ", level * factor))?;
378 } else {
379 if root(tree_size) == i as u32 {
380 write!(
381 f,
382 "\t_ (*) PK: {} PH: {} | ",
383 str::repeat("__", 32),
384 str::repeat("__", 32)
385 )?;
386 } else {
387 write!(
388 f,
389 "\t_ PK: {} PH: {} | ",
390 str::repeat("__", 32),
391 str::repeat("__", 32)
392 )?;
393 }
394
395 write!(f, "{}❑", str::repeat(" ", level * factor))?;
396 }
397 writeln!(f)?;
398 }
399
400 Ok(())
401 }
402}
403
404#[derive(Debug, Serialize, Deserialize)]
417#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq, Clone))]
418pub struct TreeSync {
419 tree: MlsBinaryTree<TreeSyncLeafNode, TreeSyncParentNode>,
420 tree_hash: Vec<u8>,
421}
422
423impl TreeSync {
424 pub(crate) fn new(
429 provider: &impl OpenMlsProvider,
430 signer: &impl Signer,
431 ciphersuite: Ciphersuite,
432 credential_with_key: CredentialWithKey,
433 life_time: Lifetime,
434 capabilities: Capabilities,
435 extensions: Extensions<LeafNode>,
436 ) -> Result<(Self, CommitSecret, EncryptionKeyPair), LibraryError> {
437 let new_leaf_node_params = NewLeafNodeParams {
438 ciphersuite,
439 credential_with_key,
440 leaf_node_source: LeafNodeSource::KeyPackage(life_time),
442 capabilities,
443 extensions,
444 tree_info_tbs: TreeInfoTbs::KeyPackage,
445 };
446 let (leaf, encryption_key_pair) = LeafNode::new(provider, signer, new_leaf_node_params)?;
447
448 let node = Node::leaf_node(leaf);
449 let path_secret: PathSecret = Secret::random(ciphersuite, provider.rand())
450 .map_err(LibraryError::unexpected_crypto_error)?
451 .into();
452 let commit_secret: CommitSecret = path_secret
453 .derive_path_secret(provider.crypto(), ciphersuite)?
454 .into();
455 let nodes = vec![TreeSyncNode::from(node).into()];
456 let tree = MlsBinaryTree::new(nodes)
457 .map_err(|_| LibraryError::custom("Unexpected error creating the binary tree."))?;
458 let mut tree_sync = Self {
459 tree,
460 tree_hash: vec![],
461 };
462 tree_sync.populate_parent_hashes(provider.crypto(), ciphersuite)?;
464
465 Ok((tree_sync, commit_secret, encryption_key_pair))
466 }
467
468 pub(crate) fn tree(&self) -> &MlsBinaryTree<TreeSyncLeafNode, TreeSyncParentNode> {
470 &self.tree
471 }
472
473 pub(crate) fn tree_hash(&self) -> &[u8] {
475 self.tree_hash.as_slice()
476 }
477
478 pub(crate) fn merge_diff(&mut self, tree_sync_diff: StagedTreeSyncDiff) {
481 let (diff, new_tree_hash) = tree_sync_diff.into_parts();
482 self.tree_hash = new_tree_hash;
483 self.tree.merge_diff(diff);
484 }
485
486 pub(crate) fn empty_diff(&self) -> TreeSyncDiff<'_> {
489 self.into()
490 }
491
492 pub(crate) fn from_ratchet_tree(
496 crypto: &impl OpenMlsCrypto,
497 ciphersuite: Ciphersuite,
498 ratchet_tree: RatchetTree,
499 ) -> Result<Self, TreeSyncFromNodesError> {
500 let total_nodes = ratchet_tree.0.len();
502 let mut leaf_nodes = Vec::with_capacity(total_nodes.div_ceil(2));
503 let mut parent_nodes = Vec::with_capacity(total_nodes / 2);
504
505 for (node_index, node_option) in ratchet_tree.0.into_iter().enumerate() {
507 if node_index % 2 == 0 {
508 let leaf = match node_option {
509 Some(node) => match TreeSyncNode::from(node) {
510 TreeSyncNode::Leaf(l) => *l,
511 TreeSyncNode::Parent(_) => {
512 return Err(TreeSyncFromNodesError::from(
513 PublicTreeError::MalformedTree,
514 ))
515 }
516 },
517 None => TreeSyncLeafNode::blank(),
518 };
519 leaf_nodes.push(leaf);
520 } else {
521 let parent = match node_option {
522 Some(node) => match TreeSyncNode::from(node) {
523 TreeSyncNode::Parent(p) => *p,
524 TreeSyncNode::Leaf(_) => {
525 return Err(TreeSyncFromNodesError::from(
526 PublicTreeError::MalformedTree,
527 ))
528 }
529 },
530 None => TreeSyncParentNode::blank(),
531 };
532 parent_nodes.push(parent);
533 }
534 }
535
536 let tree = MlsBinaryTree::from_components(leaf_nodes, parent_nodes)
537 .map_err(|_| PublicTreeError::MalformedTree)?;
538 let mut tree_sync = Self {
539 tree,
540 tree_hash: vec![],
541 };
542
543 tree_sync
545 .verify_parent_hashes(crypto, ciphersuite)
546 .map_err(|e| match e {
547 TreeSyncParentHashError::LibraryError(e) => e.into(),
548 TreeSyncParentHashError::InvalidParentHash => {
549 TreeSyncFromNodesError::from(PublicTreeError::InvalidParentHash)
550 }
551 })?;
552
553 tree_sync.populate_parent_hashes(crypto, ciphersuite)?;
555 Ok(tree_sync)
556 }
557
558 pub(crate) fn free_leaf_index(&self) -> LeafNodeIndex {
563 let diff = self.empty_diff();
564 diff.free_leaf_index()
565 }
566
567 fn populate_parent_hashes(
569 &mut self,
570 crypto: &impl OpenMlsCrypto,
571 ciphersuite: Ciphersuite,
572 ) -> Result<(), LibraryError> {
573 let diff = self.empty_diff();
574 let staged_diff = diff.into_staged_diff(crypto, ciphersuite)?;
577 self.merge_diff(staged_diff);
579 Ok(())
580 }
581
582 fn verify_parent_hashes(
587 &self,
588 crypto: &impl OpenMlsCrypto,
589 ciphersuite: Ciphersuite,
590 ) -> Result<(), TreeSyncParentHashError> {
591 let diff = self.empty_diff();
605 diff.verify_parent_hashes(crypto, ciphersuite)
607 }
608
609 pub(crate) fn tree_size(&self) -> TreeSize {
611 self.tree.tree_size()
612 }
613
614 pub fn full_leaves(&self) -> impl Iterator<Item = (LeafNodeIndex, &LeafNode)> {
616 self.tree
617 .leaves()
618 .filter_map(|(index, tsn)| tsn.node().as_ref().map(|ln| (index, ln)))
619 }
620
621 pub fn full_parents(&self) -> impl Iterator<Item = (ParentNodeIndex, &ParentNode)> {
623 self.tree
624 .parents()
625 .filter_map(|(index, tsn)| tsn.node().as_ref().map(|pn| (index, pn)))
626 }
627
628 pub fn blank_parents<'a>(&'a self) -> impl Iterator<Item = ParentNodeIndex> + 'a {
630 self.tree
631 .parents()
632 .filter_map(|(index, tsn)| tsn.node().as_ref().map_or(Some(index), |_| None))
633 }
634
635 pub fn blank_leaves<'a>(&'a self) -> impl Iterator<Item = LeafNodeIndex> + 'a {
637 self.tree
638 .leaves()
639 .filter_map(|(index, tsn)| tsn.node().as_ref().map_or(Some(index), |_| None))
640 }
641
642 fn rightmost_full_leaf(&self) -> LeafNodeIndex {
644 let mut index = LeafNodeIndex::new(0);
645 for (leaf_index, leaf) in self.tree.leaves() {
646 if leaf.node().as_ref().is_some() {
647 index = leaf_index;
648 }
649 }
650 index
651 }
652
653 pub(crate) fn full_leaf_members(&self) -> impl Iterator<Item = Member> + '_ {
658 self.tree
659 .leaves()
660 .filter_map(|(index, tsn)| tsn.node().as_ref().map(|node| (index, node)))
662 .map(|(index, leaf_node)| {
664 Member::new(
665 index,
666 leaf_node.encryption_key().as_slice().to_vec(),
667 leaf_node.signature_key().as_slice().to_vec(),
668 leaf_node.credential().clone(),
669 )
670 })
671 }
672
673 pub fn export_ratchet_tree(&self) -> RatchetTree {
676 let mut nodes = Vec::new();
677
678 let max_length = self.rightmost_full_leaf();
680
681 let mut leaves = self
684 .tree
685 .leaves()
686 .map(|(_, leaf)| leaf)
687 .take(max_length.usize() + 1);
688
689 if let Some(leaf) = leaves.next() {
691 nodes.push(leaf.node().clone().map(Node::leaf_node));
692 } else {
693 return RatchetTree::trimmed(vec![]);
695 }
696
697 let default_parent = TreeSyncParentNode::default();
699
700 let parents = self
702 .tree
703 .parents()
704 .map(|(_, parent)| parent)
706 .take(max_length.usize())
708 .chain(
710 (self.tree.parents().count()..self.tree.leaves().count() - 1)
711 .map(|_| &default_parent),
712 );
713
714 for (leaf, parent) in leaves.zip(parents) {
716 nodes.push(parent.node().clone().map(Node::parent_node));
717 nodes.push(leaf.node().clone().map(Node::leaf_node));
718 }
719
720 RatchetTree::trimmed(nodes)
721 }
722
723 pub(crate) fn leaf(&self, leaf_index: LeafNodeIndex) -> Option<&LeafNode> {
726 let tsn = self.tree.leaf(leaf_index);
727 tsn.node().as_ref()
728 }
729
730 pub(crate) fn is_leaf_in_tree(&self, leaf_index: LeafNodeIndex) -> bool {
733 is_node_in_tree(leaf_index.into(), self.tree.tree_size())
734 }
735
736 pub(crate) fn owned_encryption_keys(&self, leaf_index: LeafNodeIndex) -> Vec<EncryptionKey> {
739 self.empty_diff()
740 .encryption_keys(leaf_index)
741 .cloned()
742 .collect::<Vec<EncryptionKey>>()
743 }
744
745 pub(crate) fn derive_path_secrets(
759 &self,
760 crypto: &impl OpenMlsCrypto,
761 ciphersuite: Ciphersuite,
762 mut path_secret: PathSecret,
763 sender_index: LeafNodeIndex,
764 leaf_index: LeafNodeIndex,
765 ) -> Result<(Vec<EncryptionKeyPair>, CommitSecret), DerivePathError> {
766 let subtree_path = self.tree.subtree_path(leaf_index, sender_index);
769 let mut keypairs = Vec::new();
770 for parent_index in subtree_path {
771 let tsn = self.tree.parent_by_index(parent_index);
773 if let Some(ref parent_node) = tsn.node() {
775 if !parent_node.unmerged_leaves().contains(&leaf_index) {
778 let keypair = path_secret.derive_key_pair(crypto, ciphersuite)?;
779 if parent_node.encryption_key() != keypair.public_key() {
782 return Err(DerivePathError::PublicKeyMismatch);
783 } else {
784 keypairs.push(keypair);
787 path_secret = path_secret.derive_path_secret(crypto, ciphersuite)?;
788 }
789 };
790 }
793 }
794 Ok((keypairs, path_secret.into()))
795 }
796
797 pub(crate) fn parent(&self, node_index: ParentNodeIndex) -> Option<&ParentNode> {
800 let tsn = self.tree.parent(node_index);
801 tsn.node().as_ref()
802 }
803}
804
805#[cfg(test)]
806impl TreeSync {
807 pub(crate) fn leaf_count(&self) -> u32 {
808 self.tree.leaf_count()
809 }
810}
811
812#[cfg(test)]
813mod test {
814 use super::*;
815
816 #[cfg(debug_assertions)]
817 #[test]
818 #[should_panic]
819 fn test_ratchet_tree_internal_empty() {
821 RatchetTree::trimmed(vec![]);
822 }
823
824 #[cfg(debug_assertions)]
825 #[test]
826 #[should_panic]
827 fn test_ratchet_tree_internal_empty_after_trim() {
829 RatchetTree::trimmed(vec![None]);
830 }
831
832 #[openmls_test::openmls_test]
833 fn test_ratchet_tree_trailing_blank_nodes() {
834 let provider = &Provider::default();
835 let (key_package, _, _) = crate::key_packages::tests::key_package(ciphersuite, provider);
836 let node_in = NodeIn::from(Node::leaf_node(LeafNode::from(key_package)));
837 let tests = [
838 (vec![], false),
839 (vec![None], false),
840 (vec![None, None], false),
841 (vec![None, None, None], false),
842 (vec![Some(node_in.clone())], true),
843 (vec![Some(node_in.clone()), None], false),
844 (
845 vec![Some(node_in.clone()), None, Some(node_in.clone())],
846 true,
847 ),
848 (
849 vec![Some(node_in.clone()), None, Some(node_in), None],
850 false,
851 ),
852 ];
853
854 for (test, expected) in tests.into_iter() {
855 let got = RatchetTree::try_from_nodes(
856 ciphersuite,
857 provider.crypto(),
858 test,
859 &GroupId::random(provider.rand()),
860 )
861 .is_ok();
862 assert_eq!(got, expected);
863 }
864 }
865
866 #[cfg(not(debug_assertions))]
867 #[test]
868 fn test_ratchet_tree_internal_empty() {
870 RatchetTree::trimmed(vec![]);
871 }
872
873 #[cfg(not(debug_assertions))]
874 #[test]
875 fn test_ratchet_tree_internal_empty_after_trim() {
877 RatchetTree::trimmed(vec![None]);
878 }
879}