Skip to main content

openmls/schedule/pprf/
mod.rs

1//! # Puncturable Pseudorandom Function (PPRF) Implementation
2//!
3//! This module implements a PPRF using the same binary tree structure as the
4//! secret tree. In contrast to the secret tree, this implementation is generic
5//! over the size of the tree. Additionally, it is designed to be efficient even
6//! for larger sizes.
7
8use std::{borrow::Cow, collections::HashMap, fmt, marker::PhantomData};
9
10use openmls_traits::{
11    crypto::OpenMlsCrypto,
12    types::{Ciphersuite, CryptoError},
13};
14use serde::{
15    de::{self, Visitor},
16    ser::SerializeSeq,
17    Deserialize, Deserializer, Serialize, Serializer,
18};
19use thiserror::Error;
20use zeroize::Zeroize;
21
22use crate::{
23    binary_tree::array_representation::TreeSize, ciphersuite::Secret,
24    tree::secret_tree::derive_child_secrets,
25};
26
27use input::AsIndexBytes;
28use prefix::Prefix;
29
30pub use prefix::{Prefix16, Prefix256};
31
32mod input;
33mod prefix;
34
35/// Error evaluating the PPRF at the given input.
36#[derive(Debug, Clone, Error, PartialEq)]
37pub enum PprfError {
38    /// Index out of bounds.
39    #[error("Index out of bounds")]
40    IndexOutOfBounds,
41    /// Evaluating on punctured input.
42    #[error("Evaluating on punctured input")]
43    PuncturedInput,
44    /// Error deriving child node.
45    #[error("Error deriving child node: {0}")]
46    ChildDerivationError(#[from] CryptoError),
47    /// Prefix exceeded its maximum depth.
48    #[error("Prefix length exceeds maximum depth")]
49    PrefixMaxDepthExceeded,
50}
51
52/// A Node in the PPRF tree that contains the node's secret.
53///
54/// Implements transparent and custom serde for efficiently storing and reading secret bytes.
55#[derive(Debug, Clone)]
56#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
57struct PprfNode(Secret);
58
59impl Serialize for PprfNode {
60    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
61        serde_bytes::Bytes::new(self.0.as_slice()).serialize(serializer)
62    }
63}
64
65impl<'de> Deserialize<'de> for PprfNode {
66    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
67        // Some serialization formats do not support 0-copy deserialization of bytes, so we need to
68        // support owned bytes as well.
69        let bytes: Cow<[u8]> = serde_bytes::Deserialize::deserialize(deserializer)?;
70        let secret = Secret::from_slice(bytes.as_ref());
71        if let Cow::Owned(mut bytes) = bytes {
72            bytes.zeroize() // Zeroize owned bytes to prevent leaking secrets.
73        }
74        Ok(Self(secret))
75    }
76}
77
78impl PprfNode {
79    /// Derives the left and right child nodes from the current node.
80    fn derive_children(
81        &self,
82        crypto: &impl OpenMlsCrypto,
83        ciphersuite: Ciphersuite,
84    ) -> Result<(Self, Self), CryptoError> {
85        let own_secret = &self.0;
86        let (left_secret, right_secret) = derive_child_secrets(own_secret, crypto, ciphersuite)?;
87        Ok((Self(left_secret), Self(right_secret)))
88    }
89}
90
91/// The PPRF containing the tree of nodes, where each node contains a secret. It
92/// can be evaluated at a given input only once. The struct will grow in size
93/// with each evaluation.
94///
95/// The struct is generic over the prefix, which determines how individual nodes
96/// are indexed. As prefixes are stored alongside each node, small prefixes help
97/// keep the overall tree small.
98#[derive(Debug, Serialize, Deserialize, Clone)]
99#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
100pub(crate) struct Pprf<P: Prefix> {
101    #[serde(
102        serialize_with = "serialize_hashmap",
103        deserialize_with = "deserialize_hashmap"
104    )]
105    nodes: HashMap<P, PprfNode>, // Mapping of prefix and depth to node
106    width: usize,
107}
108
109/// Get the bit in the given byte slice at the given index.
110fn get_bit(index: &[u8], bit_index: usize) -> bool {
111    let byte = index[bit_index / 8];
112    let bit = 7 - (bit_index % 8); // big-endian
113    (byte >> bit) & 1 == 1
114}
115
116impl<P: Prefix> Pprf<P> {
117    /// Create a new PPRF with the given secret and size.
118    pub(crate) fn new_with_size(secret: Secret, size: TreeSize) -> Self {
119        let width = size.leaf_count() as usize;
120        Pprf {
121            // The width of the tree in bytes.
122            width,
123            nodes: [(P::new(), PprfNode(secret))].into(),
124        }
125    }
126
127    #[cfg(test)]
128    pub(super) fn new_for_test(secret: Secret) -> Self {
129        let width = secret.as_slice().len();
130        Pprf {
131            // The width of the tree in bytes.
132            width,
133            nodes: [(P::new(), PprfNode(secret))].into(),
134        }
135    }
136
137    /// Evaluates the PPRF at the given input.
138    pub(crate) fn evaluate<Input: AsIndexBytes>(
139        &mut self,
140        crypto: &impl OpenMlsCrypto,
141        ciphersuite: Ciphersuite,
142        input: &Input,
143    ) -> Result<Secret, PprfError> {
144        let input = input.as_index_bytes();
145        if input.len() > P::MAX_INPUT_LEN {
146            return Err(PprfError::IndexOutOfBounds);
147        }
148
149        // We interpret the input as a bit string indexing the leaf in our tree.
150        let leaf_index = input;
151
152        let mut prefix = P::new();
153        let mut current_node;
154        let mut depth = 0;
155
156        // Step 1: Find the deepest existing node in the cache
157        loop {
158            if let Some(node) = self.nodes.remove(&prefix) {
159                if depth == P::MAX_DEPTH {
160                    return Ok(node.0);
161                } // already at leaf
162                current_node = node;
163                break;
164            }
165
166            // If we reach the max depth and we didn't find a node, then
167            // the PPRF was already punctured at this index.
168            if depth == P::MAX_DEPTH {
169                return Err(PprfError::PuncturedInput);
170            }
171
172            let bit = get_bit(&leaf_index, depth);
173            prefix.push_bit(bit)?;
174            depth += 1;
175        }
176
177        // Step 2: Derive and walk the rest of the path
178        for d in depth..P::MAX_DEPTH {
179            let (left, right) = current_node.derive_children(crypto, ciphersuite)?;
180            let bit = get_bit(&leaf_index, d);
181
182            let (next_node, copath_node) = if bit { (right, left) } else { (left, right) };
183
184            let mut copath_prefix = prefix.clone();
185            copath_prefix.push_bit(!bit)?;
186            let node_at_copath_prefix = self.nodes.insert(copath_prefix.clone(), copath_node);
187            debug_assert!(node_at_copath_prefix.is_none());
188
189            current_node = next_node;
190            prefix.push_bit(bit)?;
191        }
192
193        Ok(current_node.0)
194    }
195}
196
197fn serialize_hashmap<T, U, S>(map: &HashMap<T, U>, serializer: S) -> Result<S::Ok, S::Error>
198where
199    T: Serialize,
200    U: Serialize,
201    S: Serializer,
202{
203    let mut seq = serializer.serialize_seq(Some(map.len()))?;
204    for (k, v) in map {
205        seq.serialize_element(&(k, v))?;
206    }
207    seq.end()
208}
209
210fn deserialize_hashmap<'de, T, U, D>(deserializer: D) -> Result<HashMap<T, U>, D::Error>
211where
212    T: Eq + std::hash::Hash + Deserialize<'de>,
213    U: Deserialize<'de>,
214    D: Deserializer<'de>,
215{
216    struct TupleSeqVisitor<T, U> {
217        marker: PhantomData<(T, U)>,
218    }
219
220    impl<'de, T, U> Visitor<'de> for TupleSeqVisitor<T, U>
221    where
222        T: Eq + std::hash::Hash + Deserialize<'de>,
223        U: Deserialize<'de>,
224    {
225        type Value = HashMap<T, U>;
226
227        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
228            formatter.write_str("a sequence of tuples")
229        }
230
231        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
232        where
233            A: de::SeqAccess<'de>,
234        {
235            let mut map = HashMap::with_capacity(seq.size_hint().unwrap_or(0));
236            while let Some((k, v)) = seq.next_element::<(T, U)>()? {
237                map.insert(k, v);
238            }
239            Ok(map)
240        }
241    }
242
243    deserializer.deserialize_seq(TupleSeqVisitor {
244        marker: PhantomData,
245    })
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use openmls_test::openmls_test;
252    use rand::{rngs::StdRng, Rng, RngExt, SeedableRng};
253
254    fn random_vec(rng: &mut impl Rng, len: usize) -> Vec<u8> {
255        let mut bytes = vec![0u8; len];
256        rng.fill_bytes(&mut bytes);
257        bytes
258    }
259
260    fn dummy_secret(rng: &mut impl Rng, ciphersuite: Ciphersuite) -> Secret {
261        Secret::from_slice(&random_vec(rng, ciphersuite.hash_length()))
262    }
263
264    fn dummy_index<P: Prefix>(rng: &mut impl Rng) -> Vec<u8> {
265        random_vec(rng, P::MAX_INPUT_LEN)
266    }
267
268    #[openmls_test]
269    fn evaluates_single_path() {
270        let provider = &Provider::default();
271        let seed: [u8; 32] = rand::rng().random();
272        println!("Seed: {:?}", seed);
273        let mut rng = StdRng::from_seed(seed);
274        let root_secret = dummy_secret(&mut rng, ciphersuite);
275        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
276        let index = dummy_index::<Prefix16>(&mut rng);
277        let crypto = provider.crypto();
278
279        let result = pprf.evaluate(crypto, ciphersuite, &index);
280        assert!(result.is_ok());
281        assert_eq!(
282            result.as_ref().unwrap().as_slice().len(),
283            ciphersuite.hash_length()
284        );
285    }
286
287    #[openmls_test]
288    fn re_evaluation_of_same_index_returns_error() {
289        let provider = &Provider::default();
290        let seed: [u8; 32] = rand::rng().random();
291        println!("Seed: {:?}", seed);
292        let mut rng = StdRng::from_seed(seed);
293        let root_secret = dummy_secret(&mut rng, ciphersuite);
294        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
295        let index = dummy_index::<Prefix16>(&mut rng);
296        let crypto = provider.crypto();
297
298        let _first = pprf.evaluate(crypto, ciphersuite, &index).unwrap();
299        let second = pprf
300            .evaluate(crypto, ciphersuite, &index)
301            .expect_err("Evaluation on same input should fail");
302
303        assert!(matches!(second, PprfError::PuncturedInput));
304    }
305
306    #[openmls_test]
307    fn different_indices_produce_different_results() {
308        let provider = &Provider::default();
309        let seed: [u8; 32] = rand::rng().random();
310        println!("Seed: {:?}", seed);
311        let mut rng = StdRng::from_seed(seed);
312        let root_secret = dummy_secret(&mut rng, ciphersuite);
313        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
314        let index1 = dummy_index::<Prefix16>(&mut rng);
315        let index2 = dummy_index::<Prefix16>(&mut rng);
316        let crypto = provider.crypto();
317
318        let leaf1 = pprf.evaluate(crypto, ciphersuite, &index1).unwrap();
319        let leaf2 = pprf.evaluate(crypto, ciphersuite, &index2).unwrap();
320
321        assert_ne!(leaf1.as_slice(), leaf2.as_slice());
322    }
323
324    #[openmls_test]
325    fn rejects_out_of_bounds_index() {
326        let provider = &Provider::default();
327        let seed: [u8; 32] = rand::rng().random();
328        println!("Seed: {:?}", seed);
329        let mut rng = StdRng::from_seed(seed);
330        let root_secret = dummy_secret(&mut rng, ciphersuite);
331        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
332        let index = random_vec(&mut rng, Prefix16::MAX_INPUT_LEN + 1); // Out of bounds
333
334        let crypto = provider.crypto();
335
336        let result = pprf.evaluate(crypto, ciphersuite, &index);
337        assert!(matches!(result, Err(PprfError::IndexOutOfBounds)));
338    }
339
340    #[openmls_test]
341    fn pprf_serialization() {
342        let provider = &Provider::default();
343        let seed: [u8; 32] = rand::rng().random();
344        println!("Seed: {:?}", seed);
345        let mut rng = StdRng::from_seed(seed);
346        let root_secret = dummy_secret(&mut rng, ciphersuite);
347        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
348        let index = random_vec(&mut rng, Prefix16::MAX_INPUT_LEN);
349
350        let crypto = provider.crypto();
351
352        let result = pprf.evaluate(crypto, ciphersuite, &index);
353        assert!(result.is_ok());
354
355        let serialized = serde_json::to_string(&pprf).unwrap();
356        println!("Serialized: {}", serialized);
357        let deserialized: Pprf<Prefix16> = serde_json::from_str(&serialized).unwrap();
358
359        assert_eq!(pprf, deserialized);
360    }
361}