1use std::{
25 fmt::Debug,
26 io::{Read, Write},
27};
28
29use serde::{Deserialize, Serialize};
30
31mod application_id_extension;
33mod codec;
34mod external_pub_extension;
35mod external_sender_extension;
36mod last_resort;
37mod ratchet_tree_extension;
38mod required_capabilities;
39use errors::*;
40
41pub mod errors;
43
44pub use application_id_extension::ApplicationIdExtension;
46pub use external_pub_extension::ExternalPubExtension;
47pub use external_sender_extension::{
48 ExternalSender, ExternalSendersExtension, SenderExtensionIndex,
49};
50pub use last_resort::LastResortExtension;
51pub use ratchet_tree_extension::RatchetTreeExtension;
52pub use required_capabilities::RequiredCapabilitiesExtension;
53use tls_codec::{
54 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
55 Size, TlsSize,
56};
57
58#[cfg(test)]
59mod tests;
60
61#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)]
77pub enum ExtensionType {
78 ApplicationId,
81
82 RatchetTree,
85
86 RequiredCapabilities,
89
90 ExternalPub,
93
94 ExternalSenders,
97
98 LastResort,
101
102 Unknown(u16),
104}
105
106impl ExtensionType {
107 pub(crate) fn is_default(self) -> bool {
109 match self {
110 ExtensionType::ApplicationId
111 | ExtensionType::RatchetTree
112 | ExtensionType::RequiredCapabilities
113 | ExtensionType::ExternalPub
114 | ExtensionType::ExternalSenders => true,
115 ExtensionType::LastResort | ExtensionType::Unknown(_) => false,
116 }
117 }
118
119 pub(crate) fn is_valid_in_leaf_node(self) -> Option<bool> {
123 match self {
124 ExtensionType::LastResort
125 | ExtensionType::RatchetTree
126 | ExtensionType::RequiredCapabilities
127 | ExtensionType::ExternalPub
128 | ExtensionType::ExternalSenders => Some(false),
129 ExtensionType::ApplicationId => Some(true),
130 ExtensionType::Unknown(_) => None,
131 }
132 }
133}
134
135impl Size for ExtensionType {
136 fn tls_serialized_len(&self) -> usize {
137 2
138 }
139}
140
141impl TlsDeserializeTrait for ExtensionType {
142 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
143 where
144 Self: Sized,
145 {
146 let mut extension_type = [0u8; 2];
147 bytes.read_exact(&mut extension_type)?;
148
149 Ok(ExtensionType::from(u16::from_be_bytes(extension_type)))
150 }
151}
152
153impl DeserializeBytes for ExtensionType {
154 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
155 where
156 Self: Sized,
157 {
158 let mut bytes_ref = bytes;
159 let extension_type = ExtensionType::tls_deserialize(&mut bytes_ref)?;
160 let remainder = &bytes[extension_type.tls_serialized_len()..];
161 Ok((extension_type, remainder))
162 }
163}
164
165impl TlsSerializeTrait for ExtensionType {
166 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
167 writer.write_all(&u16::from(*self).to_be_bytes())?;
168
169 Ok(2)
170 }
171}
172
173impl From<u16> for ExtensionType {
174 fn from(a: u16) -> Self {
175 match a {
176 1 => ExtensionType::ApplicationId,
177 2 => ExtensionType::RatchetTree,
178 3 => ExtensionType::RequiredCapabilities,
179 4 => ExtensionType::ExternalPub,
180 5 => ExtensionType::ExternalSenders,
181 10 => ExtensionType::LastResort,
182 unknown => ExtensionType::Unknown(unknown),
183 }
184 }
185}
186
187impl From<ExtensionType> for u16 {
188 fn from(value: ExtensionType) -> Self {
189 match value {
190 ExtensionType::ApplicationId => 1,
191 ExtensionType::RatchetTree => 2,
192 ExtensionType::RequiredCapabilities => 3,
193 ExtensionType::ExternalPub => 4,
194 ExtensionType::ExternalSenders => 5,
195 ExtensionType::LastResort => 10,
196 ExtensionType::Unknown(unknown) => unknown,
197 }
198 }
199}
200
201#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
216pub enum Extension {
217 ApplicationId(ApplicationIdExtension),
219
220 RatchetTree(RatchetTreeExtension),
222
223 RequiredCapabilities(RequiredCapabilitiesExtension),
225
226 ExternalPub(ExternalPubExtension),
228
229 ExternalSenders(ExternalSendersExtension),
231
232 LastResort(LastResortExtension),
234
235 Unknown(u16, UnknownExtension),
237}
238
239#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize)]
241pub struct UnknownExtension(pub Vec<u8>);
242
243#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TlsSize)]
245pub struct Extensions {
246 unique: Vec<Extension>,
247}
248
249impl TlsSerializeTrait for Extensions {
250 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
251 self.unique.tls_serialize(writer)
252 }
253}
254
255impl TlsDeserializeTrait for Extensions {
256 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
257 where
258 Self: Sized,
259 {
260 let candidate: Vec<Extension> = Vec::tls_deserialize(bytes)?;
261 Extensions::try_from(candidate)
262 .map_err(|_| Error::DecodingError("Found duplicate extensions".into()))
263 }
264}
265
266impl DeserializeBytes for Extensions {
267 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
268 where
269 Self: Sized,
270 {
271 let mut bytes_ref = bytes;
272 let extensions = Extensions::tls_deserialize(&mut bytes_ref)?;
273 let remainder = &bytes[extensions.tls_serialized_len()..];
274 Ok((extensions, remainder))
275 }
276}
277
278impl Extensions {
279 pub fn empty() -> Self {
281 Self { unique: vec![] }
282 }
283
284 pub fn single(extension: Extension) -> Self {
286 Self {
287 unique: vec![extension],
288 }
289 }
290
291 pub fn from_vec(extensions: Vec<Extension>) -> Result<Self, InvalidExtensionError> {
296 extensions.try_into()
297 }
298
299 pub fn iter(&self) -> impl Iterator<Item = &Extension> {
301 self.unique.iter()
302 }
303
304 pub fn add(&mut self, extension: Extension) -> Result<(), InvalidExtensionError> {
309 if self.contains(extension.extension_type()) {
310 return Err(InvalidExtensionError::Duplicate);
311 }
312
313 self.unique.push(extension);
314
315 Ok(())
316 }
317
318 pub fn add_or_replace(&mut self, extension: Extension) -> Option<Extension> {
322 let replaced = self.remove(extension.extension_type());
323 self.unique.push(extension);
324 replaced
325 }
326
327 pub fn remove(&mut self, extension_type: ExtensionType) -> Option<Extension> {
332 if let Some(pos) = self
333 .unique
334 .iter()
335 .position(|ext| ext.extension_type() == extension_type)
336 {
337 Some(self.unique.remove(pos))
338 } else {
339 None
340 }
341 }
342
343 pub fn contains(&self, extension_type: ExtensionType) -> bool {
346 self.unique
347 .iter()
348 .any(|ext| ext.extension_type() == extension_type)
349 }
350}
351
352impl TryFrom<Vec<Extension>> for Extensions {
353 type Error = InvalidExtensionError;
354
355 fn try_from(candidate: Vec<Extension>) -> Result<Self, Self::Error> {
356 let mut unique: Vec<Extension> = Vec::new();
357
358 for extension in candidate.into_iter() {
359 if unique
360 .iter()
361 .any(|ext| ext.extension_type() == extension.extension_type())
362 {
363 return Err(InvalidExtensionError::Duplicate);
364 } else {
365 unique.push(extension);
366 }
367 }
368
369 Ok(Self { unique })
370 }
371}
372
373impl Extensions {
374 fn find_by_type(&self, extension_type: ExtensionType) -> Option<&Extension> {
375 self.unique
376 .iter()
377 .find(|ext| ext.extension_type() == extension_type)
378 }
379
380 pub fn application_id(&self) -> Option<&ApplicationIdExtension> {
382 self.find_by_type(ExtensionType::ApplicationId)
383 .and_then(|e| match e {
384 Extension::ApplicationId(e) => Some(e),
385 _ => None,
386 })
387 }
388
389 pub fn ratchet_tree(&self) -> Option<&RatchetTreeExtension> {
391 self.find_by_type(ExtensionType::RatchetTree)
392 .and_then(|e| match e {
393 Extension::RatchetTree(e) => Some(e),
394 _ => None,
395 })
396 }
397
398 pub fn required_capabilities(&self) -> Option<&RequiredCapabilitiesExtension> {
401 self.find_by_type(ExtensionType::RequiredCapabilities)
402 .and_then(|e| match e {
403 Extension::RequiredCapabilities(e) => Some(e),
404 _ => None,
405 })
406 }
407
408 pub fn external_pub(&self) -> Option<&ExternalPubExtension> {
410 self.find_by_type(ExtensionType::ExternalPub)
411 .and_then(|e| match e {
412 Extension::ExternalPub(e) => Some(e),
413 _ => None,
414 })
415 }
416
417 pub fn external_senders(&self) -> Option<&ExternalSendersExtension> {
419 self.find_by_type(ExtensionType::ExternalSenders)
420 .and_then(|e| match e {
421 Extension::ExternalSenders(e) => Some(e),
422 _ => None,
423 })
424 }
425
426 pub fn unknown(&self, extension_type_id: u16) -> Option<&UnknownExtension> {
428 let extension_type: ExtensionType = extension_type_id.into();
429
430 match extension_type {
431 ExtensionType::Unknown(_) => self.find_by_type(extension_type).and_then(|e| match e {
432 Extension::Unknown(_, e) => Some(e),
433 _ => None,
434 }),
435 _ => None,
436 }
437 }
438}
439
440impl Extension {
441 pub fn as_application_id_extension(&self) -> Result<&ApplicationIdExtension, ExtensionError> {
445 match self {
446 Self::ApplicationId(e) => Ok(e),
447 _ => Err(ExtensionError::InvalidExtensionType(
448 "This is not an ApplicationIdExtension".into(),
449 )),
450 }
451 }
452
453 pub fn as_ratchet_tree_extension(&self) -> Result<&RatchetTreeExtension, ExtensionError> {
457 match self {
458 Self::RatchetTree(rte) => Ok(rte),
459 _ => Err(ExtensionError::InvalidExtensionType(
460 "This is not a RatchetTreeExtension".into(),
461 )),
462 }
463 }
464
465 pub fn as_required_capabilities_extension(
469 &self,
470 ) -> Result<&RequiredCapabilitiesExtension, ExtensionError> {
471 match self {
472 Self::RequiredCapabilities(e) => Ok(e),
473 _ => Err(ExtensionError::InvalidExtensionType(
474 "This is not a RequiredCapabilitiesExtension".into(),
475 )),
476 }
477 }
478
479 pub fn as_external_pub_extension(&self) -> Result<&ExternalPubExtension, ExtensionError> {
483 match self {
484 Self::ExternalPub(e) => Ok(e),
485 _ => Err(ExtensionError::InvalidExtensionType(
486 "This is not an ExternalPubExtension".into(),
487 )),
488 }
489 }
490
491 pub fn as_external_senders_extension(
495 &self,
496 ) -> Result<&ExternalSendersExtension, ExtensionError> {
497 match self {
498 Self::ExternalSenders(e) => Ok(e),
499 _ => Err(ExtensionError::InvalidExtensionType(
500 "This is not an ExternalSendersExtension".into(),
501 )),
502 }
503 }
504
505 #[inline]
507 pub const fn extension_type(&self) -> ExtensionType {
508 match self {
509 Extension::ApplicationId(_) => ExtensionType::ApplicationId,
510 Extension::RatchetTree(_) => ExtensionType::RatchetTree,
511 Extension::RequiredCapabilities(_) => ExtensionType::RequiredCapabilities,
512 Extension::ExternalPub(_) => ExtensionType::ExternalPub,
513 Extension::ExternalSenders(_) => ExtensionType::ExternalSenders,
514 Extension::LastResort(_) => ExtensionType::LastResort,
515 Extension::Unknown(kind, _) => ExtensionType::Unknown(*kind),
516 }
517 }
518}
519
520#[cfg(test)]
521mod test {
522 use itertools::Itertools;
523 use tls_codec::{Deserialize, Serialize, VLBytes};
524
525 use crate::{ciphersuite::HpkePublicKey, extensions::*};
526
527 #[test]
528 fn add() {
529 let mut extensions = Extensions::default();
530 extensions
531 .add(Extension::RequiredCapabilities(
532 RequiredCapabilitiesExtension::default(),
533 ))
534 .unwrap();
535 assert!(extensions
536 .add(Extension::RequiredCapabilities(
537 RequiredCapabilitiesExtension::default()
538 ))
539 .is_err());
540 }
541
542 #[test]
543 fn add_try_from() {
544 let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
547 let ext_y = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
548
549 let tests = [
550 (vec![], true),
551 (vec![ext_x.clone()], true),
552 (vec![ext_x.clone(), ext_x.clone()], false),
553 (vec![ext_x.clone(), ext_x.clone(), ext_x.clone()], false),
554 (vec![ext_y.clone()], true),
555 (vec![ext_y.clone(), ext_y.clone()], false),
556 (vec![ext_y.clone(), ext_y.clone(), ext_y.clone()], false),
557 (vec![ext_x.clone(), ext_y.clone()], true),
558 (vec![ext_y.clone(), ext_x.clone()], true),
559 (vec![ext_x.clone(), ext_x.clone(), ext_y.clone()], false),
560 (vec![ext_y.clone(), ext_y.clone(), ext_x.clone()], false),
561 (vec![ext_x.clone(), ext_y.clone(), ext_y.clone()], false),
562 (vec![ext_y.clone(), ext_x.clone(), ext_x.clone()], false),
563 (vec![ext_x.clone(), ext_y.clone(), ext_x.clone()], false),
564 (vec![ext_y.clone(), ext_x, ext_y], false),
565 ];
566
567 for (test, should_work) in tests.into_iter() {
568 {
570 let mut extensions = Extensions::default();
571
572 let mut works = true;
573 for ext in test.iter() {
574 match extensions.add(ext.clone()) {
575 Ok(_) => {}
576 Err(InvalidExtensionError::Duplicate) => {
577 works = false;
578 }
579 _ => panic!("This should have never happened."),
580 }
581 }
582
583 println!("{:?}, {:?}", test.clone(), should_work);
584 assert_eq!(works, should_work);
585 }
586
587 if should_work {
589 assert!(Extensions::try_from(test).is_ok());
590 } else {
591 assert!(Extensions::try_from(test).is_err());
592 }
593 }
594 }
595
596 #[test]
597 fn ensure_ordering() {
598 let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
602 let ext_y = Extension::ExternalPub(ExternalPubExtension::new(HpkePublicKey::new(vec![])));
603 let ext_z = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
604
605 for candidate in [ext_x, ext_y, ext_z]
606 .into_iter()
607 .permutations(3)
608 .collect::<Vec<_>>()
609 {
610 let candidate: Extensions = Extensions::try_from(candidate).unwrap();
611 let bytes = candidate.tls_serialize_detached().unwrap();
612 let got = Extensions::tls_deserialize(&mut bytes.as_slice()).unwrap();
613 assert_eq!(candidate, got);
614 }
615 }
616
617 #[test]
618 fn that_unknown_extensions_are_de_serialized_correctly() {
619 let extension_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF100, 0xFFFF];
620 let extension_datas = [vec![], vec![0], vec![1, 2, 3]];
621
622 for extension_type in extension_types.into_iter() {
623 for extension_data in extension_datas.iter() {
624 let test = {
626 let mut buf = extension_type.to_be_bytes().to_vec();
627 buf.append(
628 &mut VLBytes::new(extension_data.clone())
629 .tls_serialize_detached()
630 .unwrap(),
631 );
632 buf
633 };
634
635 let got = Extension::tls_deserialize_exact(&test).unwrap();
637
638 match got {
639 Extension::Unknown(got_extension_type, ref got_extension_data) => {
640 assert_eq!(extension_type, got_extension_type);
641 assert_eq!(extension_data, &got_extension_data.0);
642 }
643 other => panic!("Expected `Extension::Unknown`, got {:?}", other),
644 }
645
646 let got_serialized = got.tls_serialize_detached().unwrap();
648 assert_eq!(test, got_serialized);
649 }
650 }
651 }
652}