openmls/schedule/pprf/
mod.rs

1//! # Puncturable Pseudorandom Function (PPRF) Implementation
2//!
3//! This module implements a PPRF using the same binary tree structure as the
4//! secret tree. In contrast to the secret tree, this implementation is generic
5//! over the size of the tree. Additionally, it is designed to be efficient even
6//! for larger sizes.
7
8use std::collections::HashMap;
9
10use openmls_traits::{
11    crypto::OpenMlsCrypto,
12    types::{Ciphersuite, CryptoError},
13};
14use serde::{Deserialize, Deserializer, Serialize, Serializer};
15use thiserror::Error;
16use zeroize::ZeroizeOnDrop;
17
18use crate::{
19    binary_tree::array_representation::TreeSize, ciphersuite::Secret,
20    tree::secret_tree::derive_child_secrets,
21};
22
23use input::AsIndexBytes;
24use prefix::Prefix;
25
26pub use prefix::Prefix16;
27
28mod input;
29mod prefix;
30
31/// Error evaluating the PPRF at the given input.
32#[derive(Debug, Clone, Error, PartialEq)]
33pub enum PprfError {
34    /// Index out of bounds.
35    #[error("Index out of bounds")]
36    IndexOutOfBounds,
37    /// Evaluating on punctured input.
38    #[error("Evaluating on punctured input")]
39    PuncturedInput,
40    /// Error deriving child node.
41    #[error("Error deriving child node: {0}")]
42    ChildDerivationError(#[from] CryptoError),
43}
44
45/// A Node in the PPRF tree that contains the node's secret.
46#[derive(Debug, Serialize, Deserialize, Clone, ZeroizeOnDrop)]
47#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
48#[serde(transparent)]
49struct PprfNode(#[serde(with = "serde_bytes")] Vec<u8>);
50
51impl From<Secret> for PprfNode {
52    fn from(secret: Secret) -> Self {
53        Self(secret.as_slice().to_vec())
54    }
55}
56
57impl From<PprfNode> for Secret {
58    fn from(node: PprfNode) -> Self {
59        Secret::from_slice(&node.0)
60    }
61}
62
63impl PprfNode {
64    /// Derives the left and right child nodes from the current node.
65    fn derive_children(
66        &self,
67        crypto: &impl OpenMlsCrypto,
68        ciphersuite: Ciphersuite,
69    ) -> Result<(Self, Self), CryptoError> {
70        let own_secret = Secret::from_slice(&self.0);
71        let (left_secret, right_secret) = derive_child_secrets(&own_secret, crypto, ciphersuite)?;
72        Ok((left_secret.into(), right_secret.into()))
73    }
74}
75
76/// The PPRF containing the tree of nodes, where each node contains a secret. It
77/// can be evaluated at a given input only once. The struct will grow in size
78/// with each evaluation.
79///
80/// The struct is generic over the prefix, which determines how individual nodes
81/// are indexed. As prefixes are stored alongside each node, small prefixes help
82/// keep the overall tree small.
83#[derive(Debug, Serialize, Deserialize, Clone)]
84#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
85pub(crate) struct Pprf<P: Prefix> {
86    #[serde(
87        serialize_with = "serialize_hashmap",
88        deserialize_with = "deserialize_hashmap"
89    )]
90    nodes: HashMap<P, PprfNode>, // Mapping of prefix and depth to node
91    width: usize,
92}
93
94/// Get the bit in the given byte slice at the given index.
95fn get_bit(index: &[u8], bit_index: usize) -> bool {
96    let byte = index[bit_index / 8];
97    let bit = 7 - (bit_index % 8); // big-endian
98    (byte >> bit) & 1 == 1
99}
100
101impl<P: Prefix> Pprf<P> {
102    /// Create a new PPRF with the given secret and size.
103    pub(super) fn new_with_size(secret: Secret, size: TreeSize) -> Self {
104        let width = size.leaf_count() as usize;
105        Pprf {
106            // The width of the tree in bytes.
107            width,
108            nodes: [(P::new(), PprfNode(secret.as_slice().to_vec()))].into(),
109        }
110    }
111
112    #[cfg(test)]
113    pub(super) fn new_for_test(secret: Secret) -> Self {
114        let width = secret.as_slice().len();
115        Pprf {
116            // The width of the tree in bytes.
117            width,
118            nodes: [(P::new(), secret.into())].into(),
119        }
120    }
121
122    /// Evaluates the PPRF at the given input.
123    pub(super) fn evaluate<Input: AsIndexBytes>(
124        &mut self,
125        crypto: &impl OpenMlsCrypto,
126        ciphersuite: Ciphersuite,
127        input: &Input,
128    ) -> Result<Secret, PprfError> {
129        let input = input.as_index_bytes();
130        if input.len() > P::MAX_INPUT_LEN {
131            return Err(PprfError::IndexOutOfBounds);
132        }
133
134        // We interpret the input as a bit string indexing the leaf in our tree.
135        let leaf_index = input;
136
137        let mut prefix = P::new();
138        let mut current_node;
139        let mut depth = 0;
140
141        // Step 1: Find the deepest existing node in the cache
142        loop {
143            if let Some(node) = self.nodes.remove(&prefix) {
144                if depth == P::MAX_DEPTH {
145                    return Ok(node.into());
146                } // already at leaf
147                current_node = node;
148                break;
149            }
150
151            // If we reach the max depth and we didn't find a node, then
152            // the PPRF was already punctured at this index.
153            if depth == P::MAX_DEPTH {
154                return Err(PprfError::PuncturedInput);
155            }
156
157            let bit = get_bit(&leaf_index, depth);
158            prefix.push_bit(bit);
159            depth += 1;
160        }
161
162        // Step 2: Derive and walk the rest of the path
163        for d in depth..P::MAX_DEPTH {
164            let (left, right) = current_node.derive_children(crypto, ciphersuite).unwrap();
165            let bit = get_bit(&leaf_index, d);
166
167            let (next_node, copath_node) = if bit { (right, left) } else { (left, right) };
168
169            let mut copath_prefix = prefix.clone();
170            copath_prefix.push_bit(!bit);
171            let node_at_copath_prefix = self.nodes.insert(copath_prefix.clone(), copath_node);
172            debug_assert!(node_at_copath_prefix.is_none());
173
174            current_node = next_node;
175            prefix.push_bit(bit);
176        }
177
178        Ok(current_node.into())
179    }
180}
181
182fn serialize_hashmap<'a, T, U, V, S>(v: &'a V, serializer: S) -> Result<S::Ok, S::Error>
183where
184    T: Serialize,
185    U: Serialize,
186    &'a V: IntoIterator<Item = (T, U)> + 'a,
187    S: Serializer,
188{
189    let vec = v.into_iter().collect::<Vec<_>>();
190    vec.serialize(serializer)
191}
192
193fn deserialize_hashmap<'de, T, U, D>(deserializer: D) -> Result<HashMap<T, U>, D::Error>
194where
195    T: Eq + std::hash::Hash + Deserialize<'de>,
196    U: Deserialize<'de>,
197    D: Deserializer<'de>,
198{
199    Ok(Vec::<(T, U)>::deserialize(deserializer)?
200        .into_iter()
201        .collect::<HashMap<T, U>>())
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use openmls_test::openmls_test;
208    use rand::{
209        rngs::{OsRng, StdRng},
210        Rng, SeedableRng,
211    };
212
213    fn random_vec(rng: &mut impl Rng, len: usize) -> Vec<u8> {
214        let mut bytes = vec![0u8; len];
215        rng.fill_bytes(&mut bytes);
216        bytes
217    }
218
219    fn dummy_secret(rng: &mut impl Rng, ciphersuite: Ciphersuite) -> Secret {
220        Secret::from_slice(&random_vec(rng, ciphersuite.hash_length()))
221    }
222
223    fn dummy_index<P: Prefix>(rng: &mut impl Rng) -> Vec<u8> {
224        random_vec(rng, P::MAX_INPUT_LEN)
225    }
226
227    #[openmls_test]
228    fn evaluates_single_path() {
229        let provider = &Provider::default();
230        let seed: [u8; 32] = OsRng.gen();
231        println!("Seed: {:?}", seed);
232        let mut rng = StdRng::from_seed(seed);
233        let root_secret = dummy_secret(&mut rng, ciphersuite);
234        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
235        let index = dummy_index::<Prefix16>(&mut rng);
236        let crypto = provider.crypto();
237
238        let result = pprf.evaluate(crypto, ciphersuite, &index);
239        assert!(result.is_ok());
240        assert_eq!(result.as_ref().unwrap().as_slice().len(), 32);
241    }
242
243    #[openmls_test]
244    fn re_evaluation_of_same_index_returns_error() {
245        let provider = &Provider::default();
246        let seed: [u8; 32] = OsRng.gen();
247        println!("Seed: {:?}", seed);
248        let mut rng = StdRng::from_seed(seed);
249        let root_secret = dummy_secret(&mut rng, ciphersuite);
250        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
251        let index = dummy_index::<Prefix16>(&mut rng);
252        let crypto = provider.crypto();
253
254        let _first = pprf.evaluate(crypto, ciphersuite, &index).unwrap();
255        let second = pprf
256            .evaluate(crypto, ciphersuite, &index)
257            .expect_err("Evaluation on same input should fail");
258
259        assert!(matches!(second, PprfError::PuncturedInput));
260    }
261
262    #[openmls_test]
263    fn different_indices_produce_different_results() {
264        let provider = &Provider::default();
265        let seed: [u8; 32] = OsRng.gen();
266        println!("Seed: {:?}", seed);
267        let mut rng = StdRng::from_seed(seed);
268        let root_secret = dummy_secret(&mut rng, ciphersuite);
269        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
270        let index1 = dummy_index::<Prefix16>(&mut rng);
271        let index2 = dummy_index::<Prefix16>(&mut rng);
272        let crypto = provider.crypto();
273
274        let leaf1 = pprf.evaluate(crypto, ciphersuite, &index1).unwrap();
275        let leaf2 = pprf.evaluate(crypto, ciphersuite, &index2).unwrap();
276
277        assert_ne!(leaf1.as_slice(), leaf2.as_slice());
278    }
279
280    #[openmls_test]
281    fn rejects_out_of_bounds_index() {
282        let provider = &Provider::default();
283        let seed: [u8; 32] = OsRng.gen();
284        println!("Seed: {:?}", seed);
285        let mut rng = StdRng::from_seed(seed);
286        let root_secret = dummy_secret(&mut rng, ciphersuite);
287        let mut pprf = Pprf::<Prefix16>::new_for_test(root_secret);
288        let index = random_vec(&mut rng, Prefix16::MAX_INPUT_LEN + 1); // Out of bounds
289
290        let crypto = provider.crypto();
291
292        let result = pprf.evaluate(crypto, ciphersuite, &index);
293        assert!(matches!(result, Err(PprfError::IndexOutOfBounds)));
294    }
295}