openmls/schedule/
psk.rs

1//! # Preshared keys.
2
3use std::borrow::Borrow;
4
5use openmls_traits::{random::OpenMlsRand, storage::StorageProvider as StorageProviderTrait};
6use serde::{Deserialize, Serialize};
7use tls_codec::{Serialize as TlsSerializeTrait, VLBytes};
8
9use super::*;
10use crate::{
11    group::{GroupEpoch, GroupId},
12    schedule::psk::store::ResumptionPskStore,
13    storage::{OpenMlsProvider, StorageProvider},
14};
15
16/// Resumption PSK usage.
17///
18/// ```c
19/// // draft-ietf-mls-protocol-19
20/// enum {
21///   reserved(0),
22///   application(1),
23///   reinit(2),
24///   branch(3),
25///   (255)
26/// } ResumptionPSKUsage;
27/// ```
28#[derive(
29    Clone,
30    Copy,
31    Debug,
32    PartialEq,
33    Eq,
34    PartialOrd,
35    Ord,
36    Hash,
37    Deserialize,
38    Serialize,
39    TlsDeserialize,
40    TlsDeserializeBytes,
41    TlsSerialize,
42    TlsSize,
43)]
44#[repr(u8)]
45pub enum ResumptionPskUsage {
46    /// Application.
47    Application = 1,
48    /// Resumption PSK used for group reinitialization.
49    ///
50    /// Note: "Resumption PSKs with usage `reinit` MUST NOT be used in other contexts (than reinitialization)."
51    Reinit = 2,
52    /// Resumption PSK used for subgroup branching.
53    ///
54    /// Note: "Resumption PSKs with usage `branch` MUST NOT be used in other contexts (than subgroup branching)."
55    Branch = 3,
56}
57
58/// External PSK.
59#[derive(
60    Debug,
61    PartialEq,
62    Eq,
63    PartialOrd,
64    Ord,
65    Clone,
66    Hash,
67    Deserialize,
68    Serialize,
69    TlsDeserialize,
70    TlsDeserializeBytes,
71    TlsSerialize,
72    TlsSize,
73)]
74pub struct ExternalPsk {
75    psk_id: VLBytes,
76}
77
78impl ExternalPsk {
79    /// Create a new `ExternalPsk` from a PSK ID
80    pub fn new(psk_id: Vec<u8>) -> Self {
81        Self {
82            psk_id: psk_id.into(),
83        }
84    }
85
86    /// Return the PSK ID
87    pub fn psk_id(&self) -> &[u8] {
88        self.psk_id.as_slice()
89    }
90}
91
92/// Contains the secret part of the PSK as well as the
93/// public part that is used as a marker for injection into the key schedule.
94#[derive(Serialize, Deserialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)]
95pub(crate) struct PskBundle {
96    secret: Secret,
97}
98
99/// Resumption PSK.
100#[derive(
101    Clone,
102    Debug,
103    PartialEq,
104    Eq,
105    PartialOrd,
106    Ord,
107    Deserialize,
108    Serialize,
109    TlsDeserialize,
110    TlsDeserializeBytes,
111    TlsSerialize,
112    TlsSize,
113)]
114pub struct ResumptionPsk {
115    pub(crate) usage: ResumptionPskUsage,
116    pub(crate) psk_group_id: GroupId,
117    pub(crate) psk_epoch: GroupEpoch,
118}
119
120impl ResumptionPsk {
121    /// Create a new `ResumptionPsk`
122    pub fn new(usage: ResumptionPskUsage, psk_group_id: GroupId, psk_epoch: GroupEpoch) -> Self {
123        Self {
124            usage,
125            psk_group_id,
126            psk_epoch,
127        }
128    }
129
130    /// Return the usage
131    pub fn usage(&self) -> ResumptionPskUsage {
132        self.usage
133    }
134
135    /// Return the `GroupId`
136    pub fn psk_group_id(&self) -> &GroupId {
137        &self.psk_group_id
138    }
139
140    /// Return the `GroupEpoch`
141    pub fn psk_epoch(&self) -> GroupEpoch {
142        self.psk_epoch
143    }
144}
145
146/// The different PSK types.
147#[derive(
148    Clone,
149    Debug,
150    PartialEq,
151    Eq,
152    PartialOrd,
153    Ord,
154    Deserialize,
155    Serialize,
156    TlsDeserialize,
157    TlsDeserializeBytes,
158    TlsSerialize,
159    TlsSize,
160)]
161#[repr(u8)]
162pub enum Psk {
163    /// An external PSK provided by the application.
164    #[tls_codec(discriminant = 1)]
165    External(ExternalPsk),
166    /// A resumption PSK derived from the MLS key schedule.
167    #[tls_codec(discriminant = 2)]
168    Resumption(ResumptionPsk),
169}
170
171/// ```c
172/// // draft-ietf-mls-protocol-19
173/// enum {
174///   reserved(0),
175///   external(1),
176///   resumption(2),
177///   (255)
178/// } PSKType;
179/// ```
180#[derive(Clone, Copy, Debug, Eq, PartialEq)]
181#[repr(u8)]
182pub enum PskType {
183    /// An external PSK.
184    External = 1,
185    /// A resumption PSK.
186    Resumption = 2,
187}
188
189/// A `PreSharedKeyID` is used to uniquely identify the PSKs that get injected
190/// in the key schedule.
191///
192/// ```c
193/// // draft-ietf-mls-protocol-19
194/// struct {
195///   PSKType psktype;
196///   select (PreSharedKeyID.psktype) {
197///     case external:
198///       opaque psk_id<V>;
199///
200///     case resumption:
201///       ResumptionPSKUsage usage;
202///       opaque psk_group_id<V>;
203///       uint64 psk_epoch;
204///   };
205///   opaque psk_nonce<V>;
206/// } PreSharedKeyID;
207/// ```
208#[derive(
209    Clone,
210    Debug,
211    PartialEq,
212    Eq,
213    PartialOrd,
214    Ord,
215    Deserialize,
216    Serialize,
217    TlsDeserialize,
218    TlsDeserializeBytes,
219    TlsSerialize,
220    TlsSize,
221)]
222pub struct PreSharedKeyId {
223    pub(crate) psk: Psk,
224    pub(crate) psk_nonce: VLBytes,
225}
226
227impl PreSharedKeyId {
228    /// Construct a `PreSharedKeyID` with a random nonce.
229    pub fn new(
230        ciphersuite: Ciphersuite,
231        rand: &impl OpenMlsRand,
232        psk: Psk,
233    ) -> Result<Self, CryptoError> {
234        let psk_nonce = rand
235            .random_vec(ciphersuite.hash_length())
236            .map_err(|_| CryptoError::InsufficientRandomness)?
237            .into();
238
239        Ok(Self { psk, psk_nonce })
240    }
241
242    /// Construct an external `PreSharedKeyID`.
243    pub fn external(psk_id: Vec<u8>, psk_nonce: Vec<u8>) -> Self {
244        let psk = Psk::External(ExternalPsk::new(psk_id));
245
246        Self {
247            psk,
248            psk_nonce: psk_nonce.into(),
249        }
250    }
251
252    /// Construct a resumption `PreSharedKeyID`.
253    pub fn resumption(
254        usage: ResumptionPskUsage,
255        psk_group_id: GroupId,
256        psk_epoch: GroupEpoch,
257        psk_nonce: Vec<u8>,
258    ) -> Self {
259        let psk = Psk::Resumption(ResumptionPsk::new(usage, psk_group_id, psk_epoch));
260
261        Self {
262            psk,
263            psk_nonce: psk_nonce.into(),
264        }
265    }
266
267    /// Return the PSK.
268    pub fn psk(&self) -> &Psk {
269        &self.psk
270    }
271
272    /// Return the PSK nonce.
273    pub fn psk_nonce(&self) -> &[u8] {
274        self.psk_nonce.as_slice()
275    }
276
277    // ----- Key Store -----------------------------------------------------------------------------
278
279    /// Save this `PreSharedKeyId` in the keystore.
280    ///
281    /// Note: The nonce is not saved as it must be unique for each time it's being applied.
282    pub fn store<Provider: OpenMlsProvider>(
283        &self,
284        provider: &Provider,
285        psk: &[u8],
286    ) -> Result<(), PskError> {
287        let psk_bundle = {
288            let secret = Secret::from_slice(psk);
289
290            PskBundle { secret }
291        };
292
293        provider
294            .storage()
295            .write_psk(&self.psk, &psk_bundle)
296            .map_err(|_| PskError::Storage)
297    }
298
299    // ----- Validation ----------------------------------------------------------------------------
300
301    pub(crate) fn validate_in_proposal(self, ciphersuite: Ciphersuite) -> Result<Self, PskError> {
302        // ValSem402
303        match self.psk() {
304            Psk::Resumption(resumption_psk) => {
305                if resumption_psk.usage != ResumptionPskUsage::Application {
306                    return Err(PskError::UsageMismatch {
307                        allowed: vec![ResumptionPskUsage::Application],
308                        got: resumption_psk.usage,
309                    });
310                }
311            }
312            Psk::External(_) => {}
313        };
314
315        // ValSem401
316        {
317            let expected_nonce_length = ciphersuite.hash_length();
318            let got_nonce_length = self.psk_nonce().len();
319
320            if expected_nonce_length != got_nonce_length {
321                return Err(PskError::NonceLengthMismatch {
322                    expected: expected_nonce_length,
323                    got: got_nonce_length,
324                });
325            }
326        }
327
328        Ok(self)
329    }
330}
331
332#[cfg(test)]
333impl PreSharedKeyId {
334    pub(crate) fn new_with_nonce(psk: Psk, psk_nonce: Vec<u8>) -> Self {
335        Self {
336            psk,
337            psk_nonce: psk_nonce.into(),
338        }
339    }
340}
341
342/// `PskLabel` is used in the final concatentation of PSKs before they are
343/// injected in the key schedule.
344///
345/// ```c
346/// // draft-ietf-mls-protocol-19
347/// struct {
348///     PreSharedKeyID id;
349///     uint16 index;
350///     uint16 count;
351/// } PSKLabel;
352/// ```
353#[derive(TlsSerialize, TlsSize)]
354pub(crate) struct PskLabel<'a> {
355    pub(crate) id: &'a PreSharedKeyId,
356    pub(crate) index: u16,
357    pub(crate) count: u16,
358}
359
360impl<'a> PskLabel<'a> {
361    /// Create a new `PskLabel`
362    fn new(id: &'a PreSharedKeyId, index: u16, count: u16) -> Self {
363        Self { id, index, count }
364    }
365}
366
367/// This contains the `psk-secret` calculated from the PSKs contained in a
368/// Commit or a PreSharedKey proposal.
369#[derive(Clone)]
370pub struct PskSecret {
371    secret: Secret,
372}
373
374impl PskSecret {
375    /// Create a new `PskSecret` from PSK IDs and PSKs
376    ///
377    /// ```text
378    /// psk_extracted_[i] = KDF.Extract(0, psk_[i])
379    /// psk_input_[i] = ExpandWithLabel(psk_extracted_[i], "derived psk", PSKLabel, KDF.Nh)
380    ///
381    /// psk_secret_[0] = 0
382    /// psk_secret_[i] = KDF.Extract(psk_input[i-1], psk_secret_[i-1])
383    /// psk_secret     = psk_secret[n]
384    /// ```
385    pub(crate) fn new(
386        crypto: &impl OpenMlsCrypto,
387        ciphersuite: Ciphersuite,
388        psks: Vec<(impl Borrow<PreSharedKeyId>, Secret)>,
389    ) -> Result<Self, PskError> {
390        // Check that we don't have too many PSKs
391        let num_psks = u16::try_from(psks.len()).map_err(|_| PskError::TooManyKeys)?;
392
393        // Following comments are from `draft-ietf-mls-protocol-19`.
394        //
395        // psk_secret_[0] = 0
396        let mut psk_secret = Secret::zero(ciphersuite);
397
398        for (index, (psk_id, psk)) in psks.into_iter().enumerate() {
399            // psk_extracted_[i] = KDF.Extract(0, psk_[i])
400            let psk_extracted = {
401                let zero_secret = Secret::zero(ciphersuite);
402                zero_secret
403                    .hkdf_extract(crypto, ciphersuite, &psk)
404                    .map_err(LibraryError::unexpected_crypto_error)?
405            };
406
407            // psk_input_[i] = ExpandWithLabel( psk_extracted_[i], "derived psk", PSKLabel, KDF.Nh)
408            let psk_input = {
409                let psk_label = PskLabel::new(psk_id.borrow(), index as u16, num_psks)
410                    .tls_serialize_detached()
411                    .map_err(LibraryError::missing_bound_check)?;
412
413                psk_extracted
414                    .kdf_expand_label(
415                        crypto,
416                        ciphersuite,
417                        "derived psk",
418                        &psk_label,
419                        ciphersuite.hash_length(),
420                    )
421                    .map_err(LibraryError::unexpected_crypto_error)?
422            };
423
424            // psk_secret_[i] = KDF.Extract(psk_input_[i-1], psk_secret_[i-1])
425            psk_secret = psk_input
426                .hkdf_extract(crypto, ciphersuite, &psk_secret)
427                .map_err(LibraryError::unexpected_crypto_error)?;
428        }
429
430        Ok(Self { secret: psk_secret })
431    }
432
433    /// Return the inner secret
434    pub(crate) fn secret(&self) -> &Secret {
435        &self.secret
436    }
437
438    #[cfg(any(feature = "test-utils", feature = "crypto-debug", test))]
439    pub(crate) fn as_slice(&self) -> &[u8] {
440        self.secret.as_slice()
441    }
442}
443
444#[cfg(any(feature = "test-utils", test))]
445impl From<Secret> for PskSecret {
446    fn from(secret: Secret) -> Self {
447        Self { secret }
448    }
449}
450
451pub(crate) fn load_psks<'p, Storage: StorageProvider>(
452    storage: &Storage,
453    resumption_psk_store: &ResumptionPskStore,
454    psk_ids: &'p [PreSharedKeyId],
455) -> Result<Vec<(&'p PreSharedKeyId, Secret)>, PskError> {
456    let mut psk_bundles = Vec::new();
457
458    for psk_id in psk_ids.iter() {
459        log_crypto!(trace, "PSK store {:?}", resumption_psk_store);
460
461        match &psk_id.psk {
462            Psk::Resumption(resumption) => {
463                if let Some(psk_bundle) = resumption_psk_store.get(resumption.psk_epoch()) {
464                    psk_bundles.push((psk_id, psk_bundle.secret.clone()));
465                } else {
466                    return Err(PskError::KeyNotFound);
467                }
468            }
469            Psk::External(_) => {
470                let psk_bundle: Option<PskBundle> = storage
471                    .psk(psk_id.psk())
472                    .map_err(|_| PskError::KeyNotFound)?;
473                if let Some(psk_bundle) = psk_bundle {
474                    psk_bundles.push((psk_id, psk_bundle.secret));
475                } else {
476                    return Err(PskError::KeyNotFound);
477                }
478            }
479        }
480    }
481
482    Ok(psk_bundles)
483}
484
485/// This module contains a store that can hold a rollover list of resumption PSKs.
486pub mod store {
487    use serde::{Deserialize, Serialize};
488
489    use crate::{group::GroupEpoch, schedule::ResumptionPskSecret};
490
491    /// Resumption PSK store.
492    ///
493    /// This is where the resumption PSKs are kept in a rollover list.
494    #[derive(Debug, Serialize, Deserialize)]
495    #[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
496    pub(crate) struct ResumptionPskStore {
497        max_number_of_secrets: usize,
498        resumption_psk: Vec<(GroupEpoch, ResumptionPskSecret)>,
499        cursor: usize,
500    }
501
502    impl ResumptionPskStore {
503        /// Creates a new store with a given maximum size of `number_of_secrets`.
504        pub(crate) fn new(max_number_of_secrets: usize) -> Self {
505            Self {
506                max_number_of_secrets,
507                resumption_psk: vec![],
508                cursor: 0,
509            }
510        }
511
512        /// Adds a new entry to the store.
513        pub(crate) fn add(&mut self, epoch: GroupEpoch, resumption_psk: ResumptionPskSecret) {
514            if self.max_number_of_secrets == 0 {
515                return;
516            }
517            let item = (epoch, resumption_psk);
518            if self.resumption_psk.len() < self.max_number_of_secrets {
519                self.resumption_psk.push(item);
520                self.cursor += 1;
521            } else {
522                self.cursor += 1;
523                self.cursor %= self.resumption_psk.len();
524                self.resumption_psk[self.cursor] = item;
525            }
526        }
527
528        /// Searches an entry for a given epoch number and if found, returns the
529        /// corresponding resumption psk.
530        pub(crate) fn get(&self, epoch: GroupEpoch) -> Option<&ResumptionPskSecret> {
531            self.resumption_psk
532                .iter()
533                .find(|&(e, _s)| e == &epoch)
534                .map(|(_e, s)| s)
535        }
536    }
537
538    #[cfg(test)]
539    impl ResumptionPskStore {
540        pub(crate) fn cursor(&self) -> usize {
541            self.cursor
542        }
543    }
544}