openmls/schedule/
psk.rs

1//! # Preshared keys.
2
3use std::borrow::Borrow;
4
5use openmls_traits::{random::OpenMlsRand, storage::StorageProvider as StorageProviderTrait};
6use serde::{Deserialize, Serialize};
7use tls_codec::{Serialize as TlsSerializeTrait, VLBytes};
8
9use super::*;
10use crate::{
11    group::{GroupEpoch, GroupId},
12    schedule::psk::store::ResumptionPskStore,
13    storage::{OpenMlsProvider, StorageProvider},
14};
15
16/// Resumption PSK usage.
17///
18/// ```c
19/// // draft-ietf-mls-protocol-19
20/// enum {
21///   reserved(0),
22///   application(1),
23///   reinit(2),
24///   branch(3),
25///   (255)
26/// } ResumptionPSKUsage;
27/// ```
28#[derive(
29    Clone,
30    Copy,
31    Debug,
32    PartialEq,
33    Eq,
34    PartialOrd,
35    Ord,
36    Hash,
37    Deserialize,
38    Serialize,
39    TlsDeserialize,
40    TlsDeserializeBytes,
41    TlsSerialize,
42    TlsSize,
43)]
44#[repr(u8)]
45pub enum ResumptionPskUsage {
46    /// Application.
47    Application = 1,
48    /// Resumption PSK used for group reinitialization.
49    ///
50    /// Note: "Resumption PSKs with usage `reinit` MUST NOT be used in other contexts (than reinitialization)."
51    Reinit = 2,
52    /// Resumption PSK used for subgroup branching.
53    ///
54    /// Note: "Resumption PSKs with usage `branch` MUST NOT be used in other contexts (than subgroup branching)."
55    Branch = 3,
56}
57
58/// External PSK.
59#[derive(
60    Debug,
61    PartialEq,
62    Eq,
63    PartialOrd,
64    Ord,
65    Clone,
66    Hash,
67    Deserialize,
68    Serialize,
69    TlsDeserialize,
70    TlsDeserializeBytes,
71    TlsSerialize,
72    TlsSize,
73)]
74pub struct ExternalPsk {
75    psk_id: VLBytes,
76}
77
78impl ExternalPsk {
79    /// Create a new `ExternalPsk` from a PSK ID
80    pub fn new(psk_id: Vec<u8>) -> Self {
81        Self {
82            psk_id: psk_id.into(),
83        }
84    }
85
86    /// Return the PSK ID
87    pub fn psk_id(&self) -> &[u8] {
88        self.psk_id.as_slice()
89    }
90}
91
92/// Contains the secret part of the PSK as well as the
93/// public part that is used as a marker for injection into the key schedule.
94#[derive(Serialize, Deserialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)]
95pub(crate) struct PskBundle {
96    secret: Secret,
97}
98
99/// Resumption PSK.
100#[derive(
101    Clone,
102    Debug,
103    PartialEq,
104    Eq,
105    PartialOrd,
106    Ord,
107    Deserialize,
108    Serialize,
109    TlsDeserialize,
110    TlsDeserializeBytes,
111    TlsSerialize,
112    TlsSize,
113)]
114pub struct ResumptionPsk {
115    pub(crate) usage: ResumptionPskUsage,
116    pub(crate) psk_group_id: GroupId,
117    pub(crate) psk_epoch: GroupEpoch,
118}
119
120impl ResumptionPsk {
121    /// Create a new `ResumptionPsk`
122    pub fn new(usage: ResumptionPskUsage, psk_group_id: GroupId, psk_epoch: GroupEpoch) -> Self {
123        Self {
124            usage,
125            psk_group_id,
126            psk_epoch,
127        }
128    }
129
130    /// Return the usage
131    pub fn usage(&self) -> ResumptionPskUsage {
132        self.usage
133    }
134
135    /// Return the `GroupId`
136    pub fn psk_group_id(&self) -> &GroupId {
137        &self.psk_group_id
138    }
139
140    /// Return the `GroupEpoch`
141    pub fn psk_epoch(&self) -> GroupEpoch {
142        self.psk_epoch
143    }
144}
145
146/// The different PSK types.
147#[derive(
148    Clone,
149    Debug,
150    PartialEq,
151    Eq,
152    PartialOrd,
153    Ord,
154    Deserialize,
155    Serialize,
156    TlsDeserialize,
157    TlsDeserializeBytes,
158    TlsSerialize,
159    TlsSize,
160)]
161#[repr(u8)]
162pub enum Psk {
163    /// An external PSK provided by the application.
164    #[tls_codec(discriminant = 1)]
165    External(ExternalPsk),
166    /// A resumption PSK derived from the MLS key schedule.
167    #[tls_codec(discriminant = 2)]
168    Resumption(ResumptionPsk),
169}
170
171/// ```c
172/// // draft-ietf-mls-protocol-19
173/// enum {
174///   reserved(0),
175///   external(1),
176///   resumption(2),
177///   (255)
178/// } PSKType;
179/// ```
180#[derive(Clone, Copy, Debug, Eq, PartialEq)]
181#[repr(u8)]
182pub enum PskType {
183    /// An external PSK.
184    External = 1,
185    /// A resumption PSK.
186    Resumption = 2,
187}
188
189/// A `PreSharedKeyID` is used to uniquely identify the PSKs that get injected
190/// in the key schedule.
191///
192/// ```c
193/// // draft-ietf-mls-protocol-19
194/// struct {
195///   PSKType psktype;
196///   select (PreSharedKeyID.psktype) {
197///     case external:
198///       opaque psk_id<V>;
199///
200///     case resumption:
201///       ResumptionPSKUsage usage;
202///       opaque psk_group_id<V>;
203///       uint64 psk_epoch;
204///   };
205///   opaque psk_nonce<V>;
206/// } PreSharedKeyID;
207/// ```
208#[derive(
209    Clone,
210    Debug,
211    PartialEq,
212    Eq,
213    PartialOrd,
214    Ord,
215    Deserialize,
216    Serialize,
217    TlsDeserialize,
218    TlsDeserializeBytes,
219    TlsSerialize,
220    TlsSize,
221)]
222pub struct PreSharedKeyId {
223    pub(crate) psk: Psk,
224    pub(crate) psk_nonce: VLBytes,
225}
226
227impl PreSharedKeyId {
228    /// Construct a `PreSharedKeyID` with a random nonce.
229    pub fn new(
230        ciphersuite: Ciphersuite,
231        rand: &impl OpenMlsRand,
232        psk: Psk,
233    ) -> Result<Self, CryptoError> {
234        let psk_nonce = rand
235            .random_vec(ciphersuite.hash_length())
236            .map_err(|_| CryptoError::InsufficientRandomness)?
237            .into();
238
239        Ok(Self { psk, psk_nonce })
240    }
241
242    /// Construct an external `PreSharedKeyID`.
243    pub fn external(psk_id: Vec<u8>, psk_nonce: Vec<u8>) -> Self {
244        let psk = Psk::External(ExternalPsk::new(psk_id));
245
246        Self {
247            psk,
248            psk_nonce: psk_nonce.into(),
249        }
250    }
251
252    /// Construct a resumption `PreSharedKeyID`.
253    pub fn resumption(
254        usage: ResumptionPskUsage,
255        psk_group_id: GroupId,
256        psk_epoch: GroupEpoch,
257        psk_nonce: Vec<u8>,
258    ) -> Self {
259        let psk = Psk::Resumption(ResumptionPsk::new(usage, psk_group_id, psk_epoch));
260
261        Self {
262            psk,
263            psk_nonce: psk_nonce.into(),
264        }
265    }
266
267    /// Return the PSK.
268    pub fn psk(&self) -> &Psk {
269        &self.psk
270    }
271
272    /// Return the PSK nonce.
273    pub fn psk_nonce(&self) -> &[u8] {
274        self.psk_nonce.as_slice()
275    }
276
277    // ----- Key Store -----------------------------------------------------------------------------
278
279    /// Save this `PreSharedKeyId` in the keystore.
280    ///
281    /// Note: The nonce is not saved as it must be unique for each time it's being applied.
282    pub fn store<Provider: OpenMlsProvider>(
283        &self,
284        provider: &Provider,
285        psk: &[u8],
286    ) -> Result<(), PskError> {
287        let psk_bundle = {
288            let secret = Secret::from_slice(psk);
289
290            PskBundle { secret }
291        };
292
293        provider
294            .storage()
295            .write_psk(&self.psk, &psk_bundle)
296            .map_err(|_| PskError::Storage)
297    }
298
299    // ----- Validation ----------------------------------------------------------------------------
300
301    pub(crate) fn validate_in_proposal(self, ciphersuite: Ciphersuite) -> Result<Self, PskError> {
302        // ValSem402
303        match self.psk() {
304            Psk::Resumption(resumption_psk) => {
305                // https://validation.openmls.tech/#valn0801
306                // https://validation.openmls.tech/#valn0802
307                if resumption_psk.usage != ResumptionPskUsage::Application {
308                    return Err(PskError::UsageMismatch {
309                        allowed: vec![ResumptionPskUsage::Application],
310                        got: resumption_psk.usage,
311                    });
312                }
313            }
314            Psk::External(_) => {}
315        };
316
317        // ValSem401
318        // https://validation.openmls.tech/#valn0803
319        {
320            let expected_nonce_length = ciphersuite.hash_length();
321            let got_nonce_length = self.psk_nonce().len();
322
323            if expected_nonce_length != got_nonce_length {
324                return Err(PskError::NonceLengthMismatch {
325                    expected: expected_nonce_length,
326                    got: got_nonce_length,
327                });
328            }
329        }
330
331        Ok(self)
332    }
333
334    pub(crate) fn validate_in_welcome(
335        psk_ids: &[PreSharedKeyId],
336        ciphersuite: Ciphersuite,
337    ) -> Result<(), PskError> {
338        let mut contains_branch_psk = false;
339        let mut contains_reinit_psk = false;
340        for id in psk_ids {
341            // https://validation.openmls.tech/#valn1401
342            match id.psk() {
343                Psk::Resumption(resumption_psk) => match resumption_psk.usage {
344                    ResumptionPskUsage::Application => {
345                        return Err(PskError::UsageMismatch {
346                            allowed: vec![ResumptionPskUsage::Reinit, ResumptionPskUsage::Branch],
347                            got: resumption_psk.usage,
348                        });
349                    }
350                    ResumptionPskUsage::Reinit => {
351                        if contains_reinit_psk {
352                            return Err(PskError::UsageDuplicate {
353                                usage: ResumptionPskUsage::Reinit,
354                            });
355                        }
356                        if contains_branch_psk {
357                            return Err(PskError::UsageConflict {
358                                first: ResumptionPskUsage::Reinit,
359                                second: ResumptionPskUsage::Branch,
360                            });
361                        }
362                        contains_reinit_psk = true;
363                    }
364                    ResumptionPskUsage::Branch => {
365                        if contains_branch_psk {
366                            return Err(PskError::UsageDuplicate {
367                                usage: ResumptionPskUsage::Branch,
368                            });
369                        }
370                        if contains_reinit_psk {
371                            return Err(PskError::UsageConflict {
372                                first: ResumptionPskUsage::Branch,
373                                second: ResumptionPskUsage::Reinit,
374                            });
375                        }
376                        contains_branch_psk = true;
377                    }
378                },
379                Psk::External(_) => {}
380            };
381
382            {
383                let expected_nonce_length = ciphersuite.hash_length();
384                let got_nonce_length = id.psk_nonce().len();
385
386                if expected_nonce_length != got_nonce_length {
387                    return Err(PskError::NonceLengthMismatch {
388                        expected: expected_nonce_length,
389                        got: got_nonce_length,
390                    });
391                }
392            }
393        }
394        Ok(())
395    }
396}
397
398#[cfg(test)]
399impl PreSharedKeyId {
400    pub(crate) fn new_with_nonce(psk: Psk, psk_nonce: Vec<u8>) -> Self {
401        Self {
402            psk,
403            psk_nonce: psk_nonce.into(),
404        }
405    }
406}
407
408/// `PskLabel` is used in the final concatentation of PSKs before they are
409/// injected in the key schedule.
410///
411/// ```c
412/// // draft-ietf-mls-protocol-19
413/// struct {
414///     PreSharedKeyID id;
415///     uint16 index;
416///     uint16 count;
417/// } PSKLabel;
418/// ```
419#[derive(TlsSerialize, TlsSize)]
420pub(crate) struct PskLabel<'a> {
421    pub(crate) id: &'a PreSharedKeyId,
422    pub(crate) index: u16,
423    pub(crate) count: u16,
424}
425
426impl<'a> PskLabel<'a> {
427    /// Create a new `PskLabel`
428    fn new(id: &'a PreSharedKeyId, index: u16, count: u16) -> Self {
429        Self { id, index, count }
430    }
431}
432
433/// This contains the `psk-secret` calculated from the PSKs contained in a
434/// Commit or a PreSharedKey proposal.
435#[derive(Clone)]
436pub struct PskSecret {
437    secret: Secret,
438}
439
440impl PskSecret {
441    /// Create a new `PskSecret` from PSK IDs and PSKs
442    ///
443    /// ```text
444    /// psk_extracted_[i] = KDF.Extract(0, psk_[i])
445    /// psk_input_[i] = ExpandWithLabel(psk_extracted_[i], "derived psk", PSKLabel, KDF.Nh)
446    ///
447    /// psk_secret_[0] = 0
448    /// psk_secret_[i] = KDF.Extract(psk_input[i-1], psk_secret_[i-1])
449    /// psk_secret     = psk_secret[n]
450    /// ```
451    pub(crate) fn new(
452        crypto: &impl OpenMlsCrypto,
453        ciphersuite: Ciphersuite,
454        psks: Vec<(impl Borrow<PreSharedKeyId>, Secret)>,
455    ) -> Result<Self, PskError> {
456        // Check that we don't have too many PSKs
457        let num_psks = u16::try_from(psks.len()).map_err(|_| PskError::TooManyKeys)?;
458
459        // Following comments are from `draft-ietf-mls-protocol-19`.
460        //
461        // psk_secret_[0] = 0
462        let mut psk_secret = Secret::zero(ciphersuite);
463
464        for (index, (psk_id, psk)) in psks.into_iter().enumerate() {
465            // psk_extracted_[i] = KDF.Extract(0, psk_[i])
466            let psk_extracted = {
467                let zero_secret = Secret::zero(ciphersuite);
468                zero_secret
469                    .hkdf_extract(crypto, ciphersuite, &psk)
470                    .map_err(LibraryError::unexpected_crypto_error)?
471            };
472
473            // psk_input_[i] = ExpandWithLabel( psk_extracted_[i], "derived psk", PSKLabel, KDF.Nh)
474            let psk_input = {
475                let psk_label = PskLabel::new(psk_id.borrow(), index as u16, num_psks)
476                    .tls_serialize_detached()
477                    .map_err(LibraryError::missing_bound_check)?;
478
479                psk_extracted
480                    .kdf_expand_label(
481                        crypto,
482                        ciphersuite,
483                        "derived psk",
484                        &psk_label,
485                        ciphersuite.hash_length(),
486                    )
487                    .map_err(LibraryError::unexpected_crypto_error)?
488            };
489
490            // psk_secret_[i] = KDF.Extract(psk_input_[i-1], psk_secret_[i-1])
491            psk_secret = psk_input
492                .hkdf_extract(crypto, ciphersuite, &psk_secret)
493                .map_err(LibraryError::unexpected_crypto_error)?;
494        }
495
496        Ok(Self { secret: psk_secret })
497    }
498
499    /// Return the inner secret
500    pub(crate) fn secret(&self) -> &Secret {
501        &self.secret
502    }
503
504    #[cfg(any(feature = "test-utils", feature = "crypto-debug", test))]
505    pub(crate) fn as_slice(&self) -> &[u8] {
506        self.secret.as_slice()
507    }
508}
509
510#[cfg(any(feature = "test-utils", test))]
511impl From<Secret> for PskSecret {
512    fn from(secret: Secret) -> Self {
513        Self { secret }
514    }
515}
516
517pub(crate) fn load_psks<'p, Storage: StorageProvider>(
518    storage: &Storage,
519    resumption_psk_store: &ResumptionPskStore,
520    psk_ids: &'p [PreSharedKeyId],
521) -> Result<Vec<(&'p PreSharedKeyId, Secret)>, PskError> {
522    let mut psk_bundles = Vec::new();
523
524    for psk_id in psk_ids.iter() {
525        log_crypto!(trace, "PSK store {:?}", resumption_psk_store);
526
527        match &psk_id.psk {
528            Psk::Resumption(resumption) => {
529                if let Some(psk_bundle) = resumption_psk_store.get(resumption.psk_epoch()) {
530                    psk_bundles.push((psk_id, psk_bundle.secret.clone()));
531                } else {
532                    return Err(PskError::KeyNotFound);
533                }
534            }
535            Psk::External(_) => {
536                let psk_bundle: Option<PskBundle> = storage
537                    .psk(psk_id.psk())
538                    .map_err(|_| PskError::KeyNotFound)?;
539                if let Some(psk_bundle) = psk_bundle {
540                    psk_bundles.push((psk_id, psk_bundle.secret));
541                } else {
542                    return Err(PskError::KeyNotFound);
543                }
544            }
545        }
546    }
547
548    Ok(psk_bundles)
549}
550
551/// This module contains a store that can hold a rollover list of resumption PSKs.
552pub mod store {
553    use serde::{Deserialize, Serialize};
554
555    use crate::{group::GroupEpoch, schedule::ResumptionPskSecret};
556
557    /// Resumption PSK store.
558    ///
559    /// This is where the resumption PSKs are kept in a rollover list.
560    #[derive(Debug, Serialize, Deserialize)]
561    #[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
562    pub(crate) struct ResumptionPskStore {
563        max_number_of_secrets: usize,
564        resumption_psk: Vec<(GroupEpoch, ResumptionPskSecret)>,
565        cursor: usize,
566    }
567
568    impl ResumptionPskStore {
569        /// Creates a new store with a given maximum size of `number_of_secrets`.
570        pub(crate) fn new(max_number_of_secrets: usize) -> Self {
571            Self {
572                max_number_of_secrets,
573                resumption_psk: vec![],
574                cursor: 0,
575            }
576        }
577
578        /// Adds a new entry to the store.
579        pub(crate) fn add(&mut self, epoch: GroupEpoch, resumption_psk: ResumptionPskSecret) {
580            if self.max_number_of_secrets == 0 {
581                return;
582            }
583            let item = (epoch, resumption_psk);
584            if self.resumption_psk.len() < self.max_number_of_secrets {
585                self.resumption_psk.push(item);
586                self.cursor += 1;
587            } else {
588                self.cursor += 1;
589                self.cursor %= self.resumption_psk.len();
590                self.resumption_psk[self.cursor] = item;
591            }
592        }
593
594        /// Searches an entry for a given epoch number and if found, returns the
595        /// corresponding resumption psk.
596        pub(crate) fn get(&self, epoch: GroupEpoch) -> Option<&ResumptionPskSecret> {
597            self.resumption_psk
598                .iter()
599                .find(|&(e, _s)| e == &epoch)
600                .map(|(_e, s)| s)
601        }
602    }
603
604    #[cfg(test)]
605    impl ResumptionPskStore {
606        pub(crate) fn cursor(&self) -> usize {
607            self.cursor
608        }
609    }
610}