1use std::{borrow::Cow, collections::HashMap, fmt, marker::PhantomData};
9
10use openmls_traits::{
11 crypto::OpenMlsCrypto,
12 types::{Ciphersuite, CryptoError},
13};
14use serde::{
15 de::{self, Visitor},
16 ser::SerializeSeq,
17 Deserialize, Deserializer, Serialize, Serializer,
18};
19use thiserror::Error;
20use zeroize::Zeroize;
21
22use crate::{
23 binary_tree::array_representation::TreeSize, ciphersuite::Secret,
24 tree::secret_tree::derive_child_secrets,
25};
26
27use input::AsIndexBytes;
28use prefix::Prefix;
29
30pub use prefix::{Prefix16, Prefix256};
31
32mod input;
33mod prefix;
34
35#[derive(Debug, Clone, Error, PartialEq)]
37pub enum PprfError {
38 #[error("Index out of bounds")]
40 IndexOutOfBounds,
41 #[error("Evaluating on punctured input")]
43 PuncturedInput,
44 #[error("Error deriving child node: {0}")]
46 ChildDerivationError(#[from] CryptoError),
47 #[error("Prefix length exceeds maximum depth")]
49 PrefixMaxDepthExceeded,
50}
51
52#[derive(Debug, Clone)]
56#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
57struct PprfNode(Secret);
58
59impl Serialize for PprfNode {
60 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
61 serde_bytes::Bytes::new(self.0.as_slice()).serialize(serializer)
62 }
63}
64
65impl<'de> Deserialize<'de> for PprfNode {
66 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
67 let bytes: Cow<[u8]> = serde_bytes::Deserialize::deserialize(deserializer)?;
70 let secret = Secret::from_slice(bytes.as_ref());
71 if let Cow::Owned(mut bytes) = bytes {
72 bytes.zeroize() }
74 Ok(Self(secret))
75 }
76}
77
78impl PprfNode {
79 fn derive_children(
81 &self,
82 crypto: &impl OpenMlsCrypto,
83 ciphersuite: Ciphersuite,
84 ) -> Result<(Self, Self), CryptoError> {
85 let own_secret = &self.0;
86 let (left_secret, right_secret) = derive_child_secrets(own_secret, crypto, ciphersuite)?;
87 Ok((Self(left_secret), Self(right_secret)))
88 }
89}
90
91#[derive(Debug, Serialize, Deserialize, Clone)]
99#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
100pub(crate) struct Pprf<P: Prefix> {
101 #[serde(
102 serialize_with = "serialize_hashmap",
103 deserialize_with = "deserialize_hashmap"
104 )]
105 nodes: HashMap<P, PprfNode>, width: usize,
107}
108
109fn get_bit(index: &[u8], bit_index: usize) -> bool {
111 let byte = index[bit_index / 8];
112 let bit = 7 - (bit_index % 8); (byte >> bit) & 1 == 1
114}
115
116impl<P: Prefix> Pprf<P> {
117 pub(crate) fn new_with_size(secret: Secret, size: TreeSize) -> Self {
119 let width = size.leaf_count() as usize;
120 Pprf {
121 width,
123 nodes: [(P::new(), PprfNode(secret))].into(),
124 }
125 }
126
127 #[cfg(test)]
128 pub(super) fn new_for_test(secret: Secret) -> Self {
129 let width = secret.as_slice().len();
130 Pprf {
131 width,
133 nodes: [(P::new(), PprfNode(secret))].into(),
134 }
135 }
136
137 pub(crate) fn evaluate<Input: AsIndexBytes>(
139 &mut self,
140 crypto: &impl OpenMlsCrypto,
141 ciphersuite: Ciphersuite,
142 input: &Input,
143 ) -> Result<Secret, PprfError> {
144 let input = input.as_index_bytes();
145 if input.len() > P::MAX_INPUT_LEN {
146 return Err(PprfError::IndexOutOfBounds);
147 }
148
149 let leaf_index = input;
151
152 let mut prefix = P::new();
153 let mut current_node;
154 let mut depth = 0;
155
156 loop {
158 if let Some(node) = self.nodes.remove(&prefix) {
159 if depth == P::MAX_DEPTH {
160 return Ok(node.0);
161 } current_node = node;
163 break;
164 }
165
166 if depth == P::MAX_DEPTH {
169 return Err(PprfError::PuncturedInput);
170 }
171
172 let bit = get_bit(&leaf_index, depth);
173 prefix.push_bit(bit)?;
174 depth += 1;
175 }
176
177 for d in depth..P::MAX_DEPTH {
179 let (left, right) = current_node.derive_children(crypto, ciphersuite)?;
180 let bit = get_bit(&leaf_index, d);
181
182 let (next_node, copath_node) = if bit { (right, left) } else { (left, right) };
183
184 let mut copath_prefix = prefix.clone();
185 copath_prefix.push_bit(!bit)?;
186 let node_at_copath_prefix = self.nodes.insert(copath_prefix.clone(), copath_node);
187 debug_assert!(node_at_copath_prefix.is_none());
188
189 current_node = next_node;
190 prefix.push_bit(bit)?;
191 }
192
193 Ok(current_node.0)
194 }
195}
196
197fn serialize_hashmap<T, U, S>(map: &HashMap<T, U>, serializer: S) -> Result<S::Ok, S::Error>
198where
199 T: Serialize,
200 U: Serialize,
201 S: Serializer,
202{
203 let mut seq = serializer.serialize_seq(Some(map.len()))?;
204 for (k, v) in map {
205 seq.serialize_element(&(k, v))?;
206 }
207 seq.end()
208}
209
210fn deserialize_hashmap<'de, T, U, D>(deserializer: D) -> Result<HashMap<T, U>, D::Error>
211where
212 T: Eq + std::hash::Hash + Deserialize<'de>,
213 U: Deserialize<'de>,
214 D: Deserializer<'de>,
215{
216 struct TupleSeqVisitor<T, U> {
217 marker: PhantomData<(T, U)>,
218 }
219
220 impl<'de, T, U> Visitor<'de> for TupleSeqVisitor<T, U>
221 where
222 T: Eq + std::hash::Hash + Deserialize<'de>,
223 U: Deserialize<'de>,
224 {
225 type Value = HashMap<T, U>;
226
227 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
228 formatter.write_str("a sequence of tuples")
229 }
230
231 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
232 where
233 A: de::SeqAccess<'de>,
234 {
235 let mut map = HashMap::with_capacity(seq.size_hint().unwrap_or(0));
236 while let Some((k, v)) = seq.next_element::<(T, U)>()? {
237 map.insert(k, v);
238 }
239 Ok(map)
240 }
241 }
242
243 deserializer.deserialize_seq(TupleSeqVisitor {
244 marker: PhantomData,
245 })
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use openmls_test::openmls_test;
252 use rand::{rngs::StdRng, Rng, RngExt, SeedableRng};
253
254 fn random_vec(rng: &mut impl Rng, len: usize) -> Vec<u8> {
255 let mut bytes = vec![0u8; len];
256 rng.fill_bytes(&mut bytes);
257 bytes
258 }
259
260 fn dummy_secret(rng: &mut impl Rng, ciphersuite: Ciphersuite) -> Secret {
261 Secret::from_slice(&random_vec(rng, ciphersuite.hash_length()))
262 }
263
264 fn dummy_index<P: Prefix>(rng: &mut impl Rng) -> Vec<u8> {
265 random_vec(rng, P::MAX_INPUT_LEN)
266 }
267
268 #[openmls_test]
269 fn evaluates_single_path() {
270 let provider = &Provider::default();
271 let seed: [u8; 32] = rand::rng().random();
272 println!("Seed: {:?}", seed);
273 let mut rng = StdRng::from_seed(seed);
274 let root_secret = dummy_secret(&mut rng, ciphersuite);
275 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
276 let index = dummy_index::<Prefix16>(&mut rng);
277 let crypto = provider.crypto();
278
279 let result = pprf.evaluate(crypto, ciphersuite, &index);
280 assert!(result.is_ok());
281 assert_eq!(
282 result.as_ref().unwrap().as_slice().len(),
283 ciphersuite.hash_length()
284 );
285 }
286
287 #[openmls_test]
288 fn re_evaluation_of_same_index_returns_error() {
289 let provider = &Provider::default();
290 let seed: [u8; 32] = rand::rng().random();
291 println!("Seed: {:?}", seed);
292 let mut rng = StdRng::from_seed(seed);
293 let root_secret = dummy_secret(&mut rng, ciphersuite);
294 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
295 let index = dummy_index::<Prefix16>(&mut rng);
296 let crypto = provider.crypto();
297
298 let _first = pprf.evaluate(crypto, ciphersuite, &index).unwrap();
299 let second = pprf
300 .evaluate(crypto, ciphersuite, &index)
301 .expect_err("Evaluation on same input should fail");
302
303 assert!(matches!(second, PprfError::PuncturedInput));
304 }
305
306 #[openmls_test]
307 fn different_indices_produce_different_results() {
308 let provider = &Provider::default();
309 let seed: [u8; 32] = rand::rng().random();
310 println!("Seed: {:?}", seed);
311 let mut rng = StdRng::from_seed(seed);
312 let root_secret = dummy_secret(&mut rng, ciphersuite);
313 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
314 let index1 = dummy_index::<Prefix16>(&mut rng);
315 let index2 = dummy_index::<Prefix16>(&mut rng);
316 let crypto = provider.crypto();
317
318 let leaf1 = pprf.evaluate(crypto, ciphersuite, &index1).unwrap();
319 let leaf2 = pprf.evaluate(crypto, ciphersuite, &index2).unwrap();
320
321 assert_ne!(leaf1.as_slice(), leaf2.as_slice());
322 }
323
324 #[openmls_test]
325 fn rejects_out_of_bounds_index() {
326 let provider = &Provider::default();
327 let seed: [u8; 32] = rand::rng().random();
328 println!("Seed: {:?}", seed);
329 let mut rng = StdRng::from_seed(seed);
330 let root_secret = dummy_secret(&mut rng, ciphersuite);
331 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
332 let index = random_vec(&mut rng, Prefix16::MAX_INPUT_LEN + 1); let crypto = provider.crypto();
335
336 let result = pprf.evaluate(crypto, ciphersuite, &index);
337 assert!(matches!(result, Err(PprfError::IndexOutOfBounds)));
338 }
339
340 #[openmls_test]
341 fn pprf_serialization() {
342 let provider = &Provider::default();
343 let seed: [u8; 32] = rand::rng().random();
344 println!("Seed: {:?}", seed);
345 let mut rng = StdRng::from_seed(seed);
346 let root_secret = dummy_secret(&mut rng, ciphersuite);
347 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
348 let index = random_vec(&mut rng, Prefix16::MAX_INPUT_LEN);
349
350 let crypto = provider.crypto();
351
352 let result = pprf.evaluate(crypto, ciphersuite, &index);
353 assert!(result.is_ok());
354
355 let serialized = serde_json::to_string(&pprf).unwrap();
356 println!("Serialized: {}", serialized);
357 let deserialized: Pprf<Prefix16> = serde_json::from_str(&serialized).unwrap();
358
359 assert_eq!(pprf, deserialized);
360 }
361}