1use std::borrow::Borrow;
4
5use openmls_traits::{random::OpenMlsRand, storage::StorageProvider as StorageProviderTrait};
6use serde::{Deserialize, Serialize};
7use tls_codec::{Serialize as TlsSerializeTrait, VLBytes};
8
9use super::*;
10use crate::{
11 group::{GroupEpoch, GroupId},
12 schedule::psk::store::ResumptionPskStore,
13 storage::{OpenMlsProvider, StorageProvider},
14};
15
16#[derive(
29 Clone,
30 Copy,
31 Debug,
32 PartialEq,
33 Eq,
34 PartialOrd,
35 Ord,
36 Hash,
37 Deserialize,
38 Serialize,
39 TlsDeserialize,
40 TlsDeserializeBytes,
41 TlsSerialize,
42 TlsSize,
43)]
44#[repr(u8)]
45pub enum ResumptionPskUsage {
46 Application = 1,
48 Reinit = 2,
52 Branch = 3,
56}
57
58#[derive(
60 Debug,
61 PartialEq,
62 Eq,
63 PartialOrd,
64 Ord,
65 Clone,
66 Hash,
67 Deserialize,
68 Serialize,
69 TlsDeserialize,
70 TlsDeserializeBytes,
71 TlsSerialize,
72 TlsSize,
73)]
74pub struct ExternalPsk {
75 psk_id: VLBytes,
76}
77
78impl ExternalPsk {
79 pub fn new(psk_id: Vec<u8>) -> Self {
81 Self {
82 psk_id: psk_id.into(),
83 }
84 }
85
86 pub fn psk_id(&self) -> &[u8] {
88 self.psk_id.as_slice()
89 }
90}
91
92#[derive(Serialize, Deserialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)]
95pub(crate) struct PskBundle {
96 secret: Secret,
97}
98
99#[derive(
101 Clone,
102 Debug,
103 PartialEq,
104 Eq,
105 PartialOrd,
106 Ord,
107 Deserialize,
108 Serialize,
109 TlsDeserialize,
110 TlsDeserializeBytes,
111 TlsSerialize,
112 TlsSize,
113 Hash,
114)]
115pub struct ResumptionPsk {
116 pub(crate) usage: ResumptionPskUsage,
117 pub(crate) psk_group_id: GroupId,
118 pub(crate) psk_epoch: GroupEpoch,
119}
120
121impl ResumptionPsk {
122 pub fn new(usage: ResumptionPskUsage, psk_group_id: GroupId, psk_epoch: GroupEpoch) -> Self {
124 Self {
125 usage,
126 psk_group_id,
127 psk_epoch,
128 }
129 }
130
131 pub fn usage(&self) -> ResumptionPskUsage {
133 self.usage
134 }
135
136 pub fn psk_group_id(&self) -> &GroupId {
138 &self.psk_group_id
139 }
140
141 pub fn psk_epoch(&self) -> GroupEpoch {
143 self.psk_epoch
144 }
145}
146
147#[derive(
149 Clone,
150 Debug,
151 PartialEq,
152 Eq,
153 PartialOrd,
154 Ord,
155 Deserialize,
156 Serialize,
157 TlsDeserialize,
158 TlsDeserializeBytes,
159 TlsSerialize,
160 TlsSize,
161 Hash,
162)]
163#[repr(u8)]
164pub enum Psk {
165 #[tls_codec(discriminant = 1)]
167 External(ExternalPsk),
168 #[tls_codec(discriminant = 2)]
170 Resumption(ResumptionPsk),
171}
172
173#[derive(Clone, Copy, Debug, Eq, PartialEq)]
183#[repr(u8)]
184pub enum PskType {
185 External = 1,
187 Resumption = 2,
189}
190
191#[derive(
211 Clone,
212 Debug,
213 PartialEq,
214 Eq,
215 PartialOrd,
216 Ord,
217 Deserialize,
218 Serialize,
219 TlsDeserialize,
220 TlsDeserializeBytes,
221 TlsSerialize,
222 TlsSize,
223 Hash,
224)]
225pub struct PreSharedKeyId {
226 pub(crate) psk: Psk,
227 pub(crate) psk_nonce: VLBytes,
228}
229
230impl PreSharedKeyId {
231 pub fn new(
233 ciphersuite: Ciphersuite,
234 rand: &impl OpenMlsRand,
235 psk: Psk,
236 ) -> Result<Self, CryptoError> {
237 let psk_nonce = rand
238 .random_vec(ciphersuite.hash_length())
239 .map_err(|_| CryptoError::InsufficientRandomness)?
240 .into();
241
242 Ok(Self { psk, psk_nonce })
243 }
244
245 pub fn external(psk_id: Vec<u8>, psk_nonce: Vec<u8>) -> Self {
247 let psk = Psk::External(ExternalPsk::new(psk_id));
248
249 Self {
250 psk,
251 psk_nonce: psk_nonce.into(),
252 }
253 }
254
255 pub fn resumption(
257 usage: ResumptionPskUsage,
258 psk_group_id: GroupId,
259 psk_epoch: GroupEpoch,
260 psk_nonce: Vec<u8>,
261 ) -> Self {
262 let psk = Psk::Resumption(ResumptionPsk::new(usage, psk_group_id, psk_epoch));
263
264 Self {
265 psk,
266 psk_nonce: psk_nonce.into(),
267 }
268 }
269
270 pub fn psk(&self) -> &Psk {
272 &self.psk
273 }
274
275 pub fn psk_nonce(&self) -> &[u8] {
277 self.psk_nonce.as_slice()
278 }
279
280 pub fn store<Provider: OpenMlsProvider>(
286 &self,
287 provider: &Provider,
288 psk: &[u8],
289 ) -> Result<(), PskError> {
290 let psk_bundle = {
291 let secret = Secret::from_slice(psk);
292
293 PskBundle { secret }
294 };
295
296 provider
297 .storage()
298 .write_psk(&self.psk, &psk_bundle)
299 .map_err(|_| PskError::Storage)
300 }
301
302 pub(crate) fn validate_in_proposal(self, ciphersuite: Ciphersuite) -> Result<(), PskError> {
305 match self.psk() {
307 Psk::Resumption(resumption_psk) => {
308 if resumption_psk.usage != ResumptionPskUsage::Application {
311 return Err(PskError::UsageMismatch {
312 allowed: vec![ResumptionPskUsage::Application],
313 got: resumption_psk.usage,
314 });
315 }
316 }
317 Psk::External(_) => {}
318 };
319
320 {
323 let expected_nonce_length = ciphersuite.hash_length();
324 let got_nonce_length = self.psk_nonce().len();
325
326 if expected_nonce_length != got_nonce_length {
327 return Err(PskError::NonceLengthMismatch {
328 expected: expected_nonce_length,
329 got: got_nonce_length,
330 });
331 }
332 }
333
334 Ok(())
335 }
336
337 pub(crate) fn validate_in_welcome(
338 psk_ids: &[PreSharedKeyId],
339 ciphersuite: Ciphersuite,
340 ) -> Result<(), PskError> {
341 let mut contains_branch_psk = false;
342 let mut contains_reinit_psk = false;
343 for id in psk_ids {
344 match id.psk() {
346 Psk::Resumption(resumption_psk) => match resumption_psk.usage {
347 ResumptionPskUsage::Application => {
348 return Err(PskError::UsageMismatch {
349 allowed: vec![ResumptionPskUsage::Reinit, ResumptionPskUsage::Branch],
350 got: resumption_psk.usage,
351 });
352 }
353 ResumptionPskUsage::Reinit => {
354 if contains_reinit_psk {
355 return Err(PskError::UsageDuplicate {
356 usage: ResumptionPskUsage::Reinit,
357 });
358 }
359 if contains_branch_psk {
360 return Err(PskError::UsageConflict {
361 first: ResumptionPskUsage::Reinit,
362 second: ResumptionPskUsage::Branch,
363 });
364 }
365 contains_reinit_psk = true;
366 }
367 ResumptionPskUsage::Branch => {
368 if contains_branch_psk {
369 return Err(PskError::UsageDuplicate {
370 usage: ResumptionPskUsage::Branch,
371 });
372 }
373 if contains_reinit_psk {
374 return Err(PskError::UsageConflict {
375 first: ResumptionPskUsage::Branch,
376 second: ResumptionPskUsage::Reinit,
377 });
378 }
379 contains_branch_psk = true;
380 }
381 },
382 Psk::External(_) => {}
383 };
384
385 {
386 let expected_nonce_length = ciphersuite.hash_length();
387 let got_nonce_length = id.psk_nonce().len();
388
389 if expected_nonce_length != got_nonce_length {
390 return Err(PskError::NonceLengthMismatch {
391 expected: expected_nonce_length,
392 got: got_nonce_length,
393 });
394 }
395 }
396 }
397 Ok(())
398 }
399}
400
401#[cfg(test)]
402impl PreSharedKeyId {
403 pub(crate) fn new_with_nonce(psk: Psk, psk_nonce: Vec<u8>) -> Self {
404 Self {
405 psk,
406 psk_nonce: psk_nonce.into(),
407 }
408 }
409}
410
411#[derive(TlsSerialize, TlsSize)]
423pub(crate) struct PskLabel<'a> {
424 pub(crate) id: &'a PreSharedKeyId,
425 pub(crate) index: u16,
426 pub(crate) count: u16,
427}
428
429impl<'a> PskLabel<'a> {
430 fn new(id: &'a PreSharedKeyId, index: u16, count: u16) -> Self {
432 Self { id, index, count }
433 }
434}
435
436#[derive(Clone)]
439pub struct PskSecret {
440 secret: Secret,
441}
442
443impl PskSecret {
444 pub(crate) fn new(
455 crypto: &impl OpenMlsCrypto,
456 ciphersuite: Ciphersuite,
457 psks: Vec<(impl Borrow<PreSharedKeyId>, Secret)>,
458 ) -> Result<Self, PskError> {
459 let num_psks = u16::try_from(psks.len()).map_err(|_| PskError::TooManyKeys)?;
461
462 let mut psk_secret = Secret::zero(ciphersuite);
466
467 for (index, (psk_id, psk)) in psks.into_iter().enumerate() {
468 let psk_extracted = {
470 let zero_secret = Secret::zero(ciphersuite);
471 zero_secret
472 .hkdf_extract(crypto, ciphersuite, &psk)
473 .map_err(LibraryError::unexpected_crypto_error)?
474 };
475
476 let psk_input = {
478 let psk_label = PskLabel::new(psk_id.borrow(), index as u16, num_psks)
479 .tls_serialize_detached()
480 .map_err(LibraryError::missing_bound_check)?;
481
482 psk_extracted
483 .kdf_expand_label(
484 crypto,
485 ciphersuite,
486 "derived psk",
487 &psk_label,
488 ciphersuite.hash_length(),
489 )
490 .map_err(LibraryError::unexpected_crypto_error)?
491 };
492
493 psk_secret = psk_input
495 .hkdf_extract(crypto, ciphersuite, &psk_secret)
496 .map_err(LibraryError::unexpected_crypto_error)?;
497 }
498
499 Ok(Self { secret: psk_secret })
500 }
501
502 pub(crate) fn secret(&self) -> &Secret {
504 &self.secret
505 }
506
507 #[cfg(any(feature = "test-utils", feature = "crypto-debug", test))]
508 pub(crate) fn as_slice(&self) -> &[u8] {
509 self.secret.as_slice()
510 }
511}
512
513#[cfg(any(feature = "test-utils", test))]
514impl From<Secret> for PskSecret {
515 fn from(secret: Secret) -> Self {
516 Self { secret }
517 }
518}
519
520pub(crate) fn load_psks<'p, Storage: StorageProvider>(
521 storage: &Storage,
522 resumption_psk_store: &ResumptionPskStore,
523 psk_ids: &'p [PreSharedKeyId],
524) -> Result<Vec<(&'p PreSharedKeyId, Secret)>, PskError> {
525 let mut psk_bundles = Vec::new();
526
527 for psk_id in psk_ids.iter() {
528 log_crypto!(trace, "PSK store {:?}", resumption_psk_store);
529
530 match &psk_id.psk {
531 Psk::Resumption(resumption) => {
532 if let Some(psk_bundle) = resumption_psk_store.get(resumption.psk_epoch()) {
533 psk_bundles.push((psk_id, psk_bundle.secret.clone()));
534 } else {
535 return Err(PskError::KeyNotFound);
536 }
537 }
538 Psk::External(_) => {
539 let psk_bundle: Option<PskBundle> = storage
540 .psk(psk_id.psk())
541 .map_err(|_| PskError::KeyNotFound)?;
542 if let Some(psk_bundle) = psk_bundle {
543 psk_bundles.push((psk_id, psk_bundle.secret));
544 } else {
545 return Err(PskError::KeyNotFound);
546 }
547 }
548 }
549 }
550
551 Ok(psk_bundles)
552}
553
554pub mod store {
556 use serde::{Deserialize, Serialize};
557
558 use crate::{group::GroupEpoch, schedule::ResumptionPskSecret};
559
560 #[derive(Debug, Serialize, Deserialize)]
564 #[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
565 pub(crate) struct ResumptionPskStore {
566 max_number_of_secrets: usize,
567 resumption_psk: Vec<(GroupEpoch, ResumptionPskSecret)>,
568 cursor: usize,
569 }
570
571 impl ResumptionPskStore {
572 pub(crate) fn new(max_number_of_secrets: usize) -> Self {
574 Self {
575 max_number_of_secrets,
576 resumption_psk: vec![],
577 cursor: 0,
578 }
579 }
580
581 pub(crate) fn add(&mut self, epoch: GroupEpoch, resumption_psk: ResumptionPskSecret) {
583 if self.max_number_of_secrets == 0 {
584 return;
585 }
586 let item = (epoch, resumption_psk);
587 if self.resumption_psk.len() < self.max_number_of_secrets {
588 self.resumption_psk.push(item);
589 self.cursor += 1;
590 } else {
591 self.cursor += 1;
592 self.cursor %= self.resumption_psk.len();
593 self.resumption_psk[self.cursor] = item;
594 }
595 }
596
597 pub(crate) fn get(&self, epoch: GroupEpoch) -> Option<&ResumptionPskSecret> {
600 self.resumption_psk
601 .iter()
602 .find(|&(e, _s)| e == &epoch)
603 .map(|(_e, s)| s)
604 }
605 }
606
607 #[cfg(test)]
608 impl ResumptionPskStore {
609 pub(crate) fn cursor(&self) -> usize {
610 self.cursor
611 }
612 }
613}