1use std::collections::HashMap;
9
10use openmls_traits::{
11 crypto::OpenMlsCrypto,
12 types::{Ciphersuite, CryptoError},
13};
14use serde::{Deserialize, Deserializer, Serialize, Serializer};
15use thiserror::Error;
16use zeroize::ZeroizeOnDrop;
17
18use crate::{
19 binary_tree::array_representation::TreeSize, ciphersuite::Secret,
20 tree::secret_tree::derive_child_secrets,
21};
22
23use input::AsIndexBytes;
24use prefix::Prefix;
25
26pub use prefix::Prefix16;
27
28mod input;
29mod prefix;
30
31#[derive(Debug, Clone, Error, PartialEq)]
33pub enum PprfError {
34 #[error("Index out of bounds")]
36 IndexOutOfBounds,
37 #[error("Evaluating on punctured input")]
39 PuncturedInput,
40 #[error("Error deriving child node: {0}")]
42 ChildDerivationError(#[from] CryptoError),
43}
44
45#[derive(Debug, Serialize, Deserialize, Clone, ZeroizeOnDrop)]
47#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
48#[serde(transparent)]
49struct PprfNode(#[serde(with = "serde_bytes")] Vec<u8>);
50
51impl From<Secret> for PprfNode {
52 fn from(secret: Secret) -> Self {
53 Self(secret.as_slice().to_vec())
54 }
55}
56
57impl From<PprfNode> for Secret {
58 fn from(node: PprfNode) -> Self {
59 Secret::from_slice(&node.0)
60 }
61}
62
63impl PprfNode {
64 fn derive_children(
66 &self,
67 crypto: &impl OpenMlsCrypto,
68 ciphersuite: Ciphersuite,
69 ) -> Result<(Self, Self), CryptoError> {
70 let own_secret = Secret::from_slice(&self.0);
71 let (left_secret, right_secret) = derive_child_secrets(&own_secret, crypto, ciphersuite)?;
72 Ok((left_secret.into(), right_secret.into()))
73 }
74}
75
76#[derive(Debug, Serialize, Deserialize, Clone)]
84#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
85pub(crate) struct Pprf<P: Prefix> {
86 #[serde(
87 serialize_with = "serialize_hashmap",
88 deserialize_with = "deserialize_hashmap"
89 )]
90 nodes: HashMap<P, PprfNode>, width: usize,
92}
93
94fn get_bit(index: &[u8], bit_index: usize) -> bool {
96 let byte = index[bit_index / 8];
97 let bit = 7 - (bit_index % 8); (byte >> bit) & 1 == 1
99}
100
101impl<P: Prefix> Pprf<P> {
102 pub(super) fn new_with_size(secret: Secret, size: TreeSize) -> Self {
104 let width = size.leaf_count() as usize;
105 Pprf {
106 width,
108 nodes: [(P::new(), PprfNode(secret.as_slice().to_vec()))].into(),
109 }
110 }
111
112 #[cfg(test)]
113 pub(super) fn new_for_test(secret: Secret) -> Self {
114 let width = secret.as_slice().len();
115 Pprf {
116 width,
118 nodes: [(P::new(), secret.into())].into(),
119 }
120 }
121
122 pub(super) fn evaluate<Input: AsIndexBytes>(
124 &mut self,
125 crypto: &impl OpenMlsCrypto,
126 ciphersuite: Ciphersuite,
127 input: &Input,
128 ) -> Result<Secret, PprfError> {
129 let input = input.as_index_bytes();
130 if input.len() > P::MAX_INPUT_LEN {
131 return Err(PprfError::IndexOutOfBounds);
132 }
133
134 let leaf_index = input;
136
137 let mut prefix = P::new();
138 let mut current_node;
139 let mut depth = 0;
140
141 loop {
143 if let Some(node) = self.nodes.remove(&prefix) {
144 if depth == P::MAX_DEPTH {
145 return Ok(node.into());
146 } current_node = node;
148 break;
149 }
150
151 if depth == P::MAX_DEPTH {
154 return Err(PprfError::PuncturedInput);
155 }
156
157 let bit = get_bit(&leaf_index, depth);
158 prefix.push_bit(bit);
159 depth += 1;
160 }
161
162 for d in depth..P::MAX_DEPTH {
164 let (left, right) = current_node.derive_children(crypto, ciphersuite).unwrap();
165 let bit = get_bit(&leaf_index, d);
166
167 let (next_node, copath_node) = if bit { (right, left) } else { (left, right) };
168
169 let mut copath_prefix = prefix.clone();
170 copath_prefix.push_bit(!bit);
171 let node_at_copath_prefix = self.nodes.insert(copath_prefix.clone(), copath_node);
172 debug_assert!(node_at_copath_prefix.is_none());
173
174 current_node = next_node;
175 prefix.push_bit(bit);
176 }
177
178 Ok(current_node.into())
179 }
180}
181
182fn serialize_hashmap<'a, T, U, V, S>(v: &'a V, serializer: S) -> Result<S::Ok, S::Error>
183where
184 T: Serialize,
185 U: Serialize,
186 &'a V: IntoIterator<Item = (T, U)> + 'a,
187 S: Serializer,
188{
189 let vec = v.into_iter().collect::<Vec<_>>();
190 vec.serialize(serializer)
191}
192
193fn deserialize_hashmap<'de, T, U, D>(deserializer: D) -> Result<HashMap<T, U>, D::Error>
194where
195 T: Eq + std::hash::Hash + Deserialize<'de>,
196 U: Deserialize<'de>,
197 D: Deserializer<'de>,
198{
199 Ok(Vec::<(T, U)>::deserialize(deserializer)?
200 .into_iter()
201 .collect::<HashMap<T, U>>())
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use openmls_test::openmls_test;
208 use rand::{
209 rngs::{OsRng, StdRng},
210 Rng, SeedableRng,
211 };
212
213 fn random_vec(rng: &mut impl Rng, len: usize) -> Vec<u8> {
214 let mut bytes = vec![0u8; len];
215 rng.fill_bytes(&mut bytes);
216 bytes
217 }
218
219 fn dummy_secret(rng: &mut impl Rng, ciphersuite: Ciphersuite) -> Secret {
220 Secret::from_slice(&random_vec(rng, ciphersuite.hash_length()))
221 }
222
223 fn dummy_index<P: Prefix>(rng: &mut impl Rng) -> Vec<u8> {
224 random_vec(rng, P::MAX_INPUT_LEN)
225 }
226
227 #[openmls_test]
228 fn evaluates_single_path() {
229 let provider = &Provider::default();
230 let seed: [u8; 32] = OsRng.gen();
231 println!("Seed: {:?}", seed);
232 let mut rng = StdRng::from_seed(seed);
233 let root_secret = dummy_secret(&mut rng, ciphersuite);
234 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
235 let index = dummy_index::<Prefix16>(&mut rng);
236 let crypto = provider.crypto();
237
238 let result = pprf.evaluate(crypto, ciphersuite, &index);
239 assert!(result.is_ok());
240 assert_eq!(result.as_ref().unwrap().as_slice().len(), 32);
241 }
242
243 #[openmls_test]
244 fn re_evaluation_of_same_index_returns_error() {
245 let provider = &Provider::default();
246 let seed: [u8; 32] = OsRng.gen();
247 println!("Seed: {:?}", seed);
248 let mut rng = StdRng::from_seed(seed);
249 let root_secret = dummy_secret(&mut rng, ciphersuite);
250 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
251 let index = dummy_index::<Prefix16>(&mut rng);
252 let crypto = provider.crypto();
253
254 let _first = pprf.evaluate(crypto, ciphersuite, &index).unwrap();
255 let second = pprf
256 .evaluate(crypto, ciphersuite, &index)
257 .expect_err("Evaluation on same input should fail");
258
259 assert!(matches!(second, PprfError::PuncturedInput));
260 }
261
262 #[openmls_test]
263 fn different_indices_produce_different_results() {
264 let provider = &Provider::default();
265 let seed: [u8; 32] = OsRng.gen();
266 println!("Seed: {:?}", seed);
267 let mut rng = StdRng::from_seed(seed);
268 let root_secret = dummy_secret(&mut rng, ciphersuite);
269 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
270 let index1 = dummy_index::<Prefix16>(&mut rng);
271 let index2 = dummy_index::<Prefix16>(&mut rng);
272 let crypto = provider.crypto();
273
274 let leaf1 = pprf.evaluate(crypto, ciphersuite, &index1).unwrap();
275 let leaf2 = pprf.evaluate(crypto, ciphersuite, &index2).unwrap();
276
277 assert_ne!(leaf1.as_slice(), leaf2.as_slice());
278 }
279
280 #[openmls_test]
281 fn rejects_out_of_bounds_index() {
282 let provider = &Provider::default();
283 let seed: [u8; 32] = OsRng.gen();
284 println!("Seed: {:?}", seed);
285 let mut rng = StdRng::from_seed(seed);
286 let root_secret = dummy_secret(&mut rng, ciphersuite);
287 let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
288 let index = random_vec(&mut rng, Prefix16::MAX_INPUT_LEN + 1); let crypto = provider.crypto();
291
292 let result = pprf.evaluate(crypto, ciphersuite, &index);
293 assert!(matches!(result, Err(PprfError::IndexOutOfBounds)));
294 }
295}