openmls/schedule/
psk.rs

1//! # Preshared keys.
2
3use std::borrow::Borrow;
4
5use openmls_traits::{random::OpenMlsRand, storage::StorageProvider as StorageProviderTrait};
6use serde::{Deserialize, Serialize};
7use tls_codec::{Serialize as TlsSerializeTrait, VLBytes};
8
9use super::*;
10use crate::{
11    group::{GroupEpoch, GroupId},
12    schedule::psk::store::ResumptionPskStore,
13    storage::{OpenMlsProvider, StorageProvider},
14};
15
16/// Resumption PSK usage.
17///
18/// ```c
19/// // draft-ietf-mls-protocol-19
20/// enum {
21///   reserved(0),
22///   application(1),
23///   reinit(2),
24///   branch(3),
25///   (255)
26/// } ResumptionPSKUsage;
27/// ```
28#[derive(
29    Clone,
30    Copy,
31    Debug,
32    PartialEq,
33    Eq,
34    PartialOrd,
35    Ord,
36    Hash,
37    Deserialize,
38    Serialize,
39    TlsDeserialize,
40    TlsDeserializeBytes,
41    TlsSerialize,
42    TlsSize,
43)]
44#[repr(u8)]
45pub enum ResumptionPskUsage {
46    /// Application.
47    Application = 1,
48    /// Resumption PSK used for group reinitialization.
49    ///
50    /// Note: "Resumption PSKs with usage `reinit` MUST NOT be used in other contexts (than reinitialization)."
51    Reinit = 2,
52    /// Resumption PSK used for subgroup branching.
53    ///
54    /// Note: "Resumption PSKs with usage `branch` MUST NOT be used in other contexts (than subgroup branching)."
55    Branch = 3,
56}
57
58/// External PSK.
59#[derive(
60    Debug,
61    PartialEq,
62    Eq,
63    PartialOrd,
64    Ord,
65    Clone,
66    Hash,
67    Deserialize,
68    Serialize,
69    TlsDeserialize,
70    TlsDeserializeBytes,
71    TlsSerialize,
72    TlsSize,
73)]
74pub struct ExternalPsk {
75    psk_id: VLBytes,
76}
77
78impl ExternalPsk {
79    /// Create a new `ExternalPsk` from a PSK ID
80    pub fn new(psk_id: Vec<u8>) -> Self {
81        Self {
82            psk_id: psk_id.into(),
83        }
84    }
85
86    /// Return the PSK ID
87    pub fn psk_id(&self) -> &[u8] {
88        self.psk_id.as_slice()
89    }
90}
91
92/// Contains the secret part of the PSK as well as the
93/// public part that is used as a marker for injection into the key schedule.
94#[derive(Serialize, Deserialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)]
95pub(crate) struct PskBundle {
96    secret: Secret,
97}
98
99/// Resumption PSK.
100#[derive(
101    Clone,
102    Debug,
103    PartialEq,
104    Eq,
105    PartialOrd,
106    Ord,
107    Deserialize,
108    Serialize,
109    TlsDeserialize,
110    TlsDeserializeBytes,
111    TlsSerialize,
112    TlsSize,
113    Hash,
114)]
115pub struct ResumptionPsk {
116    pub(crate) usage: ResumptionPskUsage,
117    pub(crate) psk_group_id: GroupId,
118    pub(crate) psk_epoch: GroupEpoch,
119}
120
121impl ResumptionPsk {
122    /// Create a new `ResumptionPsk`
123    pub fn new(usage: ResumptionPskUsage, psk_group_id: GroupId, psk_epoch: GroupEpoch) -> Self {
124        Self {
125            usage,
126            psk_group_id,
127            psk_epoch,
128        }
129    }
130
131    /// Return the usage
132    pub fn usage(&self) -> ResumptionPskUsage {
133        self.usage
134    }
135
136    /// Return the `GroupId`
137    pub fn psk_group_id(&self) -> &GroupId {
138        &self.psk_group_id
139    }
140
141    /// Return the `GroupEpoch`
142    pub fn psk_epoch(&self) -> GroupEpoch {
143        self.psk_epoch
144    }
145}
146
147/// The different PSK types.
148#[derive(
149    Clone,
150    Debug,
151    PartialEq,
152    Eq,
153    PartialOrd,
154    Ord,
155    Deserialize,
156    Serialize,
157    TlsDeserialize,
158    TlsDeserializeBytes,
159    TlsSerialize,
160    TlsSize,
161    Hash,
162)]
163#[repr(u8)]
164pub enum Psk {
165    /// An external PSK provided by the application.
166    #[tls_codec(discriminant = 1)]
167    External(ExternalPsk),
168    /// A resumption PSK derived from the MLS key schedule.
169    #[tls_codec(discriminant = 2)]
170    Resumption(ResumptionPsk),
171}
172
173/// ```c
174/// // draft-ietf-mls-protocol-19
175/// enum {
176///   reserved(0),
177///   external(1),
178///   resumption(2),
179///   (255)
180/// } PSKType;
181/// ```
182#[derive(Clone, Copy, Debug, Eq, PartialEq)]
183#[repr(u8)]
184pub enum PskType {
185    /// An external PSK.
186    External = 1,
187    /// A resumption PSK.
188    Resumption = 2,
189}
190
191/// A `PreSharedKeyID` is used to uniquely identify the PSKs that get injected
192/// in the key schedule.
193///
194/// ```c
195/// // draft-ietf-mls-protocol-19
196/// struct {
197///   PSKType psktype;
198///   select (PreSharedKeyID.psktype) {
199///     case external:
200///       opaque psk_id<V>;
201///
202///     case resumption:
203///       ResumptionPSKUsage usage;
204///       opaque psk_group_id<V>;
205///       uint64 psk_epoch;
206///   };
207///   opaque psk_nonce<V>;
208/// } PreSharedKeyID;
209/// ```
210#[derive(
211    Clone,
212    Debug,
213    PartialEq,
214    Eq,
215    PartialOrd,
216    Ord,
217    Deserialize,
218    Serialize,
219    TlsDeserialize,
220    TlsDeserializeBytes,
221    TlsSerialize,
222    TlsSize,
223    Hash,
224)]
225pub struct PreSharedKeyId {
226    pub(crate) psk: Psk,
227    pub(crate) psk_nonce: VLBytes,
228}
229
230impl PreSharedKeyId {
231    /// Construct a `PreSharedKeyID` with a random nonce.
232    pub fn new(
233        ciphersuite: Ciphersuite,
234        rand: &impl OpenMlsRand,
235        psk: Psk,
236    ) -> Result<Self, CryptoError> {
237        let psk_nonce = rand
238            .random_vec(ciphersuite.hash_length())
239            .map_err(|_| CryptoError::InsufficientRandomness)?
240            .into();
241
242        Ok(Self { psk, psk_nonce })
243    }
244
245    /// Construct an external `PreSharedKeyID`.
246    pub fn external(psk_id: Vec<u8>, psk_nonce: Vec<u8>) -> Self {
247        let psk = Psk::External(ExternalPsk::new(psk_id));
248
249        Self {
250            psk,
251            psk_nonce: psk_nonce.into(),
252        }
253    }
254
255    /// Construct a resumption `PreSharedKeyID`.
256    pub fn resumption(
257        usage: ResumptionPskUsage,
258        psk_group_id: GroupId,
259        psk_epoch: GroupEpoch,
260        psk_nonce: Vec<u8>,
261    ) -> Self {
262        let psk = Psk::Resumption(ResumptionPsk::new(usage, psk_group_id, psk_epoch));
263
264        Self {
265            psk,
266            psk_nonce: psk_nonce.into(),
267        }
268    }
269
270    /// Return the PSK.
271    pub fn psk(&self) -> &Psk {
272        &self.psk
273    }
274
275    /// Return the PSK nonce.
276    pub fn psk_nonce(&self) -> &[u8] {
277        self.psk_nonce.as_slice()
278    }
279
280    // ----- Key Store -----------------------------------------------------------------------------
281
282    /// Save this `PreSharedKeyId` in the keystore.
283    ///
284    /// Note: The nonce is not saved as it must be unique for each time it's being applied.
285    pub fn store<Provider: OpenMlsProvider>(
286        &self,
287        provider: &Provider,
288        psk: &[u8],
289    ) -> Result<(), PskError> {
290        let psk_bundle = {
291            let secret = Secret::from_slice(psk);
292
293            PskBundle { secret }
294        };
295
296        provider
297            .storage()
298            .write_psk(&self.psk, &psk_bundle)
299            .map_err(|_| PskError::Storage)
300    }
301
302    // ----- Validation ----------------------------------------------------------------------------
303
304    pub(crate) fn validate_in_proposal(self, ciphersuite: Ciphersuite) -> Result<(), PskError> {
305        // ValSem402
306        match self.psk() {
307            Psk::Resumption(resumption_psk) => {
308                // https://validation.openmls.tech/#valn0801
309                // https://validation.openmls.tech/#valn0802
310                if resumption_psk.usage != ResumptionPskUsage::Application {
311                    return Err(PskError::UsageMismatch {
312                        allowed: vec![ResumptionPskUsage::Application],
313                        got: resumption_psk.usage,
314                    });
315                }
316            }
317            Psk::External(_) => {}
318        };
319
320        // ValSem401
321        // https://validation.openmls.tech/#valn0803
322        {
323            let expected_nonce_length = ciphersuite.hash_length();
324            let got_nonce_length = self.psk_nonce().len();
325
326            if expected_nonce_length != got_nonce_length {
327                return Err(PskError::NonceLengthMismatch {
328                    expected: expected_nonce_length,
329                    got: got_nonce_length,
330                });
331            }
332        }
333
334        Ok(())
335    }
336
337    pub(crate) fn validate_in_welcome(
338        psk_ids: &[PreSharedKeyId],
339        ciphersuite: Ciphersuite,
340    ) -> Result<(), PskError> {
341        let mut contains_branch_psk = false;
342        let mut contains_reinit_psk = false;
343        for id in psk_ids {
344            // https://validation.openmls.tech/#valn1401
345            match id.psk() {
346                Psk::Resumption(resumption_psk) => match resumption_psk.usage {
347                    ResumptionPskUsage::Application => {
348                        return Err(PskError::UsageMismatch {
349                            allowed: vec![ResumptionPskUsage::Reinit, ResumptionPskUsage::Branch],
350                            got: resumption_psk.usage,
351                        });
352                    }
353                    ResumptionPskUsage::Reinit => {
354                        if contains_reinit_psk {
355                            return Err(PskError::UsageDuplicate {
356                                usage: ResumptionPskUsage::Reinit,
357                            });
358                        }
359                        if contains_branch_psk {
360                            return Err(PskError::UsageConflict {
361                                first: ResumptionPskUsage::Reinit,
362                                second: ResumptionPskUsage::Branch,
363                            });
364                        }
365                        contains_reinit_psk = true;
366                    }
367                    ResumptionPskUsage::Branch => {
368                        if contains_branch_psk {
369                            return Err(PskError::UsageDuplicate {
370                                usage: ResumptionPskUsage::Branch,
371                            });
372                        }
373                        if contains_reinit_psk {
374                            return Err(PskError::UsageConflict {
375                                first: ResumptionPskUsage::Branch,
376                                second: ResumptionPskUsage::Reinit,
377                            });
378                        }
379                        contains_branch_psk = true;
380                    }
381                },
382                Psk::External(_) => {}
383            };
384
385            {
386                let expected_nonce_length = ciphersuite.hash_length();
387                let got_nonce_length = id.psk_nonce().len();
388
389                if expected_nonce_length != got_nonce_length {
390                    return Err(PskError::NonceLengthMismatch {
391                        expected: expected_nonce_length,
392                        got: got_nonce_length,
393                    });
394                }
395            }
396        }
397        Ok(())
398    }
399}
400
401#[cfg(test)]
402impl PreSharedKeyId {
403    pub(crate) fn new_with_nonce(psk: Psk, psk_nonce: Vec<u8>) -> Self {
404        Self {
405            psk,
406            psk_nonce: psk_nonce.into(),
407        }
408    }
409}
410
411/// `PskLabel` is used in the final concatentation of PSKs before they are
412/// injected in the key schedule.
413///
414/// ```c
415/// // draft-ietf-mls-protocol-19
416/// struct {
417///     PreSharedKeyID id;
418///     uint16 index;
419///     uint16 count;
420/// } PSKLabel;
421/// ```
422#[derive(TlsSerialize, TlsSize)]
423pub(crate) struct PskLabel<'a> {
424    pub(crate) id: &'a PreSharedKeyId,
425    pub(crate) index: u16,
426    pub(crate) count: u16,
427}
428
429impl<'a> PskLabel<'a> {
430    /// Create a new `PskLabel`
431    fn new(id: &'a PreSharedKeyId, index: u16, count: u16) -> Self {
432        Self { id, index, count }
433    }
434}
435
436/// This contains the `psk-secret` calculated from the PSKs contained in a
437/// Commit or a PreSharedKey proposal.
438#[derive(Clone)]
439pub struct PskSecret {
440    secret: Secret,
441}
442
443impl PskSecret {
444    /// Create a new `PskSecret` from PSK IDs and PSKs
445    ///
446    /// ```text
447    /// psk_extracted_[i] = KDF.Extract(0, psk_[i])
448    /// psk_input_[i] = ExpandWithLabel(psk_extracted_[i], "derived psk", PSKLabel, KDF.Nh)
449    ///
450    /// psk_secret_[0] = 0
451    /// psk_secret_[i] = KDF.Extract(psk_input[i-1], psk_secret_[i-1])
452    /// psk_secret     = psk_secret[n]
453    /// ```
454    pub(crate) fn new(
455        crypto: &impl OpenMlsCrypto,
456        ciphersuite: Ciphersuite,
457        psks: Vec<(impl Borrow<PreSharedKeyId>, Secret)>,
458    ) -> Result<Self, PskError> {
459        // Check that we don't have too many PSKs
460        let num_psks = u16::try_from(psks.len()).map_err(|_| PskError::TooManyKeys)?;
461
462        // Following comments are from `draft-ietf-mls-protocol-19`.
463        //
464        // psk_secret_[0] = 0
465        let mut psk_secret = Secret::zero(ciphersuite);
466
467        for (index, (psk_id, psk)) in psks.into_iter().enumerate() {
468            // psk_extracted_[i] = KDF.Extract(0, psk_[i])
469            let psk_extracted = {
470                let zero_secret = Secret::zero(ciphersuite);
471                zero_secret
472                    .hkdf_extract(crypto, ciphersuite, &psk)
473                    .map_err(LibraryError::unexpected_crypto_error)?
474            };
475
476            // psk_input_[i] = ExpandWithLabel( psk_extracted_[i], "derived psk", PSKLabel, KDF.Nh)
477            let psk_input = {
478                let psk_label = PskLabel::new(psk_id.borrow(), index as u16, num_psks)
479                    .tls_serialize_detached()
480                    .map_err(LibraryError::missing_bound_check)?;
481
482                psk_extracted
483                    .kdf_expand_label(
484                        crypto,
485                        ciphersuite,
486                        "derived psk",
487                        &psk_label,
488                        ciphersuite.hash_length(),
489                    )
490                    .map_err(LibraryError::unexpected_crypto_error)?
491            };
492
493            // psk_secret_[i] = KDF.Extract(psk_input_[i-1], psk_secret_[i-1])
494            psk_secret = psk_input
495                .hkdf_extract(crypto, ciphersuite, &psk_secret)
496                .map_err(LibraryError::unexpected_crypto_error)?;
497        }
498
499        Ok(Self { secret: psk_secret })
500    }
501
502    /// Return the inner secret
503    pub(crate) fn secret(&self) -> &Secret {
504        &self.secret
505    }
506
507    #[cfg(any(feature = "test-utils", feature = "crypto-debug", test))]
508    pub(crate) fn as_slice(&self) -> &[u8] {
509        self.secret.as_slice()
510    }
511}
512
513#[cfg(any(feature = "test-utils", test))]
514impl From<Secret> for PskSecret {
515    fn from(secret: Secret) -> Self {
516        Self { secret }
517    }
518}
519
520pub(crate) fn load_psks<'p, Storage: StorageProvider>(
521    storage: &Storage,
522    resumption_psk_store: &ResumptionPskStore,
523    psk_ids: &'p [PreSharedKeyId],
524) -> Result<Vec<(&'p PreSharedKeyId, Secret)>, PskError> {
525    let mut psk_bundles = Vec::new();
526
527    for psk_id in psk_ids.iter() {
528        log_crypto!(trace, "PSK store {:?}", resumption_psk_store);
529
530        match &psk_id.psk {
531            Psk::Resumption(resumption) => {
532                if let Some(psk_bundle) = resumption_psk_store.get(resumption.psk_epoch()) {
533                    psk_bundles.push((psk_id, psk_bundle.secret.clone()));
534                } else {
535                    return Err(PskError::KeyNotFound);
536                }
537            }
538            Psk::External(_) => {
539                let psk_bundle: Option<PskBundle> = storage
540                    .psk(psk_id.psk())
541                    .map_err(|_| PskError::KeyNotFound)?;
542                if let Some(psk_bundle) = psk_bundle {
543                    psk_bundles.push((psk_id, psk_bundle.secret));
544                } else {
545                    return Err(PskError::KeyNotFound);
546                }
547            }
548        }
549    }
550
551    Ok(psk_bundles)
552}
553
554/// This module contains a store that can hold a rollover list of resumption PSKs.
555pub mod store {
556    use serde::{Deserialize, Serialize};
557
558    use crate::{group::GroupEpoch, schedule::ResumptionPskSecret};
559
560    /// Resumption PSK store.
561    ///
562    /// This is where the resumption PSKs are kept in a rollover list.
563    #[derive(Debug, Serialize, Deserialize)]
564    #[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
565    pub(crate) struct ResumptionPskStore {
566        max_number_of_secrets: usize,
567        resumption_psk: Vec<(GroupEpoch, ResumptionPskSecret)>,
568        cursor: usize,
569    }
570
571    impl ResumptionPskStore {
572        /// Creates a new store with a given maximum size of `number_of_secrets`.
573        pub(crate) fn new(max_number_of_secrets: usize) -> Self {
574            Self {
575                max_number_of_secrets,
576                resumption_psk: vec![],
577                cursor: 0,
578            }
579        }
580
581        /// Adds a new entry to the store.
582        pub(crate) fn add(&mut self, epoch: GroupEpoch, resumption_psk: ResumptionPskSecret) {
583            if self.max_number_of_secrets == 0 {
584                return;
585            }
586            let item = (epoch, resumption_psk);
587            if self.resumption_psk.len() < self.max_number_of_secrets {
588                self.resumption_psk.push(item);
589                self.cursor += 1;
590            } else {
591                self.cursor += 1;
592                self.cursor %= self.resumption_psk.len();
593                self.resumption_psk[self.cursor] = item;
594            }
595        }
596
597        /// Searches an entry for a given epoch number and if found, returns the
598        /// corresponding resumption psk.
599        pub(crate) fn get(&self, epoch: GroupEpoch) -> Option<&ResumptionPskSecret> {
600            self.resumption_psk
601                .iter()
602                .find(|&(e, _s)| e == &epoch)
603                .map(|(_e, s)| s)
604        }
605    }
606
607    #[cfg(test)]
608    impl ResumptionPskStore {
609        pub(crate) fn cursor(&self) -> usize {
610            self.cursor
611        }
612    }
613}