1use openmls_traits::crypto::OpenMlsCrypto;
5use openmls_traits::types::{Ciphersuite, HpkeCiphertext};
6#[cfg(not(target_arch = "wasm32"))]
7use rayon::prelude::*;
8use serde::{Deserialize, Serialize};
9use thiserror::*;
10use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes};
11
12use super::encryption_keys::{EncryptionKey, EncryptionKeyPair};
13use crate::{
14 binary_tree::array_representation::{LeafNodeIndex, ParentNodeIndex},
15 ciphersuite::HpkePublicKey,
16 error::LibraryError,
17 messages::PathSecret,
18 schedule::CommitSecret,
19 treesync::{hashes::ParentHashInput, treekem::UpdatePathNode},
20};
21
22#[derive(
26 Debug,
27 Eq,
28 PartialEq,
29 Clone,
30 Serialize,
31 Deserialize,
32 TlsSerialize,
33 TlsDeserialize,
34 TlsDeserializeBytes,
35 TlsSize,
36)]
37pub struct ParentNode {
38 pub(super) encryption_key: EncryptionKey,
39 pub(super) parent_hash: VLBytes,
40 pub(super) unmerged_leaves: UnmergedLeaves,
41}
42
43impl From<EncryptionKey> for ParentNode {
44 fn from(public_key: EncryptionKey) -> Self {
45 Self {
46 encryption_key: public_key,
47 parent_hash: vec![].into(),
48 unmerged_leaves: UnmergedLeaves::new(),
49 }
50 }
51}
52
53#[cfg_attr(test, derive(Clone))]
55#[derive(Debug)]
56pub(crate) struct PlainUpdatePathNode {
57 public_key: EncryptionKey,
58 path_secret: PathSecret,
59}
60
61impl PlainUpdatePathNode {
62 pub(in crate::treesync) fn encrypt(
64 &self,
65 crypto: &impl OpenMlsCrypto,
66 ciphersuite: Ciphersuite,
67 public_keys: &[EncryptionKey],
68 group_context: &[u8],
69 ) -> Result<UpdatePathNode, LibraryError> {
70 #[cfg(target_arch = "wasm32")]
71 let public_keys = public_keys.iter();
72 #[cfg(not(target_arch = "wasm32"))]
73 let public_keys = public_keys.par_iter();
74
75 public_keys
76 .map(|pk| {
77 self.path_secret
78 .encrypt(crypto, ciphersuite, pk, group_context)
79 })
80 .collect::<Result<Vec<HpkeCiphertext>, LibraryError>>()
81 .map(|encrypted_path_secrets| UpdatePathNode {
82 public_key: self.public_key.clone(),
83 encrypted_path_secrets,
84 })
85 }
86
87 pub(in crate::treesync) fn path_secret(&self) -> &PathSecret {
89 &self.path_secret
90 }
91
92 #[cfg(test)]
93 pub(crate) fn new(public_key: EncryptionKey, path_secret: PathSecret) -> Self {
94 Self {
95 public_key,
96 path_secret,
97 }
98 }
99}
100
101pub(in crate::treesync) type PathDerivationResult = (
105 Vec<(ParentNodeIndex, ParentNode)>,
106 Vec<PlainUpdatePathNode>,
107 Vec<EncryptionKeyPair>,
108 CommitSecret,
109);
110
111impl ParentNode {
112 pub(crate) fn derive_path(
118 crypto: &impl OpenMlsCrypto,
119 ciphersuite: Ciphersuite,
120 path_secret: PathSecret,
121 path_indices: Vec<ParentNodeIndex>,
122 ) -> Result<PathDerivationResult, LibraryError> {
123 let mut next_path_secret = path_secret;
124 let mut path_secrets = Vec::with_capacity(path_indices.len());
125
126 for _ in 0..path_indices.len() {
127 let path_secret = next_path_secret;
128 next_path_secret = path_secret.derive_path_secret(crypto, ciphersuite)?;
130 path_secrets.push(path_secret);
131 }
132
133 type PathDerivationResults = (
134 Vec<((ParentNodeIndex, ParentNode), EncryptionKeyPair)>,
135 Vec<PlainUpdatePathNode>,
136 );
137
138 #[cfg(not(target_arch = "wasm32"))]
141 let path_secrets = path_secrets.into_par_iter();
142 #[cfg(target_arch = "wasm32")]
143 let path_secrets = path_secrets.into_iter();
144
145 let (path_with_keypairs, update_path_nodes): PathDerivationResults = path_secrets
146 .zip(path_indices)
147 .map(|(path_secret, index)| {
148 let keypair = path_secret.derive_key_pair(crypto, ciphersuite)?;
151 let parent_node = ParentNode::from(keypair.public_key().clone());
152 let update_path_node = PlainUpdatePathNode {
155 public_key: keypair.public_key().clone(),
156 path_secret,
157 };
158 Ok((((index, parent_node), keypair), update_path_node))
159 })
160 .collect::<Result<
161 Vec<(
162 ((ParentNodeIndex, ParentNode), EncryptionKeyPair),
163 PlainUpdatePathNode,
164 )>,
165 LibraryError,
166 >>()?
167 .into_iter()
168 .unzip();
169
170 let (path, keypairs) = path_with_keypairs.into_iter().unzip();
171
172 let commit_secret = next_path_secret.into();
173 Ok((path, update_path_nodes, keypairs, commit_secret))
174 }
175
176 pub(crate) fn public_key(&self) -> &HpkePublicKey {
178 self.encryption_key.key()
179 }
180
181 pub(crate) fn encryption_key(&self) -> &EncryptionKey {
183 &self.encryption_key
184 }
185
186 pub(crate) fn unmerged_leaves(&self) -> &[LeafNodeIndex] {
188 self.unmerged_leaves.list()
189 }
190
191 pub(in crate::treesync) fn set_unmerged_leaves(&mut self, unmerged_leaves: Vec<LeafNodeIndex>) {
193 self.unmerged_leaves.set_list(unmerged_leaves);
194 }
195
196 pub(in crate::treesync) fn add_unmerged_leaf(&mut self, leaf_index: LeafNodeIndex) {
198 self.unmerged_leaves.add(leaf_index);
199 }
200
201 pub(in crate::treesync) fn compute_parent_hash(
203 &self,
204 crypto: &impl OpenMlsCrypto,
205 ciphersuite: Ciphersuite,
206 original_child_resolution: &[u8],
207 ) -> Result<Vec<u8>, LibraryError> {
208 let parent_hash_input = ParentHashInput::new(
209 self.encryption_key.key(),
210 self.parent_hash(),
211 original_child_resolution,
212 );
213 parent_hash_input.hash(crypto, ciphersuite)
214 }
215
216 pub(in crate::treesync) fn set_parent_hash(&mut self, parent_hash: Vec<u8>) {
218 self.parent_hash = parent_hash.into()
219 }
220
221 pub(crate) fn parent_hash(&self) -> &[u8] {
223 self.parent_hash.as_slice()
224 }
225}
226
227#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize, TlsSize, TlsSerialize)]
229pub(in crate::treesync) struct UnmergedLeaves {
230 list: Vec<LeafNodeIndex>,
231}
232
233impl UnmergedLeaves {
234 pub(in crate::treesync) fn new() -> Self {
235 Self { list: Vec::new() }
236 }
237
238 pub(in crate::treesync) fn add(&mut self, leaf_index: LeafNodeIndex) {
239 let position = self.list.binary_search(&leaf_index).unwrap_or_else(|e| e);
243 self.list.insert(position, leaf_index);
244 }
245
246 pub(in crate::treesync) fn list(&self) -> &[LeafNodeIndex] {
247 self.list.as_slice()
248 }
249
250 pub(in crate::treesync) fn set_list(&mut self, list: Vec<LeafNodeIndex>) {
252 self.list = list;
253 }
254}
255
256#[derive(Error, Debug)]
257pub(in crate::treesync) enum UnmergedLeavesError {
258 #[error("The list of leaves is not sorted.")]
260 NotSorted,
261}
262
263impl TryFrom<Vec<LeafNodeIndex>> for UnmergedLeaves {
264 type Error = UnmergedLeavesError;
265
266 fn try_from(list: Vec<LeafNodeIndex>) -> Result<Self, Self::Error> {
267 if !list.windows(2).all(|e| e[0] < e[1]) {
269 return Err(UnmergedLeavesError::NotSorted);
270 }
271 Ok(Self { list })
272 }
273}