1use std::{
25 convert::Infallible,
26 fmt::Debug,
27 io::{Read, Write},
28 marker::PhantomData,
29};
30
31use serde::{Deserialize, Serialize};
32
33#[cfg(feature = "extensions-draft-08")]
35mod app_data_dict_extension;
36mod application_id_extension;
37mod codec;
38mod external_pub_extension;
39mod external_sender_extension;
40mod last_resort;
41mod ratchet_tree_extension;
42mod required_capabilities;
43use errors::*;
44
45pub mod errors;
47
48#[cfg(feature = "extensions-draft-08")]
50pub use app_data_dict_extension::{AppDataDictionary, AppDataDictionaryExtension};
51pub use application_id_extension::ApplicationIdExtension;
52pub use external_pub_extension::ExternalPubExtension;
53pub use external_sender_extension::{
54 ExternalSender, ExternalSendersExtension, SenderExtensionIndex,
55};
56pub use last_resort::LastResortExtension;
57pub use ratchet_tree_extension::RatchetTreeExtension;
58pub use required_capabilities::RequiredCapabilitiesExtension;
59
60use tls_codec::{
61 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
62 Size, TlsDeserialize, TlsSerialize, TlsSize,
63};
64
65use crate::{
66 group::GroupContext, key_packages::KeyPackage, messages::group_info::GroupInfo,
67 treesync::LeafNode,
68};
69
70#[cfg(test)]
71mod tests;
72
73#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)]
89pub enum ExtensionType {
90 ApplicationId,
93
94 RatchetTree,
97
98 RequiredCapabilities,
101
102 ExternalPub,
105
106 ExternalSenders,
109
110 LastResort,
113
114 #[cfg(feature = "extensions-draft-08")]
115 AppDataDictionary,
117
118 Grease(u16),
120
121 Unknown(u16),
123}
124
125impl ExtensionType {
126 pub(crate) fn is_default(self) -> bool {
128 match self {
129 ExtensionType::ApplicationId
130 | ExtensionType::RatchetTree
131 | ExtensionType::RequiredCapabilities
132 | ExtensionType::ExternalPub
133 | ExtensionType::ExternalSenders => true,
134 ExtensionType::LastResort | ExtensionType::Grease(_) | ExtensionType::Unknown(_) => {
135 false
136 }
137 #[cfg(feature = "extensions-draft-08")]
138 ExtensionType::AppDataDictionary => false,
139 }
140 }
141
142 pub(crate) fn is_valid_in_leaf_node(self) -> bool {
147 match self {
148 ExtensionType::Grease(_)
149 | ExtensionType::LastResort
150 | ExtensionType::RatchetTree
151 | ExtensionType::RequiredCapabilities
152 | ExtensionType::ExternalPub
153 | ExtensionType::ExternalSenders => false,
154 ExtensionType::Unknown(_) | ExtensionType::ApplicationId => true,
155 #[cfg(feature = "extensions-draft-08")]
156 ExtensionType::AppDataDictionary => true,
157 }
158 }
159 pub(crate) fn is_valid_in_group_info(self) -> Option<bool> {
160 match self {
161 ExtensionType::Grease(_)
162 | ExtensionType::LastResort
163 | ExtensionType::RequiredCapabilities
164 | ExtensionType::ExternalSenders
165 | ExtensionType::ApplicationId => Some(false),
166 ExtensionType::RatchetTree | ExtensionType::ExternalPub => Some(true),
167 ExtensionType::Unknown(_) => None,
168 #[cfg(feature = "extensions-draft-08")]
169 ExtensionType::AppDataDictionary => Some(true),
170 }
171 }
172
173 pub(crate) fn is_valid_in_key_package(self) -> bool {
174 match self {
175 ExtensionType::Grease(_)
176 | ExtensionType::RatchetTree
177 | ExtensionType::RequiredCapabilities
178 | ExtensionType::ExternalPub
179 | ExtensionType::ExternalSenders
180 | ExtensionType::ApplicationId => false,
181 ExtensionType::Unknown(_) | ExtensionType::LastResort => true,
182 #[cfg(feature = "extensions-draft-08")]
183 ExtensionType::AppDataDictionary => true,
184 }
185 }
186
187 pub(crate) fn is_valid_in_group_context(self) -> bool {
188 match self {
189 ExtensionType::RequiredCapabilities
190 | ExtensionType::ExternalSenders
191 | ExtensionType::Unknown(_) => true,
192 #[cfg(feature = "extensions-draft-08")]
193 ExtensionType::AppDataDictionary => true,
194 _ => false,
195 }
196 }
197
198 pub fn is_grease(&self) -> bool {
203 matches!(self, ExtensionType::Grease(_))
204 }
205}
206
207impl Size for ExtensionType {
208 fn tls_serialized_len(&self) -> usize {
209 2
210 }
211}
212
213impl TlsDeserializeTrait for ExtensionType {
214 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
215 where
216 Self: Sized,
217 {
218 let mut extension_type = [0u8; 2];
219 bytes.read_exact(&mut extension_type)?;
220
221 Ok(ExtensionType::from(u16::from_be_bytes(extension_type)))
222 }
223}
224
225impl DeserializeBytes for ExtensionType {
226 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
227 where
228 Self: Sized,
229 {
230 let mut bytes_ref = bytes;
231 let extension_type = ExtensionType::tls_deserialize(&mut bytes_ref)?;
232 let remainder = &bytes[extension_type.tls_serialized_len()..];
233 Ok((extension_type, remainder))
234 }
235}
236
237impl TlsSerializeTrait for ExtensionType {
238 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
239 writer.write_all(&u16::from(*self).to_be_bytes())?;
240
241 Ok(2)
242 }
243}
244
245impl From<u16> for ExtensionType {
246 fn from(a: u16) -> Self {
247 match a {
248 1 => ExtensionType::ApplicationId,
249 2 => ExtensionType::RatchetTree,
250 3 => ExtensionType::RequiredCapabilities,
251 4 => ExtensionType::ExternalPub,
252 5 => ExtensionType::ExternalSenders,
253 #[cfg(feature = "extensions-draft-08")]
254 6 => ExtensionType::AppDataDictionary,
255 10 => ExtensionType::LastResort,
256 unknown if crate::grease::is_grease_value(unknown) => ExtensionType::Grease(unknown),
257 unknown => ExtensionType::Unknown(unknown),
258 }
259 }
260}
261
262impl From<ExtensionType> for u16 {
263 fn from(value: ExtensionType) -> Self {
264 match value {
265 ExtensionType::ApplicationId => 1,
266 ExtensionType::RatchetTree => 2,
267 ExtensionType::RequiredCapabilities => 3,
268 ExtensionType::ExternalPub => 4,
269 ExtensionType::ExternalSenders => 5,
270 #[cfg(feature = "extensions-draft-08")]
271 ExtensionType::AppDataDictionary => 6,
272 ExtensionType::LastResort => 10,
273 ExtensionType::Grease(value) => value,
274 ExtensionType::Unknown(unknown) => unknown,
275 }
276 }
277}
278
279#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
294pub enum Extension {
295 ApplicationId(ApplicationIdExtension),
297
298 RatchetTree(RatchetTreeExtension),
300
301 RequiredCapabilities(RequiredCapabilitiesExtension),
303
304 ExternalPub(ExternalPubExtension),
306
307 ExternalSenders(ExternalSendersExtension),
309
310 #[cfg(feature = "extensions-draft-08")]
312 AppDataDictionary(AppDataDictionaryExtension),
313
314 LastResort(LastResortExtension),
316
317 Unknown(u16, UnknownExtension),
319}
320
321#[derive(
323 PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSize, TlsSerialize, TlsDeserialize,
324)]
325pub struct UnknownExtension(pub Vec<u8>);
326
327#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
329pub struct Extensions<T> {
330 unique: Vec<Extension>,
331 #[serde(skip)]
332 _object: core::marker::PhantomData<T>,
333}
334
335#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, TlsSize, TlsSerialize, TlsDeserialize)]
336pub struct AnyObject;
338
339impl<T> Default for Extensions<T> {
340 fn default() -> Self {
341 Self {
342 unique: vec![],
343 _object: PhantomData,
344 }
345 }
346}
347
348impl<T> Size for Extensions<T> {
349 fn tls_serialized_len(&self) -> usize {
350 Vec::tls_serialized_len(&self.unique)
351 }
352}
353
354impl<T> TlsSerializeTrait for Extensions<T> {
355 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
356 self.unique.tls_serialize(writer)
357 }
358}
359
360impl<T: ExtensionValidator> TlsDeserializeTrait for Extensions<T>
361where
362 InvalidExtensionError: From<T::Error>,
363{
364 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
365 where
366 Self: Sized,
367 {
368 let candidate: Vec<Extension> = Vec::tls_deserialize(bytes)?;
369 Extensions::<T>::try_from(candidate)
370 .map_err(|_| Error::DecodingError("Found duplicate extensions".into()))
371 }
372}
373
374impl<T: ExtensionValidator> DeserializeBytes for Extensions<T>
375where
376 InvalidExtensionError: From<T::Error>,
377{
378 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
379 where
380 Self: Sized,
381 {
382 let mut bytes_ref = bytes;
383 let extensions = Extensions::<T>::tls_deserialize(&mut bytes_ref)?;
384 let remainder = &bytes[extensions.tls_serialized_len()..];
385 Ok((extensions, remainder))
386 }
387}
388
389impl<T: ExtensionValidator> Extensions<T> {
390 pub fn empty() -> Self {
392 Self {
393 unique: vec![],
394 _object: PhantomData,
395 }
396 }
397
398 pub fn iter(&self) -> impl Iterator<Item = &Extension> {
400 self.unique.iter()
401 }
402
403 pub fn remove(&mut self, extension_type: ExtensionType) -> Option<Extension> {
408 if let Some(pos) = self
409 .unique
410 .iter()
411 .position(|ext| ext.extension_type() == extension_type)
412 {
413 Some(self.unique.remove(pos))
414 } else {
415 None
416 }
417 }
418
419 pub fn contains(&self, extension_type: ExtensionType) -> bool {
422 self.unique
423 .iter()
424 .any(|ext| ext.extension_type() == extension_type)
425 }
426}
427
428impl<T> Extensions<T>
429where
430 T: ExtensionValidator,
431 InvalidExtensionError: From<T::Error>,
432{
433 pub fn single(extension: Extension) -> Result<Self, InvalidExtensionError> {
435 T::validate_extension_type(&extension)?;
436 Ok(Self {
437 unique: vec![extension],
438 _object: PhantomData,
439 })
440 }
441
442 pub fn from_vec(extensions: Vec<Extension>) -> Result<Self, InvalidExtensionError> {
447 extensions.try_into()
448 }
449
450 pub fn validate<'a>(
452 extensions: impl Iterator<Item = &'a Extension>,
453 ) -> Result<(), InvalidExtensionError> {
454 for ext in extensions {
455 T::validate_extension_type(ext)?;
456 }
457 Ok(())
458 }
459
460 pub fn add(&mut self, extension: Extension) -> Result<(), InvalidExtensionError> {
465 T::validate_extension_type(&extension)?;
466 if self.contains(extension.extension_type()) {
467 return Err(InvalidExtensionError::Duplicate);
468 }
469
470 self.unique.push(extension);
471
472 Ok(())
473 }
474
475 pub fn add_or_replace(
479 &mut self,
480 extension: Extension,
481 ) -> Result<Option<Extension>, InvalidExtensionError> {
482 T::validate_extension_type(&extension)?;
483 let replaced = self.remove(extension.extension_type());
484 self.unique.push(extension);
485 Ok(replaced)
486 }
487}
488
489pub trait ExtensionValidator {
491 type Error;
493
494 fn validate_extension_type(ext: &Extension) -> Result<(), Self::Error>;
496}
497
498impl ExtensionValidator for AnyObject {
499 type Error = Infallible;
500
501 fn validate_extension_type(_ext: &Extension) -> Result<(), Infallible> {
502 Ok(())
503 }
504}
505
506impl<T: ExtensionValidator> TryFrom<Vec<Extension>> for Extensions<T>
507where
508 InvalidExtensionError: From<T::Error>,
509{
510 type Error = InvalidExtensionError;
511
512 fn try_from(candidate: Vec<Extension>) -> Result<Self, Self::Error> {
513 let mut unique: Vec<Extension> = Vec::new();
514 for extension in candidate.into_iter() {
515 T::validate_extension_type(&extension)?;
516
517 if unique
518 .iter()
519 .any(|ext| ext.extension_type() == extension.extension_type())
520 {
521 return Err(InvalidExtensionError::Duplicate);
522 } else {
523 unique.push(extension);
524 }
525 }
526
527 Ok(Self {
528 unique,
529 _object: PhantomData,
530 })
531 }
532}
533
534impl ExtensionValidator for GroupInfo {
536 type Error = ExtensionTypeNotValidInGroupInfoError;
537
538 fn validate_extension_type(
539 ext: &Extension,
540 ) -> Result<(), ExtensionTypeNotValidInGroupInfoError> {
541 if ext.extension_type().is_valid_in_group_info() == Some(true)
542 || ext.extension_type().is_valid_in_group_info().is_none()
543 {
544 Ok(())
545 } else {
546 Err(ExtensionTypeNotValidInGroupInfoError(ext.extension_type()))
547 }
548 }
549}
550
551impl ExtensionValidator for GroupContext {
553 type Error = ExtensionTypeNotValidInGroupContextError;
554
555 fn validate_extension_type(
556 ext: &Extension,
557 ) -> Result<(), ExtensionTypeNotValidInGroupContextError> {
558 if ext.extension_type().is_valid_in_group_context() {
559 Ok(())
560 } else {
561 Err(ExtensionTypeNotValidInGroupContextError(
562 ext.extension_type(),
563 ))
564 }
565 }
566}
567
568impl ExtensionValidator for KeyPackage {
570 type Error = ExtensionTypeNotValidInKeyPackageError;
571
572 fn validate_extension_type(
573 ext: &Extension,
574 ) -> Result<(), ExtensionTypeNotValidInKeyPackageError> {
575 if ext.extension_type().is_valid_in_key_package() {
576 Ok(())
577 } else {
578 Err(ExtensionTypeNotValidInKeyPackageError(ext.extension_type()))
579 }
580 }
581}
582
583impl ExtensionValidator for LeafNode {
585 type Error = ExtensionTypeNotValidInLeafNodeError;
586
587 fn validate_extension_type(
588 ext: &Extension,
589 ) -> Result<(), ExtensionTypeNotValidInLeafNodeError> {
590 if ext.extension_type().is_valid_in_leaf_node() {
591 Ok(())
592 } else {
593 Err(ExtensionTypeNotValidInLeafNodeError(ext.extension_type()))
594 }
595 }
596}
597
598impl<T> Extensions<T> {
599 fn find_by_type(&self, extension_type: ExtensionType) -> Option<&Extension> {
600 self.unique
601 .iter()
602 .find(|ext| ext.extension_type() == extension_type)
603 }
604
605 pub fn application_id(&self) -> Option<&ApplicationIdExtension> {
607 self.find_by_type(ExtensionType::ApplicationId)
608 .and_then(|e| match e {
609 Extension::ApplicationId(e) => Some(e),
610 _ => None,
611 })
612 }
613
614 pub fn ratchet_tree(&self) -> Option<&RatchetTreeExtension> {
616 self.find_by_type(ExtensionType::RatchetTree)
617 .and_then(|e| match e {
618 Extension::RatchetTree(e) => Some(e),
619 _ => None,
620 })
621 }
622
623 pub fn required_capabilities(&self) -> Option<&RequiredCapabilitiesExtension> {
626 self.find_by_type(ExtensionType::RequiredCapabilities)
627 .and_then(|e| match e {
628 Extension::RequiredCapabilities(e) => Some(e),
629 _ => None,
630 })
631 }
632
633 pub fn external_pub(&self) -> Option<&ExternalPubExtension> {
635 self.find_by_type(ExtensionType::ExternalPub)
636 .and_then(|e| match e {
637 Extension::ExternalPub(e) => Some(e),
638 _ => None,
639 })
640 }
641
642 pub fn external_senders(&self) -> Option<&ExternalSendersExtension> {
644 self.find_by_type(ExtensionType::ExternalSenders)
645 .and_then(|e| match e {
646 Extension::ExternalSenders(e) => Some(e),
647 _ => None,
648 })
649 }
650
651 #[cfg(feature = "extensions-draft-08")]
652 pub fn app_data_dictionary(&self) -> Option<&AppDataDictionaryExtension> {
654 self.find_by_type(ExtensionType::AppDataDictionary)
655 .and_then(|e| match e {
656 Extension::AppDataDictionary(e) => Some(e),
657 _ => None,
658 })
659 }
660
661 pub fn unknown(&self, extension_type_id: u16) -> Option<&UnknownExtension> {
663 let extension_type: ExtensionType = extension_type_id.into();
664
665 match extension_type {
666 ExtensionType::Unknown(_) => self.find_by_type(extension_type).and_then(|e| match e {
667 Extension::Unknown(_, e) => Some(e),
668 _ => None,
669 }),
670 _ => None,
671 }
672 }
673}
674
675impl Extension {
676 pub fn as_application_id_extension(&self) -> Result<&ApplicationIdExtension, ExtensionError> {
680 match self {
681 Self::ApplicationId(e) => Ok(e),
682 _ => Err(ExtensionError::InvalidExtensionType(
683 "This is not an ApplicationIdExtension".into(),
684 )),
685 }
686 }
687 #[cfg(feature = "extensions-draft-08")]
688 pub fn as_app_data_dictionary_extension(
692 &self,
693 ) -> Result<&AppDataDictionaryExtension, ExtensionError> {
694 match self {
695 Self::AppDataDictionary(e) => Ok(e),
696 _ => Err(ExtensionError::InvalidExtensionType(
697 "This is not an AppDataDictionaryExtension".into(),
698 )),
699 }
700 }
701
702 pub fn as_ratchet_tree_extension(&self) -> Result<&RatchetTreeExtension, ExtensionError> {
706 match self {
707 Self::RatchetTree(rte) => Ok(rte),
708 _ => Err(ExtensionError::InvalidExtensionType(
709 "This is not a RatchetTreeExtension".into(),
710 )),
711 }
712 }
713
714 pub fn as_required_capabilities_extension(
718 &self,
719 ) -> Result<&RequiredCapabilitiesExtension, ExtensionError> {
720 match self {
721 Self::RequiredCapabilities(e) => Ok(e),
722 _ => Err(ExtensionError::InvalidExtensionType(
723 "This is not a RequiredCapabilitiesExtension".into(),
724 )),
725 }
726 }
727
728 pub fn as_external_pub_extension(&self) -> Result<&ExternalPubExtension, ExtensionError> {
732 match self {
733 Self::ExternalPub(e) => Ok(e),
734 _ => Err(ExtensionError::InvalidExtensionType(
735 "This is not an ExternalPubExtension".into(),
736 )),
737 }
738 }
739
740 pub fn as_external_senders_extension(
744 &self,
745 ) -> Result<&ExternalSendersExtension, ExtensionError> {
746 match self {
747 Self::ExternalSenders(e) => Ok(e),
748 _ => Err(ExtensionError::InvalidExtensionType(
749 "This is not an ExternalSendersExtension".into(),
750 )),
751 }
752 }
753
754 #[inline]
756 pub const fn extension_type(&self) -> ExtensionType {
757 match self {
758 Extension::ApplicationId(_) => ExtensionType::ApplicationId,
759 Extension::RatchetTree(_) => ExtensionType::RatchetTree,
760 Extension::RequiredCapabilities(_) => ExtensionType::RequiredCapabilities,
761 Extension::ExternalPub(_) => ExtensionType::ExternalPub,
762 Extension::ExternalSenders(_) => ExtensionType::ExternalSenders,
763 #[cfg(feature = "extensions-draft-08")]
764 Extension::AppDataDictionary(_) => ExtensionType::AppDataDictionary,
765 Extension::LastResort(_) => ExtensionType::LastResort,
766 Extension::Unknown(kind, _) => ExtensionType::Unknown(*kind),
767 }
768 }
769}
770
771macro_rules! impl_from_extensions_validator {
772 ($validator:ty, $error:ty) => {
773 impl From<Extensions<$validator>> for Extensions<AnyObject> {
774 fn from(value: Extensions<$validator>) -> Self {
775 Extensions {
776 unique: value.unique,
777 _object: PhantomData,
778 }
779 }
780 }
781
782 impl TryFrom<Extensions<AnyObject>> for Extensions<$validator> {
783 type Error = $error;
784
785 fn try_from(value: Extensions<AnyObject>) -> Result<Self, $error> {
786 value
787 .unique
788 .iter()
789 .try_for_each(<$validator as ExtensionValidator>::validate_extension_type)?;
790
791 Ok(Extensions {
792 unique: value.unique,
793 _object: PhantomData,
794 })
795 }
796 }
797 };
798}
799
800impl_from_extensions_validator!(GroupContext, ExtensionTypeNotValidInGroupContextError);
801impl_from_extensions_validator!(LeafNode, ExtensionTypeNotValidInLeafNodeError);
802impl_from_extensions_validator!(KeyPackage, ExtensionTypeNotValidInKeyPackageError);
803
804#[cfg(any(feature = "test-utils", test))]
805impl Extensions<AnyObject> {
806 pub(crate) fn coerce<T: ExtensionValidator>(self) -> Extensions<T> {
808 Extensions {
809 unique: self.unique,
810 _object: PhantomData,
811 }
812 }
813}
814#[cfg(test)]
815mod test {
816 use itertools::Itertools;
817 use tls_codec::{Deserialize, Serialize, VLBytes};
818
819 use crate::{ciphersuite::HpkePublicKey, extensions::*};
820
821 #[test]
822 fn add() {
823 let mut extensions: Extensions<AnyObject> = Extensions::default();
824 extensions
825 .add(Extension::RequiredCapabilities(
826 RequiredCapabilitiesExtension::default(),
827 ))
828 .unwrap();
829 assert!(extensions
830 .add(Extension::RequiredCapabilities(
831 RequiredCapabilitiesExtension::default()
832 ))
833 .is_err());
834 }
835
836 #[test]
837 fn add_try_from() {
838 let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
841 let ext_y = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
842
843 let tests = [
844 (vec![], true),
845 (vec![ext_x.clone()], true),
846 (vec![ext_x.clone(), ext_x.clone()], false),
847 (vec![ext_x.clone(), ext_x.clone(), ext_x.clone()], false),
848 (vec![ext_y.clone()], true),
849 (vec![ext_y.clone(), ext_y.clone()], false),
850 (vec![ext_y.clone(), ext_y.clone(), ext_y.clone()], false),
851 (vec![ext_x.clone(), ext_y.clone()], true),
852 (vec![ext_y.clone(), ext_x.clone()], true),
853 (vec![ext_x.clone(), ext_x.clone(), ext_y.clone()], false),
854 (vec![ext_y.clone(), ext_y.clone(), ext_x.clone()], false),
855 (vec![ext_x.clone(), ext_y.clone(), ext_y.clone()], false),
856 (vec![ext_y.clone(), ext_x.clone(), ext_x.clone()], false),
857 (vec![ext_x.clone(), ext_y.clone(), ext_x.clone()], false),
858 (vec![ext_y.clone(), ext_x, ext_y], false),
859 ];
860
861 for (test, should_work) in tests.into_iter() {
862 {
864 let mut extensions: Extensions<AnyObject> = Extensions::default();
865
866 let mut works = true;
867 for ext in test.iter() {
868 match extensions.add(ext.clone()) {
869 Ok(_) => {}
870 Err(InvalidExtensionError::Duplicate) => {
871 works = false;
872 }
873 _ => panic!("This should have never happened."),
874 }
875 }
876
877 println!("{:?}, {:?}", test.clone(), should_work);
878 assert_eq!(works, should_work);
879 }
880
881 if should_work {
883 assert!(Extensions::<AnyObject>::try_from(test).is_ok());
884 } else {
885 assert!(Extensions::<AnyObject>::try_from(test).is_err());
886 }
887 }
888 }
889
890 #[test]
891 fn ensure_ordering() {
892 let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
896 let ext_y = Extension::ExternalPub(ExternalPubExtension::new(HpkePublicKey::new(vec![])));
897 let ext_z = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
898
899 for candidate in [ext_x, ext_y, ext_z]
900 .into_iter()
901 .permutations(3)
902 .collect::<Vec<_>>()
903 {
904 let candidate: Extensions<AnyObject> = Extensions::try_from(candidate).unwrap();
905 let bytes = candidate.tls_serialize_detached().unwrap();
906 let got = Extensions::tls_deserialize(&mut bytes.as_slice()).unwrap();
907 assert_eq!(candidate, got);
908 }
909 }
910
911 #[test]
912 fn that_unknown_extensions_are_de_serialized_correctly() {
913 let extension_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF100, 0xFFFF];
914 let extension_datas = [vec![], vec![0], vec![1, 2, 3]];
915
916 for extension_type in extension_types.into_iter() {
917 for extension_data in extension_datas.iter() {
918 let test = {
920 let mut buf = extension_type.to_be_bytes().to_vec();
921 buf.append(
922 &mut VLBytes::new(extension_data.clone())
923 .tls_serialize_detached()
924 .unwrap(),
925 );
926 buf
927 };
928
929 let got = Extension::tls_deserialize_exact(&test).unwrap();
931
932 match got {
933 Extension::Unknown(got_extension_type, ref got_extension_data) => {
934 assert_eq!(extension_type, got_extension_type);
935 assert_eq!(extension_data, &got_extension_data.0);
936 }
937 other => panic!("Expected `Extension::Unknown`, got {other:?}"),
938 }
939
940 let got_serialized = got.tls_serialize_detached().unwrap();
942 assert_eq!(test, got_serialized);
943 }
944 }
945 }
946}