1use std::borrow::Borrow;
4
5use openmls_traits::{random::OpenMlsRand, storage::StorageProvider as StorageProviderTrait};
6use serde::{Deserialize, Serialize};
7use tls_codec::{Serialize as TlsSerializeTrait, VLBytes};
8
9use super::*;
10use crate::{
11 group::{GroupEpoch, GroupId},
12 schedule::psk::store::ResumptionPskStore,
13 storage::{OpenMlsProvider, StorageProvider},
14};
15
16#[derive(
29 Clone,
30 Copy,
31 Debug,
32 PartialEq,
33 Eq,
34 PartialOrd,
35 Ord,
36 Hash,
37 Deserialize,
38 Serialize,
39 TlsDeserialize,
40 TlsDeserializeBytes,
41 TlsSerialize,
42 TlsSize,
43)]
44#[repr(u8)]
45pub enum ResumptionPskUsage {
46 Application = 1,
48 Reinit = 2,
52 Branch = 3,
56}
57
58#[derive(
60 Debug,
61 PartialEq,
62 Eq,
63 PartialOrd,
64 Ord,
65 Clone,
66 Hash,
67 Deserialize,
68 Serialize,
69 TlsDeserialize,
70 TlsDeserializeBytes,
71 TlsSerialize,
72 TlsSize,
73)]
74pub struct ExternalPsk {
75 psk_id: VLBytes,
76}
77
78impl ExternalPsk {
79 pub fn new(psk_id: Vec<u8>) -> Self {
81 Self {
82 psk_id: psk_id.into(),
83 }
84 }
85
86 pub fn psk_id(&self) -> &[u8] {
88 self.psk_id.as_slice()
89 }
90}
91
92#[derive(Serialize, Deserialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)]
95pub(crate) struct PskBundle {
96 secret: Secret,
97}
98
99#[derive(
101 Clone,
102 Debug,
103 PartialEq,
104 Eq,
105 PartialOrd,
106 Ord,
107 Deserialize,
108 Serialize,
109 TlsDeserialize,
110 TlsDeserializeBytes,
111 TlsSerialize,
112 TlsSize,
113)]
114pub struct ResumptionPsk {
115 pub(crate) usage: ResumptionPskUsage,
116 pub(crate) psk_group_id: GroupId,
117 pub(crate) psk_epoch: GroupEpoch,
118}
119
120impl ResumptionPsk {
121 pub fn new(usage: ResumptionPskUsage, psk_group_id: GroupId, psk_epoch: GroupEpoch) -> Self {
123 Self {
124 usage,
125 psk_group_id,
126 psk_epoch,
127 }
128 }
129
130 pub fn usage(&self) -> ResumptionPskUsage {
132 self.usage
133 }
134
135 pub fn psk_group_id(&self) -> &GroupId {
137 &self.psk_group_id
138 }
139
140 pub fn psk_epoch(&self) -> GroupEpoch {
142 self.psk_epoch
143 }
144}
145
146#[derive(
148 Clone,
149 Debug,
150 PartialEq,
151 Eq,
152 PartialOrd,
153 Ord,
154 Deserialize,
155 Serialize,
156 TlsDeserialize,
157 TlsDeserializeBytes,
158 TlsSerialize,
159 TlsSize,
160)]
161#[repr(u8)]
162pub enum Psk {
163 #[tls_codec(discriminant = 1)]
165 External(ExternalPsk),
166 #[tls_codec(discriminant = 2)]
168 Resumption(ResumptionPsk),
169}
170
171#[derive(Clone, Copy, Debug, Eq, PartialEq)]
181#[repr(u8)]
182pub enum PskType {
183 External = 1,
185 Resumption = 2,
187}
188
189#[derive(
209 Clone,
210 Debug,
211 PartialEq,
212 Eq,
213 PartialOrd,
214 Ord,
215 Deserialize,
216 Serialize,
217 TlsDeserialize,
218 TlsDeserializeBytes,
219 TlsSerialize,
220 TlsSize,
221)]
222pub struct PreSharedKeyId {
223 pub(crate) psk: Psk,
224 pub(crate) psk_nonce: VLBytes,
225}
226
227impl PreSharedKeyId {
228 pub fn new(
230 ciphersuite: Ciphersuite,
231 rand: &impl OpenMlsRand,
232 psk: Psk,
233 ) -> Result<Self, CryptoError> {
234 let psk_nonce = rand
235 .random_vec(ciphersuite.hash_length())
236 .map_err(|_| CryptoError::InsufficientRandomness)?
237 .into();
238
239 Ok(Self { psk, psk_nonce })
240 }
241
242 pub fn external(psk_id: Vec<u8>, psk_nonce: Vec<u8>) -> Self {
244 let psk = Psk::External(ExternalPsk::new(psk_id));
245
246 Self {
247 psk,
248 psk_nonce: psk_nonce.into(),
249 }
250 }
251
252 pub fn resumption(
254 usage: ResumptionPskUsage,
255 psk_group_id: GroupId,
256 psk_epoch: GroupEpoch,
257 psk_nonce: Vec<u8>,
258 ) -> Self {
259 let psk = Psk::Resumption(ResumptionPsk::new(usage, psk_group_id, psk_epoch));
260
261 Self {
262 psk,
263 psk_nonce: psk_nonce.into(),
264 }
265 }
266
267 pub fn psk(&self) -> &Psk {
269 &self.psk
270 }
271
272 pub fn psk_nonce(&self) -> &[u8] {
274 self.psk_nonce.as_slice()
275 }
276
277 pub fn store<Provider: OpenMlsProvider>(
283 &self,
284 provider: &Provider,
285 psk: &[u8],
286 ) -> Result<(), PskError> {
287 let psk_bundle = {
288 let secret = Secret::from_slice(psk);
289
290 PskBundle { secret }
291 };
292
293 provider
294 .storage()
295 .write_psk(&self.psk, &psk_bundle)
296 .map_err(|_| PskError::Storage)
297 }
298
299 pub(crate) fn validate_in_proposal(self, ciphersuite: Ciphersuite) -> Result<Self, PskError> {
302 match self.psk() {
304 Psk::Resumption(resumption_psk) => {
305 if resumption_psk.usage != ResumptionPskUsage::Application {
308 return Err(PskError::UsageMismatch {
309 allowed: vec![ResumptionPskUsage::Application],
310 got: resumption_psk.usage,
311 });
312 }
313 }
314 Psk::External(_) => {}
315 };
316
317 {
320 let expected_nonce_length = ciphersuite.hash_length();
321 let got_nonce_length = self.psk_nonce().len();
322
323 if expected_nonce_length != got_nonce_length {
324 return Err(PskError::NonceLengthMismatch {
325 expected: expected_nonce_length,
326 got: got_nonce_length,
327 });
328 }
329 }
330
331 Ok(self)
332 }
333
334 pub(crate) fn validate_in_welcome(
335 psk_ids: &[PreSharedKeyId],
336 ciphersuite: Ciphersuite,
337 ) -> Result<(), PskError> {
338 let mut contains_branch_psk = false;
339 let mut contains_reinit_psk = false;
340 for id in psk_ids {
341 match id.psk() {
343 Psk::Resumption(resumption_psk) => match resumption_psk.usage {
344 ResumptionPskUsage::Application => {
345 return Err(PskError::UsageMismatch {
346 allowed: vec![ResumptionPskUsage::Reinit, ResumptionPskUsage::Branch],
347 got: resumption_psk.usage,
348 });
349 }
350 ResumptionPskUsage::Reinit => {
351 if contains_reinit_psk {
352 return Err(PskError::UsageDuplicate {
353 usage: ResumptionPskUsage::Reinit,
354 });
355 }
356 if contains_branch_psk {
357 return Err(PskError::UsageConflict {
358 first: ResumptionPskUsage::Reinit,
359 second: ResumptionPskUsage::Branch,
360 });
361 }
362 contains_reinit_psk = true;
363 }
364 ResumptionPskUsage::Branch => {
365 if contains_branch_psk {
366 return Err(PskError::UsageDuplicate {
367 usage: ResumptionPskUsage::Branch,
368 });
369 }
370 if contains_reinit_psk {
371 return Err(PskError::UsageConflict {
372 first: ResumptionPskUsage::Branch,
373 second: ResumptionPskUsage::Reinit,
374 });
375 }
376 contains_branch_psk = true;
377 }
378 },
379 Psk::External(_) => {}
380 };
381
382 {
383 let expected_nonce_length = ciphersuite.hash_length();
384 let got_nonce_length = id.psk_nonce().len();
385
386 if expected_nonce_length != got_nonce_length {
387 return Err(PskError::NonceLengthMismatch {
388 expected: expected_nonce_length,
389 got: got_nonce_length,
390 });
391 }
392 }
393 }
394 Ok(())
395 }
396}
397
398#[cfg(test)]
399impl PreSharedKeyId {
400 pub(crate) fn new_with_nonce(psk: Psk, psk_nonce: Vec<u8>) -> Self {
401 Self {
402 psk,
403 psk_nonce: psk_nonce.into(),
404 }
405 }
406}
407
408#[derive(TlsSerialize, TlsSize)]
420pub(crate) struct PskLabel<'a> {
421 pub(crate) id: &'a PreSharedKeyId,
422 pub(crate) index: u16,
423 pub(crate) count: u16,
424}
425
426impl<'a> PskLabel<'a> {
427 fn new(id: &'a PreSharedKeyId, index: u16, count: u16) -> Self {
429 Self { id, index, count }
430 }
431}
432
433#[derive(Clone)]
436pub struct PskSecret {
437 secret: Secret,
438}
439
440impl PskSecret {
441 pub(crate) fn new(
452 crypto: &impl OpenMlsCrypto,
453 ciphersuite: Ciphersuite,
454 psks: Vec<(impl Borrow<PreSharedKeyId>, Secret)>,
455 ) -> Result<Self, PskError> {
456 let num_psks = u16::try_from(psks.len()).map_err(|_| PskError::TooManyKeys)?;
458
459 let mut psk_secret = Secret::zero(ciphersuite);
463
464 for (index, (psk_id, psk)) in psks.into_iter().enumerate() {
465 let psk_extracted = {
467 let zero_secret = Secret::zero(ciphersuite);
468 zero_secret
469 .hkdf_extract(crypto, ciphersuite, &psk)
470 .map_err(LibraryError::unexpected_crypto_error)?
471 };
472
473 let psk_input = {
475 let psk_label = PskLabel::new(psk_id.borrow(), index as u16, num_psks)
476 .tls_serialize_detached()
477 .map_err(LibraryError::missing_bound_check)?;
478
479 psk_extracted
480 .kdf_expand_label(
481 crypto,
482 ciphersuite,
483 "derived psk",
484 &psk_label,
485 ciphersuite.hash_length(),
486 )
487 .map_err(LibraryError::unexpected_crypto_error)?
488 };
489
490 psk_secret = psk_input
492 .hkdf_extract(crypto, ciphersuite, &psk_secret)
493 .map_err(LibraryError::unexpected_crypto_error)?;
494 }
495
496 Ok(Self { secret: psk_secret })
497 }
498
499 pub(crate) fn secret(&self) -> &Secret {
501 &self.secret
502 }
503
504 #[cfg(any(feature = "test-utils", feature = "crypto-debug", test))]
505 pub(crate) fn as_slice(&self) -> &[u8] {
506 self.secret.as_slice()
507 }
508}
509
510#[cfg(any(feature = "test-utils", test))]
511impl From<Secret> for PskSecret {
512 fn from(secret: Secret) -> Self {
513 Self { secret }
514 }
515}
516
517pub(crate) fn load_psks<'p, Storage: StorageProvider>(
518 storage: &Storage,
519 resumption_psk_store: &ResumptionPskStore,
520 psk_ids: &'p [PreSharedKeyId],
521) -> Result<Vec<(&'p PreSharedKeyId, Secret)>, PskError> {
522 let mut psk_bundles = Vec::new();
523
524 for psk_id in psk_ids.iter() {
525 log_crypto!(trace, "PSK store {:?}", resumption_psk_store);
526
527 match &psk_id.psk {
528 Psk::Resumption(resumption) => {
529 if let Some(psk_bundle) = resumption_psk_store.get(resumption.psk_epoch()) {
530 psk_bundles.push((psk_id, psk_bundle.secret.clone()));
531 } else {
532 return Err(PskError::KeyNotFound);
533 }
534 }
535 Psk::External(_) => {
536 let psk_bundle: Option<PskBundle> = storage
537 .psk(psk_id.psk())
538 .map_err(|_| PskError::KeyNotFound)?;
539 if let Some(psk_bundle) = psk_bundle {
540 psk_bundles.push((psk_id, psk_bundle.secret));
541 } else {
542 return Err(PskError::KeyNotFound);
543 }
544 }
545 }
546 }
547
548 Ok(psk_bundles)
549}
550
551pub mod store {
553 use serde::{Deserialize, Serialize};
554
555 use crate::{group::GroupEpoch, schedule::ResumptionPskSecret};
556
557 #[derive(Debug, Serialize, Deserialize)]
561 #[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
562 pub(crate) struct ResumptionPskStore {
563 max_number_of_secrets: usize,
564 resumption_psk: Vec<(GroupEpoch, ResumptionPskSecret)>,
565 cursor: usize,
566 }
567
568 impl ResumptionPskStore {
569 pub(crate) fn new(max_number_of_secrets: usize) -> Self {
571 Self {
572 max_number_of_secrets,
573 resumption_psk: vec![],
574 cursor: 0,
575 }
576 }
577
578 pub(crate) fn add(&mut self, epoch: GroupEpoch, resumption_psk: ResumptionPskSecret) {
580 if self.max_number_of_secrets == 0 {
581 return;
582 }
583 let item = (epoch, resumption_psk);
584 if self.resumption_psk.len() < self.max_number_of_secrets {
585 self.resumption_psk.push(item);
586 self.cursor += 1;
587 } else {
588 self.cursor += 1;
589 self.cursor %= self.resumption_psk.len();
590 self.resumption_psk[self.cursor] = item;
591 }
592 }
593
594 pub(crate) fn get(&self, epoch: GroupEpoch) -> Option<&ResumptionPskSecret> {
597 self.resumption_psk
598 .iter()
599 .find(|&(e, _s)| e == &epoch)
600 .map(|(_e, s)| s)
601 }
602 }
603
604 #[cfg(test)]
605 impl ResumptionPskStore {
606 pub(crate) fn cursor(&self) -> usize {
607 self.cursor
608 }
609 }
610}