1use std::{
25 convert::Infallible,
26 fmt::Debug,
27 io::{Read, Write},
28 marker::PhantomData,
29};
30
31use serde::{Deserialize, Serialize};
32
33#[cfg(feature = "extensions-draft-08")]
35mod app_data_dict_extension;
36mod application_id_extension;
37mod codec;
38mod external_pub_extension;
39mod external_sender_extension;
40mod last_resort;
41mod ratchet_tree_extension;
42mod required_capabilities;
43use errors::*;
44
45pub mod errors;
47
48#[cfg(feature = "extensions-draft-08")]
50pub use app_data_dict_extension::{AppDataDictionary, AppDataDictionaryExtension};
51pub use application_id_extension::ApplicationIdExtension;
52pub use external_pub_extension::ExternalPubExtension;
53pub use external_sender_extension::{
54 ExternalSender, ExternalSendersExtension, SenderExtensionIndex,
55};
56pub use last_resort::LastResortExtension;
57pub use ratchet_tree_extension::RatchetTreeExtension;
58pub use required_capabilities::RequiredCapabilitiesExtension;
59
60use tls_codec::{
61 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
62 Size, TlsDeserialize, TlsSerialize, TlsSize,
63};
64
65use crate::{
66 group::GroupContext, key_packages::KeyPackage, messages::group_info::GroupInfo,
67 treesync::LeafNode,
68};
69
70#[cfg(test)]
71mod tests;
72
73#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)]
89pub enum ExtensionType {
90 ApplicationId,
93
94 RatchetTree,
97
98 RequiredCapabilities,
101
102 ExternalPub,
105
106 ExternalSenders,
109
110 LastResort,
113
114 #[cfg(feature = "extensions-draft-08")]
115 AppDataDictionary,
117
118 Grease(u16),
120
121 Unknown(u16),
123}
124
125impl ExtensionType {
126 pub(crate) fn is_default(self) -> bool {
128 match self {
129 ExtensionType::ApplicationId
130 | ExtensionType::RatchetTree
131 | ExtensionType::RequiredCapabilities
132 | ExtensionType::ExternalPub
133 | ExtensionType::ExternalSenders => true,
134 ExtensionType::LastResort | ExtensionType::Grease(_) | ExtensionType::Unknown(_) => {
135 false
136 }
137 #[cfg(feature = "extensions-draft-08")]
138 ExtensionType::AppDataDictionary => false,
139 }
140 }
141
142 pub(crate) fn is_valid_in_leaf_node(self) -> bool {
147 match self {
148 ExtensionType::Grease(_)
149 | ExtensionType::LastResort
150 | ExtensionType::RatchetTree
151 | ExtensionType::RequiredCapabilities
152 | ExtensionType::ExternalPub
153 | ExtensionType::ExternalSenders => false,
154 ExtensionType::Unknown(_) | ExtensionType::ApplicationId => true,
155 #[cfg(feature = "extensions-draft-08")]
156 ExtensionType::AppDataDictionary => true,
157 }
158 }
159 pub(crate) fn is_valid_in_group_info(self) -> Option<bool> {
160 match self {
161 ExtensionType::Grease(_)
162 | ExtensionType::LastResort
163 | ExtensionType::RequiredCapabilities
164 | ExtensionType::ExternalSenders
165 | ExtensionType::ApplicationId => Some(false),
166 ExtensionType::RatchetTree | ExtensionType::ExternalPub => Some(true),
167 ExtensionType::Unknown(_) => None,
168 #[cfg(feature = "extensions-draft-08")]
169 ExtensionType::AppDataDictionary => Some(true),
170 }
171 }
172
173 pub(crate) fn is_valid_in_group_context(self) -> bool {
174 match self {
175 ExtensionType::RequiredCapabilities
176 | ExtensionType::ExternalSenders
177 | ExtensionType::Unknown(_) => true,
178 #[cfg(feature = "extensions-draft-08")]
179 ExtensionType::AppDataDictionary => true,
180 _ => false,
181 }
182 }
183
184 pub fn is_grease(&self) -> bool {
189 matches!(self, ExtensionType::Grease(_))
190 }
191}
192
193impl Size for ExtensionType {
194 fn tls_serialized_len(&self) -> usize {
195 2
196 }
197}
198
199impl TlsDeserializeTrait for ExtensionType {
200 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
201 where
202 Self: Sized,
203 {
204 let mut extension_type = [0u8; 2];
205 bytes.read_exact(&mut extension_type)?;
206
207 Ok(ExtensionType::from(u16::from_be_bytes(extension_type)))
208 }
209}
210
211impl DeserializeBytes for ExtensionType {
212 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
213 where
214 Self: Sized,
215 {
216 let mut bytes_ref = bytes;
217 let extension_type = ExtensionType::tls_deserialize(&mut bytes_ref)?;
218 let remainder = &bytes[extension_type.tls_serialized_len()..];
219 Ok((extension_type, remainder))
220 }
221}
222
223impl TlsSerializeTrait for ExtensionType {
224 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
225 writer.write_all(&u16::from(*self).to_be_bytes())?;
226
227 Ok(2)
228 }
229}
230
231impl From<u16> for ExtensionType {
232 fn from(a: u16) -> Self {
233 match a {
234 1 => ExtensionType::ApplicationId,
235 2 => ExtensionType::RatchetTree,
236 3 => ExtensionType::RequiredCapabilities,
237 4 => ExtensionType::ExternalPub,
238 5 => ExtensionType::ExternalSenders,
239 #[cfg(feature = "extensions-draft-08")]
240 6 => ExtensionType::AppDataDictionary,
241 10 => ExtensionType::LastResort,
242 unknown if crate::grease::is_grease_value(unknown) => ExtensionType::Grease(unknown),
243 unknown => ExtensionType::Unknown(unknown),
244 }
245 }
246}
247
248impl From<ExtensionType> for u16 {
249 fn from(value: ExtensionType) -> Self {
250 match value {
251 ExtensionType::ApplicationId => 1,
252 ExtensionType::RatchetTree => 2,
253 ExtensionType::RequiredCapabilities => 3,
254 ExtensionType::ExternalPub => 4,
255 ExtensionType::ExternalSenders => 5,
256 #[cfg(feature = "extensions-draft-08")]
257 ExtensionType::AppDataDictionary => 6,
258 ExtensionType::LastResort => 10,
259 ExtensionType::Grease(value) => value,
260 ExtensionType::Unknown(unknown) => unknown,
261 }
262 }
263}
264
265#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
280pub enum Extension {
281 ApplicationId(ApplicationIdExtension),
283
284 RatchetTree(RatchetTreeExtension),
286
287 RequiredCapabilities(RequiredCapabilitiesExtension),
289
290 ExternalPub(ExternalPubExtension),
292
293 ExternalSenders(ExternalSendersExtension),
295
296 #[cfg(feature = "extensions-draft-08")]
298 AppDataDictionary(AppDataDictionaryExtension),
299
300 LastResort(LastResortExtension),
302
303 Unknown(u16, UnknownExtension),
305}
306
307#[derive(
309 PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSize, TlsSerialize, TlsDeserialize,
310)]
311pub struct UnknownExtension(pub Vec<u8>);
312
313#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
315pub struct Extensions<T> {
316 unique: Vec<Extension>,
317 #[serde(skip)]
318 _object: core::marker::PhantomData<T>,
319}
320
321#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, TlsSize, TlsSerialize, TlsDeserialize)]
322pub struct AnyObject;
324
325impl<T> Default for Extensions<T> {
326 fn default() -> Self {
327 Self {
328 unique: vec![],
329 _object: PhantomData,
330 }
331 }
332}
333
334impl<T> Size for Extensions<T> {
335 fn tls_serialized_len(&self) -> usize {
336 Vec::tls_serialized_len(&self.unique)
337 }
338}
339
340impl<T> TlsSerializeTrait for Extensions<T> {
341 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
342 self.unique.tls_serialize(writer)
343 }
344}
345
346impl<T: ExtensionValidator> TlsDeserializeTrait for Extensions<T>
347where
348 InvalidExtensionError: From<T::Error>,
349{
350 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
351 where
352 Self: Sized,
353 {
354 let candidate: Vec<Extension> = Vec::tls_deserialize(bytes)?;
355 Extensions::<T>::try_from(candidate)
356 .map_err(|_| Error::DecodingError("Found duplicate extensions".into()))
357 }
358}
359
360impl<T: ExtensionValidator> DeserializeBytes for Extensions<T>
361where
362 InvalidExtensionError: From<T::Error>,
363{
364 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
365 where
366 Self: Sized,
367 {
368 let mut bytes_ref = bytes;
369 let extensions = Extensions::<T>::tls_deserialize(&mut bytes_ref)?;
370 let remainder = &bytes[extensions.tls_serialized_len()..];
371 Ok((extensions, remainder))
372 }
373}
374
375impl<T: ExtensionValidator> Extensions<T> {
376 pub fn empty() -> Self {
378 Self {
379 unique: vec![],
380 _object: PhantomData,
381 }
382 }
383
384 pub fn iter(&self) -> impl Iterator<Item = &Extension> {
386 self.unique.iter()
387 }
388
389 pub fn remove(&mut self, extension_type: ExtensionType) -> Option<Extension> {
394 if let Some(pos) = self
395 .unique
396 .iter()
397 .position(|ext| ext.extension_type() == extension_type)
398 {
399 Some(self.unique.remove(pos))
400 } else {
401 None
402 }
403 }
404
405 pub fn contains(&self, extension_type: ExtensionType) -> bool {
408 self.unique
409 .iter()
410 .any(|ext| ext.extension_type() == extension_type)
411 }
412}
413
414impl<T> Extensions<T>
415where
416 T: ExtensionValidator,
417 InvalidExtensionError: From<T::Error>,
418{
419 pub fn single(extension: Extension) -> Result<Self, InvalidExtensionError> {
421 T::validate_extension_type(&extension)?;
422 Ok(Self {
423 unique: vec![extension],
424 _object: PhantomData,
425 })
426 }
427
428 pub fn from_vec(extensions: Vec<Extension>) -> Result<Self, InvalidExtensionError> {
433 extensions.try_into()
434 }
435
436 pub fn validate<'a>(
438 extensions: impl Iterator<Item = &'a Extension>,
439 ) -> Result<(), InvalidExtensionError> {
440 for ext in extensions {
441 T::validate_extension_type(ext)?;
442 }
443 Ok(())
444 }
445
446 pub fn add(&mut self, extension: Extension) -> Result<(), InvalidExtensionError> {
451 T::validate_extension_type(&extension)?;
452 if self.contains(extension.extension_type()) {
453 return Err(InvalidExtensionError::Duplicate);
454 }
455
456 self.unique.push(extension);
457
458 Ok(())
459 }
460
461 pub fn add_or_replace(
465 &mut self,
466 extension: Extension,
467 ) -> Result<Option<Extension>, InvalidExtensionError> {
468 T::validate_extension_type(&extension)?;
469 let replaced = self.remove(extension.extension_type());
470 self.unique.push(extension);
471 Ok(replaced)
472 }
473}
474
475pub trait ExtensionValidator {
477 type Error;
479
480 fn validate_extension_type(ext: &Extension) -> Result<(), Self::Error>;
482}
483
484impl ExtensionValidator for AnyObject {
485 type Error = Infallible;
486
487 fn validate_extension_type(_ext: &Extension) -> Result<(), Infallible> {
488 Ok(())
489 }
490}
491
492impl<T: ExtensionValidator> TryFrom<Vec<Extension>> for Extensions<T>
493where
494 InvalidExtensionError: From<T::Error>,
495{
496 type Error = InvalidExtensionError;
497
498 fn try_from(candidate: Vec<Extension>) -> Result<Self, Self::Error> {
499 let mut unique: Vec<Extension> = Vec::new();
500 for extension in candidate.into_iter() {
501 T::validate_extension_type(&extension)?;
502
503 if unique
504 .iter()
505 .any(|ext| ext.extension_type() == extension.extension_type())
506 {
507 return Err(InvalidExtensionError::Duplicate);
508 } else {
509 unique.push(extension);
510 }
511 }
512
513 Ok(Self {
514 unique,
515 _object: PhantomData,
516 })
517 }
518}
519
520impl ExtensionValidator for GroupInfo {
522 type Error = ExtensionTypeNotValidInGroupInfoError;
523
524 fn validate_extension_type(
525 ext: &Extension,
526 ) -> Result<(), ExtensionTypeNotValidInGroupInfoError> {
527 if ext.extension_type().is_valid_in_group_info() == Some(true)
528 || ext.extension_type().is_valid_in_group_info().is_none()
529 {
530 Ok(())
531 } else {
532 Err(ExtensionTypeNotValidInGroupInfoError(ext.extension_type()))
533 }
534 }
535}
536
537impl ExtensionValidator for GroupContext {
539 type Error = ExtensionTypeNotValidInGroupContextError;
540
541 fn validate_extension_type(
542 ext: &Extension,
543 ) -> Result<(), ExtensionTypeNotValidInGroupContextError> {
544 if ext.extension_type().is_valid_in_group_context() {
545 Ok(())
546 } else {
547 Err(ExtensionTypeNotValidInGroupContextError(
548 ext.extension_type(),
549 ))
550 }
551 }
552}
553
554impl ExtensionValidator for KeyPackage {
556 type Error = ExtensionTypeNotValidInKeyPackageError;
557
558 fn validate_extension_type(
559 ext: &Extension,
560 ) -> Result<(), ExtensionTypeNotValidInKeyPackageError> {
561 if ext.extension_type() == ExtensionType::LastResort
562 || matches!(ext.extension_type(), ExtensionType::Unknown(_))
563 {
564 Ok(())
565 } else {
566 Err(ExtensionTypeNotValidInKeyPackageError(ext.extension_type()))
567 }
568 }
569}
570
571impl ExtensionValidator for LeafNode {
573 type Error = ExtensionTypeNotValidInLeafNodeError;
574
575 fn validate_extension_type(
576 ext: &Extension,
577 ) -> Result<(), ExtensionTypeNotValidInLeafNodeError> {
578 if ext.extension_type().is_valid_in_leaf_node() {
579 Ok(())
580 } else {
581 Err(ExtensionTypeNotValidInLeafNodeError(ext.extension_type()))
582 }
583 }
584}
585
586impl<T> Extensions<T> {
587 fn find_by_type(&self, extension_type: ExtensionType) -> Option<&Extension> {
588 self.unique
589 .iter()
590 .find(|ext| ext.extension_type() == extension_type)
591 }
592
593 pub fn application_id(&self) -> Option<&ApplicationIdExtension> {
595 self.find_by_type(ExtensionType::ApplicationId)
596 .and_then(|e| match e {
597 Extension::ApplicationId(e) => Some(e),
598 _ => None,
599 })
600 }
601
602 pub fn ratchet_tree(&self) -> Option<&RatchetTreeExtension> {
604 self.find_by_type(ExtensionType::RatchetTree)
605 .and_then(|e| match e {
606 Extension::RatchetTree(e) => Some(e),
607 _ => None,
608 })
609 }
610
611 pub fn required_capabilities(&self) -> Option<&RequiredCapabilitiesExtension> {
614 self.find_by_type(ExtensionType::RequiredCapabilities)
615 .and_then(|e| match e {
616 Extension::RequiredCapabilities(e) => Some(e),
617 _ => None,
618 })
619 }
620
621 pub fn external_pub(&self) -> Option<&ExternalPubExtension> {
623 self.find_by_type(ExtensionType::ExternalPub)
624 .and_then(|e| match e {
625 Extension::ExternalPub(e) => Some(e),
626 _ => None,
627 })
628 }
629
630 pub fn external_senders(&self) -> Option<&ExternalSendersExtension> {
632 self.find_by_type(ExtensionType::ExternalSenders)
633 .and_then(|e| match e {
634 Extension::ExternalSenders(e) => Some(e),
635 _ => None,
636 })
637 }
638
639 #[cfg(feature = "extensions-draft-08")]
640 pub fn app_data_dictionary(&self) -> Option<&AppDataDictionaryExtension> {
642 self.find_by_type(ExtensionType::AppDataDictionary)
643 .and_then(|e| match e {
644 Extension::AppDataDictionary(e) => Some(e),
645 _ => None,
646 })
647 }
648
649 pub fn unknown(&self, extension_type_id: u16) -> Option<&UnknownExtension> {
651 let extension_type: ExtensionType = extension_type_id.into();
652
653 match extension_type {
654 ExtensionType::Unknown(_) => self.find_by_type(extension_type).and_then(|e| match e {
655 Extension::Unknown(_, e) => Some(e),
656 _ => None,
657 }),
658 _ => None,
659 }
660 }
661}
662
663impl Extension {
664 pub fn as_application_id_extension(&self) -> Result<&ApplicationIdExtension, ExtensionError> {
668 match self {
669 Self::ApplicationId(e) => Ok(e),
670 _ => Err(ExtensionError::InvalidExtensionType(
671 "This is not an ApplicationIdExtension".into(),
672 )),
673 }
674 }
675 #[cfg(feature = "extensions-draft-08")]
676 pub fn as_app_data_dictionary_extension(
680 &self,
681 ) -> Result<&AppDataDictionaryExtension, ExtensionError> {
682 match self {
683 Self::AppDataDictionary(e) => Ok(e),
684 _ => Err(ExtensionError::InvalidExtensionType(
685 "This is not an AppDataDictionaryExtension".into(),
686 )),
687 }
688 }
689
690 pub fn as_ratchet_tree_extension(&self) -> Result<&RatchetTreeExtension, ExtensionError> {
694 match self {
695 Self::RatchetTree(rte) => Ok(rte),
696 _ => Err(ExtensionError::InvalidExtensionType(
697 "This is not a RatchetTreeExtension".into(),
698 )),
699 }
700 }
701
702 pub fn as_required_capabilities_extension(
706 &self,
707 ) -> Result<&RequiredCapabilitiesExtension, ExtensionError> {
708 match self {
709 Self::RequiredCapabilities(e) => Ok(e),
710 _ => Err(ExtensionError::InvalidExtensionType(
711 "This is not a RequiredCapabilitiesExtension".into(),
712 )),
713 }
714 }
715
716 pub fn as_external_pub_extension(&self) -> Result<&ExternalPubExtension, ExtensionError> {
720 match self {
721 Self::ExternalPub(e) => Ok(e),
722 _ => Err(ExtensionError::InvalidExtensionType(
723 "This is not an ExternalPubExtension".into(),
724 )),
725 }
726 }
727
728 pub fn as_external_senders_extension(
732 &self,
733 ) -> Result<&ExternalSendersExtension, ExtensionError> {
734 match self {
735 Self::ExternalSenders(e) => Ok(e),
736 _ => Err(ExtensionError::InvalidExtensionType(
737 "This is not an ExternalSendersExtension".into(),
738 )),
739 }
740 }
741
742 #[inline]
744 pub const fn extension_type(&self) -> ExtensionType {
745 match self {
746 Extension::ApplicationId(_) => ExtensionType::ApplicationId,
747 Extension::RatchetTree(_) => ExtensionType::RatchetTree,
748 Extension::RequiredCapabilities(_) => ExtensionType::RequiredCapabilities,
749 Extension::ExternalPub(_) => ExtensionType::ExternalPub,
750 Extension::ExternalSenders(_) => ExtensionType::ExternalSenders,
751 #[cfg(feature = "extensions-draft-08")]
752 Extension::AppDataDictionary(_) => ExtensionType::AppDataDictionary,
753 Extension::LastResort(_) => ExtensionType::LastResort,
754 Extension::Unknown(kind, _) => ExtensionType::Unknown(*kind),
755 }
756 }
757}
758
759macro_rules! impl_from_extensions_validator {
760 ($validator:ty, $error:ty) => {
761 impl From<Extensions<$validator>> for Extensions<AnyObject> {
762 fn from(value: Extensions<$validator>) -> Self {
763 Extensions {
764 unique: value.unique,
765 _object: PhantomData,
766 }
767 }
768 }
769
770 impl TryFrom<Extensions<AnyObject>> for Extensions<$validator> {
771 type Error = $error;
772
773 fn try_from(value: Extensions<AnyObject>) -> Result<Self, $error> {
774 value
775 .unique
776 .iter()
777 .try_for_each(<$validator as ExtensionValidator>::validate_extension_type)?;
778
779 Ok(Extensions {
780 unique: value.unique,
781 _object: PhantomData,
782 })
783 }
784 }
785 };
786}
787
788impl_from_extensions_validator!(GroupContext, ExtensionTypeNotValidInGroupContextError);
789impl_from_extensions_validator!(LeafNode, ExtensionTypeNotValidInLeafNodeError);
790impl_from_extensions_validator!(KeyPackage, ExtensionTypeNotValidInKeyPackageError);
791
792#[cfg(any(feature = "test-utils", test))]
793impl Extensions<AnyObject> {
794 pub(crate) fn coerce<T: ExtensionValidator>(self) -> Extensions<T> {
796 Extensions {
797 unique: self.unique,
798 _object: PhantomData,
799 }
800 }
801}
802#[cfg(test)]
803mod test {
804 use itertools::Itertools;
805 use tls_codec::{Deserialize, Serialize, VLBytes};
806
807 use crate::{ciphersuite::HpkePublicKey, extensions::*};
808
809 #[test]
810 fn add() {
811 let mut extensions: Extensions<AnyObject> = Extensions::default();
812 extensions
813 .add(Extension::RequiredCapabilities(
814 RequiredCapabilitiesExtension::default(),
815 ))
816 .unwrap();
817 assert!(extensions
818 .add(Extension::RequiredCapabilities(
819 RequiredCapabilitiesExtension::default()
820 ))
821 .is_err());
822 }
823
824 #[test]
825 fn add_try_from() {
826 let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
829 let ext_y = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
830
831 let tests = [
832 (vec![], true),
833 (vec![ext_x.clone()], true),
834 (vec![ext_x.clone(), ext_x.clone()], false),
835 (vec![ext_x.clone(), ext_x.clone(), ext_x.clone()], false),
836 (vec![ext_y.clone()], true),
837 (vec![ext_y.clone(), ext_y.clone()], false),
838 (vec![ext_y.clone(), ext_y.clone(), ext_y.clone()], false),
839 (vec![ext_x.clone(), ext_y.clone()], true),
840 (vec![ext_y.clone(), ext_x.clone()], true),
841 (vec![ext_x.clone(), ext_x.clone(), ext_y.clone()], false),
842 (vec![ext_y.clone(), ext_y.clone(), ext_x.clone()], false),
843 (vec![ext_x.clone(), ext_y.clone(), ext_y.clone()], false),
844 (vec![ext_y.clone(), ext_x.clone(), ext_x.clone()], false),
845 (vec![ext_x.clone(), ext_y.clone(), ext_x.clone()], false),
846 (vec![ext_y.clone(), ext_x, ext_y], false),
847 ];
848
849 for (test, should_work) in tests.into_iter() {
850 {
852 let mut extensions: Extensions<AnyObject> = Extensions::default();
853
854 let mut works = true;
855 for ext in test.iter() {
856 match extensions.add(ext.clone()) {
857 Ok(_) => {}
858 Err(InvalidExtensionError::Duplicate) => {
859 works = false;
860 }
861 _ => panic!("This should have never happened."),
862 }
863 }
864
865 println!("{:?}, {:?}", test.clone(), should_work);
866 assert_eq!(works, should_work);
867 }
868
869 if should_work {
871 assert!(Extensions::<AnyObject>::try_from(test).is_ok());
872 } else {
873 assert!(Extensions::<AnyObject>::try_from(test).is_err());
874 }
875 }
876 }
877
878 #[test]
879 fn ensure_ordering() {
880 let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
884 let ext_y = Extension::ExternalPub(ExternalPubExtension::new(HpkePublicKey::new(vec![])));
885 let ext_z = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
886
887 for candidate in [ext_x, ext_y, ext_z]
888 .into_iter()
889 .permutations(3)
890 .collect::<Vec<_>>()
891 {
892 let candidate: Extensions<AnyObject> = Extensions::try_from(candidate).unwrap();
893 let bytes = candidate.tls_serialize_detached().unwrap();
894 let got = Extensions::tls_deserialize(&mut bytes.as_slice()).unwrap();
895 assert_eq!(candidate, got);
896 }
897 }
898
899 #[test]
900 fn that_unknown_extensions_are_de_serialized_correctly() {
901 let extension_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF100, 0xFFFF];
902 let extension_datas = [vec![], vec![0], vec![1, 2, 3]];
903
904 for extension_type in extension_types.into_iter() {
905 for extension_data in extension_datas.iter() {
906 let test = {
908 let mut buf = extension_type.to_be_bytes().to_vec();
909 buf.append(
910 &mut VLBytes::new(extension_data.clone())
911 .tls_serialize_detached()
912 .unwrap(),
913 );
914 buf
915 };
916
917 let got = Extension::tls_deserialize_exact(&test).unwrap();
919
920 match got {
921 Extension::Unknown(got_extension_type, ref got_extension_data) => {
922 assert_eq!(extension_type, got_extension_type);
923 assert_eq!(extension_data, &got_extension_data.0);
924 }
925 other => panic!("Expected `Extension::Unknown`, got {other:?}"),
926 }
927
928 let got_serialized = got.tls_serialize_detached().unwrap();
930 assert_eq!(test, got_serialized);
931 }
932 }
933 }
934}