Skip to main content

openmls/schedule/pprf/
mod.rs

1//! # Puncturable Pseudorandom Function (PPRF) Implementation
2//!
3//! This module implements a PPRF using the same binary tree structure as the
4//! secret tree. In contrast to the secret tree, this implementation is generic
5//! over the size of the tree. Additionally, it is designed to be efficient even
6//! for larger sizes.
7
8use std::{borrow::Cow, collections::HashMap, fmt, marker::PhantomData};
9
10use openmls_traits::{
11    crypto::OpenMlsCrypto,
12    types::{Ciphersuite, CryptoError},
13};
14use serde::{
15    de::{self, Visitor},
16    ser::SerializeSeq,
17    Deserialize, Deserializer, Serialize, Serializer,
18};
19use thiserror::Error;
20use zeroize::Zeroize;
21
22use crate::{
23    binary_tree::array_representation::TreeSize, ciphersuite::Secret,
24    tree::secret_tree::derive_child_secrets,
25};
26
27use input::AsIndexBytes;
28use prefix::Prefix;
29
30pub use prefix::Prefix16;
31
32mod input;
33mod prefix;
34
35/// Error evaluating the PPRF at the given input.
36#[derive(Debug, Clone, Error, PartialEq)]
37pub enum PprfError {
38    /// Index out of bounds.
39    #[error("Index out of bounds")]
40    IndexOutOfBounds,
41    /// Evaluating on punctured input.
42    #[error("Evaluating on punctured input")]
43    PuncturedInput,
44    /// Error deriving child node.
45    #[error("Error deriving child node: {0}")]
46    ChildDerivationError(#[from] CryptoError),
47}
48
49/// A Node in the PPRF tree that contains the node's secret.
50///
51/// Implements transparent and custom serde for efficiently storing and reading secret bytes.
52#[derive(Debug, Clone)]
53#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
54struct PprfNode(Secret);
55
56impl Serialize for PprfNode {
57    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
58        serde_bytes::Bytes::new(self.0.as_slice()).serialize(serializer)
59    }
60}
61
62impl<'de> Deserialize<'de> for PprfNode {
63    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
64        // Some serialization formats do not support 0-copy deserialization of bytes, so we need to
65        // support owned bytes as well.
66        let bytes: Cow<[u8]> = serde_bytes::Deserialize::deserialize(deserializer)?;
67        let secret = Secret::from_slice(bytes.as_ref());
68        if let Cow::Owned(mut bytes) = bytes {
69            bytes.zeroize() // Zeroize owned bytes to prevent leaking secrets.
70        }
71        Ok(Self(secret))
72    }
73}
74
75impl PprfNode {
76    /// Derives the left and right child nodes from the current node.
77    fn derive_children(
78        &self,
79        crypto: &impl OpenMlsCrypto,
80        ciphersuite: Ciphersuite,
81    ) -> Result<(Self, Self), CryptoError> {
82        let own_secret = &self.0;
83        let (left_secret, right_secret) = derive_child_secrets(own_secret, crypto, ciphersuite)?;
84        Ok((Self(left_secret), Self(right_secret)))
85    }
86}
87
88/// The PPRF containing the tree of nodes, where each node contains a secret. It
89/// can be evaluated at a given input only once. The struct will grow in size
90/// with each evaluation.
91///
92/// The struct is generic over the prefix, which determines how individual nodes
93/// are indexed. As prefixes are stored alongside each node, small prefixes help
94/// keep the overall tree small.
95#[derive(Debug, Serialize, Deserialize, Clone)]
96#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
97pub(crate) struct Pprf<P: Prefix> {
98    #[serde(
99        serialize_with = "serialize_hashmap",
100        deserialize_with = "deserialize_hashmap"
101    )]
102    nodes: HashMap<P, PprfNode>, // Mapping of prefix and depth to node
103    width: usize,
104}
105
106/// Get the bit in the given byte slice at the given index.
107fn get_bit(index: &[u8], bit_index: usize) -> bool {
108    let byte = index[bit_index / 8];
109    let bit = 7 - (bit_index % 8); // big-endian
110    (byte >> bit) & 1 == 1
111}
112
113impl<P: Prefix> Pprf<P> {
114    /// Create a new PPRF with the given secret and size.
115    pub(super) fn new_with_size(secret: Secret, size: TreeSize) -> Self {
116        let width = size.leaf_count() as usize;
117        Pprf {
118            // The width of the tree in bytes.
119            width,
120            nodes: [(P::new(), PprfNode(secret))].into(),
121        }
122    }
123
124    #[cfg(test)]
125    pub(super) fn new_for_test(secret: Secret) -> Self {
126        let width = secret.as_slice().len();
127        Pprf {
128            // The width of the tree in bytes.
129            width,
130            nodes: [(P::new(), PprfNode(secret))].into(),
131        }
132    }
133
134    /// Evaluates the PPRF at the given input.
135    pub(super) fn evaluate<Input: AsIndexBytes>(
136        &mut self,
137        crypto: &impl OpenMlsCrypto,
138        ciphersuite: Ciphersuite,
139        input: &Input,
140    ) -> Result<Secret, PprfError> {
141        let input = input.as_index_bytes();
142        if input.len() > P::MAX_INPUT_LEN {
143            return Err(PprfError::IndexOutOfBounds);
144        }
145
146        // We interpret the input as a bit string indexing the leaf in our tree.
147        let leaf_index = input;
148
149        let mut prefix = P::new();
150        let mut current_node;
151        let mut depth = 0;
152
153        // Step 1: Find the deepest existing node in the cache
154        loop {
155            if let Some(node) = self.nodes.remove(&prefix) {
156                if depth == P::MAX_DEPTH {
157                    return Ok(node.0);
158                } // already at leaf
159                current_node = node;
160                break;
161            }
162
163            // If we reach the max depth and we didn't find a node, then
164            // the PPRF was already punctured at this index.
165            if depth == P::MAX_DEPTH {
166                return Err(PprfError::PuncturedInput);
167            }
168
169            let bit = get_bit(&leaf_index, depth);
170            prefix.push_bit(bit);
171            depth += 1;
172        }
173
174        // Step 2: Derive and walk the rest of the path
175        for d in depth..P::MAX_DEPTH {
176            let (left, right) = current_node.derive_children(crypto, ciphersuite).unwrap();
177            let bit = get_bit(&leaf_index, d);
178
179            let (next_node, copath_node) = if bit { (right, left) } else { (left, right) };
180
181            let mut copath_prefix = prefix.clone();
182            copath_prefix.push_bit(!bit);
183            let node_at_copath_prefix = self.nodes.insert(copath_prefix.clone(), copath_node);
184            debug_assert!(node_at_copath_prefix.is_none());
185
186            current_node = next_node;
187            prefix.push_bit(bit);
188        }
189
190        Ok(current_node.0)
191    }
192}
193
194fn serialize_hashmap<T, U, S>(map: &HashMap<T, U>, serializer: S) -> Result<S::Ok, S::Error>
195where
196    T: Serialize,
197    U: Serialize,
198    S: Serializer,
199{
200    let mut seq = serializer.serialize_seq(Some(map.len()))?;
201    for (k, v) in map {
202        seq.serialize_element(&(k, v))?;
203    }
204    seq.end()
205}
206
207fn deserialize_hashmap<'de, T, U, D>(deserializer: D) -> Result<HashMap<T, U>, D::Error>
208where
209    T: Eq + std::hash::Hash + Deserialize<'de>,
210    U: Deserialize<'de>,
211    D: Deserializer<'de>,
212{
213    struct TupleSeqVisitor<T, U> {
214        marker: PhantomData<(T, U)>,
215    }
216
217    impl<'de, T, U> Visitor<'de> for TupleSeqVisitor<T, U>
218    where
219        T: Eq + std::hash::Hash + Deserialize<'de>,
220        U: Deserialize<'de>,
221    {
222        type Value = HashMap<T, U>;
223
224        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
225            formatter.write_str("a sequence of tuples")
226        }
227
228        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
229        where
230            A: de::SeqAccess<'de>,
231        {
232            let mut map = HashMap::with_capacity(seq.size_hint().unwrap_or(0));
233            while let Some((k, v)) = seq.next_element::<(T, U)>()? {
234                map.insert(k, v);
235            }
236            Ok(map)
237        }
238    }
239
240    deserializer.deserialize_seq(TupleSeqVisitor {
241        marker: PhantomData,
242    })
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use openmls_test::openmls_test;
249    use rand::{
250        rngs::{OsRng, StdRng},
251        Rng, SeedableRng, TryRngCore,
252    };
253
254    fn random_vec(rng: &mut impl Rng, len: usize) -> Vec<u8> {
255        let mut bytes = vec![0u8; len];
256        rng.fill_bytes(&mut bytes);
257        bytes
258    }
259
260    fn dummy_secret(rng: &mut impl Rng, ciphersuite: Ciphersuite) -> Secret {
261        Secret::from_slice(&random_vec(rng, ciphersuite.hash_length()))
262    }
263
264    fn dummy_index<P: Prefix>(rng: &mut impl Rng) -> Vec<u8> {
265        random_vec(rng, P::MAX_INPUT_LEN)
266    }
267
268    #[openmls_test]
269    fn evaluates_single_path() {
270        let provider = &Provider::default();
271        let seed: [u8; 32] = OsRng.unwrap_mut().random();
272        println!("Seed: {:?}", seed);
273        let mut rng = StdRng::from_seed(seed);
274        let root_secret = dummy_secret(&mut rng, ciphersuite);
275        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
276        let index = dummy_index::<Prefix16>(&mut rng);
277        let crypto = provider.crypto();
278
279        let result = pprf.evaluate(crypto, ciphersuite, &index);
280        assert!(result.is_ok());
281        assert_eq!(result.as_ref().unwrap().as_slice().len(), 32);
282    }
283
284    #[openmls_test]
285    fn re_evaluation_of_same_index_returns_error() {
286        let provider = &Provider::default();
287        let seed: [u8; 32] = OsRng.unwrap_mut().random();
288        println!("Seed: {:?}", seed);
289        let mut rng = StdRng::from_seed(seed);
290        let root_secret = dummy_secret(&mut rng, ciphersuite);
291        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
292        let index = dummy_index::<Prefix16>(&mut rng);
293        let crypto = provider.crypto();
294
295        let _first = pprf.evaluate(crypto, ciphersuite, &index).unwrap();
296        let second = pprf
297            .evaluate(crypto, ciphersuite, &index)
298            .expect_err("Evaluation on same input should fail");
299
300        assert!(matches!(second, PprfError::PuncturedInput));
301    }
302
303    #[openmls_test]
304    fn different_indices_produce_different_results() {
305        let provider = &Provider::default();
306        let seed: [u8; 32] = OsRng.unwrap_mut().random();
307        println!("Seed: {:?}", seed);
308        let mut rng = StdRng::from_seed(seed);
309        let root_secret = dummy_secret(&mut rng, ciphersuite);
310        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
311        let index1 = dummy_index::<Prefix16>(&mut rng);
312        let index2 = dummy_index::<Prefix16>(&mut rng);
313        let crypto = provider.crypto();
314
315        let leaf1 = pprf.evaluate(crypto, ciphersuite, &index1).unwrap();
316        let leaf2 = pprf.evaluate(crypto, ciphersuite, &index2).unwrap();
317
318        assert_ne!(leaf1.as_slice(), leaf2.as_slice());
319    }
320
321    #[openmls_test]
322    fn rejects_out_of_bounds_index() {
323        let provider = &Provider::default();
324        let seed: [u8; 32] = OsRng.unwrap_mut().random();
325        println!("Seed: {:?}", seed);
326        let mut rng = StdRng::from_seed(seed);
327        let root_secret = dummy_secret(&mut rng, ciphersuite);
328        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
329        let index = random_vec(&mut rng, Prefix16::MAX_INPUT_LEN + 1); // Out of bounds
330
331        let crypto = provider.crypto();
332
333        let result = pprf.evaluate(crypto, ciphersuite, &index);
334        assert!(matches!(result, Err(PprfError::IndexOutOfBounds)));
335    }
336
337    #[openmls_test]
338    fn pprf_serialization() {
339        let provider = &Provider::default();
340        let seed: [u8; 32] = OsRng.unwrap_mut().random();
341        println!("Seed: {:?}", seed);
342        let mut rng = StdRng::from_seed(seed);
343        let root_secret = dummy_secret(&mut rng, ciphersuite);
344        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
345        let index = random_vec(&mut rng, Prefix16::MAX_INPUT_LEN);
346
347        let crypto = provider.crypto();
348
349        let result = pprf.evaluate(crypto, ciphersuite, &index);
350        assert!(result.is_ok());
351
352        let serialized = serde_json::to_string(&pprf).unwrap();
353        println!("Serialized: {}", serialized);
354        let deserialized: Pprf<Prefix16> = serde_json::from_str(&serialized).unwrap();
355
356        assert_eq!(pprf, deserialized);
357    }
358}