1use std::{borrow::Cow, collections::HashMap, fmt, marker::PhantomData};
9
10use openmls_traits::{
11 crypto::OpenMlsCrypto,
12 types::{Ciphersuite, CryptoError},
13};
14use serde::{
15 de::{self, Visitor},
16 ser::SerializeSeq,
17 Deserialize, Deserializer, Serialize, Serializer,
18};
19use thiserror::Error;
20use zeroize::Zeroize;
21
22use crate::{
23 binary_tree::array_representation::TreeSize, ciphersuite::Secret,
24 tree::secret_tree::derive_child_secrets,
25};
26
27use input::AsIndexBytes;
28use prefix::Prefix;
29
30pub use prefix::Prefix16;
31
32mod input;
33mod prefix;
34
35#[derive(Debug, Clone, Error, PartialEq)]
37pub enum PprfError {
38 #[error("Index out of bounds")]
40 IndexOutOfBounds,
41 #[error("Evaluating on punctured input")]
43 PuncturedInput,
44 #[error("Error deriving child node: {0}")]
46 ChildDerivationError(#[from] CryptoError),
47}
48
49#[derive(Debug, Clone)]
53#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
54struct PprfNode(Secret);
55
56impl Serialize for PprfNode {
57 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
58 serde_bytes::Bytes::new(self.0.as_slice()).serialize(serializer)
59 }
60}
61
62impl<'de> Deserialize<'de> for PprfNode {
63 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
64 let bytes: Cow<[u8]> = serde_bytes::Deserialize::deserialize(deserializer)?;
67 let secret = Secret::from_slice(bytes.as_ref());
68 if let Cow::Owned(mut bytes) = bytes {
69 bytes.zeroize() }
71 Ok(Self(secret))
72 }
73}
74
75impl PprfNode {
76 fn derive_children(
78 &self,
79 crypto: &impl OpenMlsCrypto,
80 ciphersuite: Ciphersuite,
81 ) -> Result<(Self, Self), CryptoError> {
82 let own_secret = &self.0;
83 let (left_secret, right_secret) = derive_child_secrets(own_secret, crypto, ciphersuite)?;
84 Ok((Self(left_secret), Self(right_secret)))
85 }
86}
87
88#[derive(Debug, Serialize, Deserialize, Clone)]
96#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
97pub(crate) struct Pprf<P: Prefix> {
98 #[serde(
99 serialize_with = "serialize_hashmap",
100 deserialize_with = "deserialize_hashmap"
101 )]
102 nodes: HashMap<P, PprfNode>, width: usize,
104}
105
106fn get_bit(index: &[u8], bit_index: usize) -> bool {
108 let byte = index[bit_index / 8];
109 let bit = 7 - (bit_index % 8); (byte >> bit) & 1 == 1
111}
112
113impl<P: Prefix> Pprf<P> {
114 pub(super) fn new_with_size(secret: Secret, size: TreeSize) -> Self {
116 let width = size.leaf_count() as usize;
117 Pprf {
118 width,
120 nodes: [(P::new(), PprfNode(secret))].into(),
121 }
122 }
123
124 #[cfg(test)]
125 pub(super) fn new_for_test(secret: Secret) -> Self {
126 let width = secret.as_slice().len();
127 Pprf {
128 width,
130 nodes: [(P::new(), PprfNode(secret))].into(),
131 }
132 }
133
134 pub(super) fn evaluate<Input: AsIndexBytes>(
136 &mut self,
137 crypto: &impl OpenMlsCrypto,
138 ciphersuite: Ciphersuite,
139 input: &Input,
140 ) -> Result<Secret, PprfError> {
141 let input = input.as_index_bytes();
142 if input.len() > P::MAX_INPUT_LEN {
143 return Err(PprfError::IndexOutOfBounds);
144 }
145
146 let leaf_index = input;
148
149 let mut prefix = P::new();
150 let mut current_node;
151 let mut depth = 0;
152
153 loop {
155 if let Some(node) = self.nodes.remove(&prefix) {
156 if depth == P::MAX_DEPTH {
157 return Ok(node.0);
158 } current_node = node;
160 break;
161 }
162
163 if depth == P::MAX_DEPTH {
166 return Err(PprfError::PuncturedInput);
167 }
168
169 let bit = get_bit(&leaf_index, depth);
170 prefix.push_bit(bit);
171 depth += 1;
172 }
173
174 for d in depth..P::MAX_DEPTH {
176 let (left, right) = current_node.derive_children(crypto, ciphersuite).unwrap();
177 let bit = get_bit(&leaf_index, d);
178
179 let (next_node, copath_node) = if bit { (right, left) } else { (left, right) };
180
181 let mut copath_prefix = prefix.clone();
182 copath_prefix.push_bit(!bit);
183 let node_at_copath_prefix = self.nodes.insert(copath_prefix.clone(), copath_node);
184 debug_assert!(node_at_copath_prefix.is_none());
185
186 current_node = next_node;
187 prefix.push_bit(bit);
188 }
189
190 Ok(current_node.0)
191 }
192}
193
194fn serialize_hashmap<T, U, S>(map: &HashMap<T, U>, serializer: S) -> Result<S::Ok, S::Error>
195where
196 T: Serialize,
197 U: Serialize,
198 S: Serializer,
199{
200 let mut seq = serializer.serialize_seq(Some(map.len()))?;
201 for (k, v) in map {
202 seq.serialize_element(&(k, v))?;
203 }
204 seq.end()
205}
206
207fn deserialize_hashmap<'de, T, U, D>(deserializer: D) -> Result<HashMap<T, U>, D::Error>
208where
209 T: Eq + std::hash::Hash + Deserialize<'de>,
210 U: Deserialize<'de>,
211 D: Deserializer<'de>,
212{
213 struct TupleSeqVisitor<T, U> {
214 marker: PhantomData<(T, U)>,
215 }
216
217 impl<'de, T, U> Visitor<'de> for TupleSeqVisitor<T, U>
218 where
219 T: Eq + std::hash::Hash + Deserialize<'de>,
220 U: Deserialize<'de>,
221 {
222 type Value = HashMap<T, U>;
223
224 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
225 formatter.write_str("a sequence of tuples")
226 }
227
228 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
229 where
230 A: de::SeqAccess<'de>,
231 {
232 let mut map = HashMap::with_capacity(seq.size_hint().unwrap_or(0));
233 while let Some((k, v)) = seq.next_element::<(T, U)>()? {
234 map.insert(k, v);
235 }
236 Ok(map)
237 }
238 }
239
240 deserializer.deserialize_seq(TupleSeqVisitor {
241 marker: PhantomData,
242 })
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use openmls_test::openmls_test;
249 use rand::{
250 rngs::{OsRng, StdRng},
251 Rng, SeedableRng, TryRngCore,
252 };
253
254 fn random_vec(rng: &mut impl Rng, len: usize) -> Vec<u8> {
255 let mut bytes = vec![0u8; len];
256 rng.fill_bytes(&mut bytes);
257 bytes
258 }
259
260 fn dummy_secret(rng: &mut impl Rng, ciphersuite: Ciphersuite) -> Secret {
261 Secret::from_slice(&random_vec(rng, ciphersuite.hash_length()))
262 }
263
264 fn dummy_index<P: Prefix>(rng: &mut impl Rng) -> Vec<u8> {
265 random_vec(rng, P::MAX_INPUT_LEN)
266 }
267
268 #[openmls_test]
269 fn evaluates_single_path() {
270 let provider = &Provider::default();
271 let seed: [u8; 32] = OsRng.unwrap_mut().random();
272 println!("Seed: {:?}", seed);
273 let mut rng = StdRng::from_seed(seed);
274 let root_secret = dummy_secret(&mut rng, ciphersuite);
275 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
276 let index = dummy_index::<Prefix16>(&mut rng);
277 let crypto = provider.crypto();
278
279 let result = pprf.evaluate(crypto, ciphersuite, &index);
280 assert!(result.is_ok());
281 assert_eq!(result.as_ref().unwrap().as_slice().len(), 32);
282 }
283
284 #[openmls_test]
285 fn re_evaluation_of_same_index_returns_error() {
286 let provider = &Provider::default();
287 let seed: [u8; 32] = OsRng.unwrap_mut().random();
288 println!("Seed: {:?}", seed);
289 let mut rng = StdRng::from_seed(seed);
290 let root_secret = dummy_secret(&mut rng, ciphersuite);
291 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
292 let index = dummy_index::<Prefix16>(&mut rng);
293 let crypto = provider.crypto();
294
295 let _first = pprf.evaluate(crypto, ciphersuite, &index).unwrap();
296 let second = pprf
297 .evaluate(crypto, ciphersuite, &index)
298 .expect_err("Evaluation on same input should fail");
299
300 assert!(matches!(second, PprfError::PuncturedInput));
301 }
302
303 #[openmls_test]
304 fn different_indices_produce_different_results() {
305 let provider = &Provider::default();
306 let seed: [u8; 32] = OsRng.unwrap_mut().random();
307 println!("Seed: {:?}", seed);
308 let mut rng = StdRng::from_seed(seed);
309 let root_secret = dummy_secret(&mut rng, ciphersuite);
310 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
311 let index1 = dummy_index::<Prefix16>(&mut rng);
312 let index2 = dummy_index::<Prefix16>(&mut rng);
313 let crypto = provider.crypto();
314
315 let leaf1 = pprf.evaluate(crypto, ciphersuite, &index1).unwrap();
316 let leaf2 = pprf.evaluate(crypto, ciphersuite, &index2).unwrap();
317
318 assert_ne!(leaf1.as_slice(), leaf2.as_slice());
319 }
320
321 #[openmls_test]
322 fn rejects_out_of_bounds_index() {
323 let provider = &Provider::default();
324 let seed: [u8; 32] = OsRng.unwrap_mut().random();
325 println!("Seed: {:?}", seed);
326 let mut rng = StdRng::from_seed(seed);
327 let root_secret = dummy_secret(&mut rng, ciphersuite);
328 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
329 let index = random_vec(&mut rng, Prefix16::MAX_INPUT_LEN + 1); let crypto = provider.crypto();
332
333 let result = pprf.evaluate(crypto, ciphersuite, &index);
334 assert!(matches!(result, Err(PprfError::IndexOutOfBounds)));
335 }
336
337 #[openmls_test]
338 fn pprf_serialization() {
339 let provider = &Provider::default();
340 let seed: [u8; 32] = OsRng.unwrap_mut().random();
341 println!("Seed: {:?}", seed);
342 let mut rng = StdRng::from_seed(seed);
343 let root_secret = dummy_secret(&mut rng, ciphersuite);
344 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
345 let index = random_vec(&mut rng, Prefix16::MAX_INPUT_LEN);
346
347 let crypto = provider.crypto();
348
349 let result = pprf.evaluate(crypto, ciphersuite, &index);
350 assert!(result.is_ok());
351
352 let serialized = serde_json::to_string(&pprf).unwrap();
353 println!("Serialized: {}", serialized);
354 let deserialized: Pprf<Prefix16> = serde_json::from_str(&serialized).unwrap();
355
356 assert_eq!(pprf, deserialized);
357 }
358}