1use std::borrow::Borrow;
4
5use openmls_traits::{random::OpenMlsRand, storage::StorageProvider as StorageProviderTrait};
6use serde::{Deserialize, Serialize};
7use tls_codec::{Serialize as TlsSerializeTrait, VLBytes};
8
9use super::*;
10use crate::{
11 group::{GroupEpoch, GroupId},
12 schedule::psk::store::ResumptionPskStore,
13 storage::{OpenMlsProvider, StorageProvider},
14};
15
16#[derive(
29 Clone,
30 Copy,
31 Debug,
32 PartialEq,
33 Eq,
34 PartialOrd,
35 Ord,
36 Hash,
37 Deserialize,
38 Serialize,
39 TlsDeserialize,
40 TlsDeserializeBytes,
41 TlsSerialize,
42 TlsSize,
43)]
44#[repr(u8)]
45pub enum ResumptionPskUsage {
46 Application = 1,
48 Reinit = 2,
52 Branch = 3,
56}
57
58#[derive(
60 Debug,
61 PartialEq,
62 Eq,
63 PartialOrd,
64 Ord,
65 Clone,
66 Hash,
67 Deserialize,
68 Serialize,
69 TlsDeserialize,
70 TlsDeserializeBytes,
71 TlsSerialize,
72 TlsSize,
73)]
74pub struct ExternalPsk {
75 psk_id: VLBytes,
76}
77
78impl ExternalPsk {
79 pub fn new(psk_id: Vec<u8>) -> Self {
81 Self {
82 psk_id: psk_id.into(),
83 }
84 }
85
86 pub fn psk_id(&self) -> &[u8] {
88 self.psk_id.as_slice()
89 }
90}
91
92#[derive(Serialize, Deserialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)]
95pub(crate) struct PskBundle {
96 secret: Secret,
97}
98
99#[derive(
101 Clone,
102 Debug,
103 PartialEq,
104 Eq,
105 PartialOrd,
106 Ord,
107 Deserialize,
108 Serialize,
109 TlsDeserialize,
110 TlsDeserializeBytes,
111 TlsSerialize,
112 TlsSize,
113)]
114pub struct ResumptionPsk {
115 pub(crate) usage: ResumptionPskUsage,
116 pub(crate) psk_group_id: GroupId,
117 pub(crate) psk_epoch: GroupEpoch,
118}
119
120impl ResumptionPsk {
121 pub fn new(usage: ResumptionPskUsage, psk_group_id: GroupId, psk_epoch: GroupEpoch) -> Self {
123 Self {
124 usage,
125 psk_group_id,
126 psk_epoch,
127 }
128 }
129
130 pub fn usage(&self) -> ResumptionPskUsage {
132 self.usage
133 }
134
135 pub fn psk_group_id(&self) -> &GroupId {
137 &self.psk_group_id
138 }
139
140 pub fn psk_epoch(&self) -> GroupEpoch {
142 self.psk_epoch
143 }
144}
145
146#[derive(
148 Clone,
149 Debug,
150 PartialEq,
151 Eq,
152 PartialOrd,
153 Ord,
154 Deserialize,
155 Serialize,
156 TlsDeserialize,
157 TlsDeserializeBytes,
158 TlsSerialize,
159 TlsSize,
160)]
161#[repr(u8)]
162pub enum Psk {
163 #[tls_codec(discriminant = 1)]
165 External(ExternalPsk),
166 #[tls_codec(discriminant = 2)]
168 Resumption(ResumptionPsk),
169}
170
171#[derive(Clone, Copy, Debug, Eq, PartialEq)]
181#[repr(u8)]
182pub enum PskType {
183 External = 1,
185 Resumption = 2,
187}
188
189#[derive(
209 Clone,
210 Debug,
211 PartialEq,
212 Eq,
213 PartialOrd,
214 Ord,
215 Deserialize,
216 Serialize,
217 TlsDeserialize,
218 TlsDeserializeBytes,
219 TlsSerialize,
220 TlsSize,
221)]
222pub struct PreSharedKeyId {
223 pub(crate) psk: Psk,
224 pub(crate) psk_nonce: VLBytes,
225}
226
227impl PreSharedKeyId {
228 pub fn new(
230 ciphersuite: Ciphersuite,
231 rand: &impl OpenMlsRand,
232 psk: Psk,
233 ) -> Result<Self, CryptoError> {
234 let psk_nonce = rand
235 .random_vec(ciphersuite.hash_length())
236 .map_err(|_| CryptoError::InsufficientRandomness)?
237 .into();
238
239 Ok(Self { psk, psk_nonce })
240 }
241
242 pub fn external(psk_id: Vec<u8>, psk_nonce: Vec<u8>) -> Self {
244 let psk = Psk::External(ExternalPsk::new(psk_id));
245
246 Self {
247 psk,
248 psk_nonce: psk_nonce.into(),
249 }
250 }
251
252 pub fn resumption(
254 usage: ResumptionPskUsage,
255 psk_group_id: GroupId,
256 psk_epoch: GroupEpoch,
257 psk_nonce: Vec<u8>,
258 ) -> Self {
259 let psk = Psk::Resumption(ResumptionPsk::new(usage, psk_group_id, psk_epoch));
260
261 Self {
262 psk,
263 psk_nonce: psk_nonce.into(),
264 }
265 }
266
267 pub fn psk(&self) -> &Psk {
269 &self.psk
270 }
271
272 pub fn psk_nonce(&self) -> &[u8] {
274 self.psk_nonce.as_slice()
275 }
276
277 pub fn store<Provider: OpenMlsProvider>(
283 &self,
284 provider: &Provider,
285 psk: &[u8],
286 ) -> Result<(), PskError> {
287 let psk_bundle = {
288 let secret = Secret::from_slice(psk);
289
290 PskBundle { secret }
291 };
292
293 provider
294 .storage()
295 .write_psk(&self.psk, &psk_bundle)
296 .map_err(|_| PskError::Storage)
297 }
298
299 pub(crate) fn validate_in_proposal(self, ciphersuite: Ciphersuite) -> Result<Self, PskError> {
302 match self.psk() {
304 Psk::Resumption(resumption_psk) => {
305 if resumption_psk.usage != ResumptionPskUsage::Application {
306 return Err(PskError::UsageMismatch {
307 allowed: vec![ResumptionPskUsage::Application],
308 got: resumption_psk.usage,
309 });
310 }
311 }
312 Psk::External(_) => {}
313 };
314
315 {
317 let expected_nonce_length = ciphersuite.hash_length();
318 let got_nonce_length = self.psk_nonce().len();
319
320 if expected_nonce_length != got_nonce_length {
321 return Err(PskError::NonceLengthMismatch {
322 expected: expected_nonce_length,
323 got: got_nonce_length,
324 });
325 }
326 }
327
328 Ok(self)
329 }
330}
331
332#[cfg(test)]
333impl PreSharedKeyId {
334 pub(crate) fn new_with_nonce(psk: Psk, psk_nonce: Vec<u8>) -> Self {
335 Self {
336 psk,
337 psk_nonce: psk_nonce.into(),
338 }
339 }
340}
341
342#[derive(TlsSerialize, TlsSize)]
354pub(crate) struct PskLabel<'a> {
355 pub(crate) id: &'a PreSharedKeyId,
356 pub(crate) index: u16,
357 pub(crate) count: u16,
358}
359
360impl<'a> PskLabel<'a> {
361 fn new(id: &'a PreSharedKeyId, index: u16, count: u16) -> Self {
363 Self { id, index, count }
364 }
365}
366
367#[derive(Clone)]
370pub struct PskSecret {
371 secret: Secret,
372}
373
374impl PskSecret {
375 pub(crate) fn new(
386 crypto: &impl OpenMlsCrypto,
387 ciphersuite: Ciphersuite,
388 psks: Vec<(impl Borrow<PreSharedKeyId>, Secret)>,
389 ) -> Result<Self, PskError> {
390 let num_psks = u16::try_from(psks.len()).map_err(|_| PskError::TooManyKeys)?;
392
393 let mut psk_secret = Secret::zero(ciphersuite);
397
398 for (index, (psk_id, psk)) in psks.into_iter().enumerate() {
399 let psk_extracted = {
401 let zero_secret = Secret::zero(ciphersuite);
402 zero_secret
403 .hkdf_extract(crypto, ciphersuite, &psk)
404 .map_err(LibraryError::unexpected_crypto_error)?
405 };
406
407 let psk_input = {
409 let psk_label = PskLabel::new(psk_id.borrow(), index as u16, num_psks)
410 .tls_serialize_detached()
411 .map_err(LibraryError::missing_bound_check)?;
412
413 psk_extracted
414 .kdf_expand_label(
415 crypto,
416 ciphersuite,
417 "derived psk",
418 &psk_label,
419 ciphersuite.hash_length(),
420 )
421 .map_err(LibraryError::unexpected_crypto_error)?
422 };
423
424 psk_secret = psk_input
426 .hkdf_extract(crypto, ciphersuite, &psk_secret)
427 .map_err(LibraryError::unexpected_crypto_error)?;
428 }
429
430 Ok(Self { secret: psk_secret })
431 }
432
433 pub(crate) fn secret(&self) -> &Secret {
435 &self.secret
436 }
437
438 #[cfg(any(feature = "test-utils", feature = "crypto-debug", test))]
439 pub(crate) fn as_slice(&self) -> &[u8] {
440 self.secret.as_slice()
441 }
442}
443
444#[cfg(any(feature = "test-utils", test))]
445impl From<Secret> for PskSecret {
446 fn from(secret: Secret) -> Self {
447 Self { secret }
448 }
449}
450
451pub(crate) fn load_psks<'p, Storage: StorageProvider>(
452 storage: &Storage,
453 resumption_psk_store: &ResumptionPskStore,
454 psk_ids: &'p [PreSharedKeyId],
455) -> Result<Vec<(&'p PreSharedKeyId, Secret)>, PskError> {
456 let mut psk_bundles = Vec::new();
457
458 for psk_id in psk_ids.iter() {
459 log_crypto!(trace, "PSK store {:?}", resumption_psk_store);
460
461 match &psk_id.psk {
462 Psk::Resumption(resumption) => {
463 if let Some(psk_bundle) = resumption_psk_store.get(resumption.psk_epoch()) {
464 psk_bundles.push((psk_id, psk_bundle.secret.clone()));
465 } else {
466 return Err(PskError::KeyNotFound);
467 }
468 }
469 Psk::External(_) => {
470 let psk_bundle: Option<PskBundle> = storage
471 .psk(psk_id.psk())
472 .map_err(|_| PskError::KeyNotFound)?;
473 if let Some(psk_bundle) = psk_bundle {
474 psk_bundles.push((psk_id, psk_bundle.secret));
475 } else {
476 return Err(PskError::KeyNotFound);
477 }
478 }
479 }
480 }
481
482 Ok(psk_bundles)
483}
484
485pub mod store {
487 use serde::{Deserialize, Serialize};
488
489 use crate::{group::GroupEpoch, schedule::ResumptionPskSecret};
490
491 #[derive(Debug, Serialize, Deserialize)]
495 #[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
496 pub(crate) struct ResumptionPskStore {
497 max_number_of_secrets: usize,
498 resumption_psk: Vec<(GroupEpoch, ResumptionPskSecret)>,
499 cursor: usize,
500 }
501
502 impl ResumptionPskStore {
503 pub(crate) fn new(max_number_of_secrets: usize) -> Self {
505 Self {
506 max_number_of_secrets,
507 resumption_psk: vec![],
508 cursor: 0,
509 }
510 }
511
512 pub(crate) fn add(&mut self, epoch: GroupEpoch, resumption_psk: ResumptionPskSecret) {
514 if self.max_number_of_secrets == 0 {
515 return;
516 }
517 let item = (epoch, resumption_psk);
518 if self.resumption_psk.len() < self.max_number_of_secrets {
519 self.resumption_psk.push(item);
520 self.cursor += 1;
521 } else {
522 self.cursor += 1;
523 self.cursor %= self.resumption_psk.len();
524 self.resumption_psk[self.cursor] = item;
525 }
526 }
527
528 pub(crate) fn get(&self, epoch: GroupEpoch) -> Option<&ResumptionPskSecret> {
531 self.resumption_psk
532 .iter()
533 .find(|&(e, _s)| e == &epoch)
534 .map(|(_e, s)| s)
535 }
536 }
537
538 #[cfg(test)]
539 impl ResumptionPskStore {
540 pub(crate) fn cursor(&self) -> usize {
541 self.cursor
542 }
543 }
544}