1use std::{
25 convert::Infallible,
26 fmt::Debug,
27 io::{Read, Write},
28 marker::PhantomData,
29};
30
31use serde::{Deserialize, Serialize};
32
33#[cfg(feature = "extensions-draft-08")]
35mod app_data_dict_extension;
36mod application_id_extension;
37mod codec;
38mod external_pub_extension;
39mod external_sender_extension;
40mod last_resort;
41mod ratchet_tree_extension;
42mod required_capabilities;
43use errors::*;
44
45pub mod errors;
47
48#[cfg(feature = "extensions-draft-08")]
50pub use app_data_dict_extension::{AppDataDictionary, AppDataDictionaryExtension};
51pub use application_id_extension::ApplicationIdExtension;
52pub use external_pub_extension::ExternalPubExtension;
53pub use external_sender_extension::{
54 ExternalSender, ExternalSendersExtension, SenderExtensionIndex,
55};
56pub use last_resort::LastResortExtension;
57pub use ratchet_tree_extension::RatchetTreeExtension;
58pub use required_capabilities::RequiredCapabilitiesExtension;
59
60use tls_codec::{
61 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
62 Size, TlsDeserialize, TlsSerialize, TlsSize,
63};
64
65use crate::{
66 group::GroupContext, key_packages::KeyPackage, messages::group_info::GroupInfo,
67 treesync::LeafNode,
68};
69
70#[cfg(test)]
71mod tests;
72
73#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
89#[cfg_attr(
90 feature = "0-8-1-storage-format",
91 derive(serde::Serialize, serde::Deserialize)
92)]
93#[cfg_attr(
94 not(feature = "0-8-1-storage-format"),
95 derive(
96 openmls_serialization_helpers::Serialize,
97 openmls_serialization_helpers::Deserialize,
98 )
99)]
100pub enum ExtensionType {
101 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 0)]
102 ApplicationId,
105
106 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 1)]
107 RatchetTree,
110
111 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 2)]
112 RequiredCapabilities,
115
116 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 3)]
117 ExternalPub,
120
121 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 4)]
122 ExternalSenders,
125
126 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 5)]
127 LastResort,
130
131 #[cfg(feature = "extensions-draft-08")]
132 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 8)]
133 AppDataDictionary,
135
136 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 7)]
137 Grease(u16),
139
140 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 6)]
141 Unknown(u16),
143}
144
145impl ExtensionType {
146 pub(crate) fn is_default(self) -> bool {
148 match self {
149 ExtensionType::ApplicationId
150 | ExtensionType::RatchetTree
151 | ExtensionType::RequiredCapabilities
152 | ExtensionType::ExternalPub
153 | ExtensionType::ExternalSenders => true,
154 ExtensionType::LastResort | ExtensionType::Grease(_) | ExtensionType::Unknown(_) => {
155 false
156 }
157 #[cfg(feature = "extensions-draft-08")]
158 ExtensionType::AppDataDictionary => false,
159 }
160 }
161
162 pub(crate) fn is_valid_in_leaf_node(self) -> bool {
167 match self {
168 ExtensionType::Grease(_)
169 | ExtensionType::LastResort
170 | ExtensionType::RatchetTree
171 | ExtensionType::RequiredCapabilities
172 | ExtensionType::ExternalPub
173 | ExtensionType::ExternalSenders => false,
174 ExtensionType::Unknown(_) | ExtensionType::ApplicationId => true,
175 #[cfg(feature = "extensions-draft-08")]
176 ExtensionType::AppDataDictionary => true,
177 }
178 }
179 pub(crate) fn is_valid_in_group_info(self) -> Option<bool> {
180 match self {
181 ExtensionType::Grease(_)
182 | ExtensionType::LastResort
183 | ExtensionType::RequiredCapabilities
184 | ExtensionType::ExternalSenders
185 | ExtensionType::ApplicationId => Some(false),
186 ExtensionType::RatchetTree | ExtensionType::ExternalPub => Some(true),
187 ExtensionType::Unknown(_) => None,
188 #[cfg(feature = "extensions-draft-08")]
189 ExtensionType::AppDataDictionary => Some(true),
190 }
191 }
192
193 pub(crate) fn is_valid_in_key_package(self) -> bool {
194 match self {
195 ExtensionType::Grease(_)
196 | ExtensionType::RatchetTree
197 | ExtensionType::RequiredCapabilities
198 | ExtensionType::ExternalPub
199 | ExtensionType::ExternalSenders
200 | ExtensionType::ApplicationId => false,
201 ExtensionType::Unknown(_) | ExtensionType::LastResort => true,
202 #[cfg(feature = "extensions-draft-08")]
203 ExtensionType::AppDataDictionary => true,
204 }
205 }
206
207 pub(crate) fn is_valid_in_group_context(self) -> bool {
208 match self {
209 ExtensionType::RequiredCapabilities
210 | ExtensionType::ExternalSenders
211 | ExtensionType::Unknown(_) => true,
212 #[cfg(feature = "extensions-draft-08")]
213 ExtensionType::AppDataDictionary => true,
214 _ => false,
215 }
216 }
217
218 pub fn is_grease(&self) -> bool {
223 matches!(self, ExtensionType::Grease(_))
224 }
225}
226
227impl Size for ExtensionType {
228 fn tls_serialized_len(&self) -> usize {
229 2
230 }
231}
232
233impl TlsDeserializeTrait for ExtensionType {
234 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
235 where
236 Self: Sized,
237 {
238 let mut extension_type = [0u8; 2];
239 bytes.read_exact(&mut extension_type)?;
240
241 Ok(ExtensionType::from(u16::from_be_bytes(extension_type)))
242 }
243}
244
245impl DeserializeBytes for ExtensionType {
246 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
247 where
248 Self: Sized,
249 {
250 let mut bytes_ref = bytes;
251 let extension_type = ExtensionType::tls_deserialize(&mut bytes_ref)?;
252 let remainder = &bytes[extension_type.tls_serialized_len()..];
253 Ok((extension_type, remainder))
254 }
255}
256
257impl TlsSerializeTrait for ExtensionType {
258 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
259 writer.write_all(&u16::from(*self).to_be_bytes())?;
260
261 Ok(2)
262 }
263}
264
265impl From<u16> for ExtensionType {
266 fn from(a: u16) -> Self {
267 match a {
268 1 => ExtensionType::ApplicationId,
269 2 => ExtensionType::RatchetTree,
270 3 => ExtensionType::RequiredCapabilities,
271 4 => ExtensionType::ExternalPub,
272 5 => ExtensionType::ExternalSenders,
273 #[cfg(feature = "extensions-draft-08")]
274 6 => ExtensionType::AppDataDictionary,
275 10 => ExtensionType::LastResort,
276 unknown if crate::grease::is_grease_value(unknown) => ExtensionType::Grease(unknown),
277 unknown => ExtensionType::Unknown(unknown),
278 }
279 }
280}
281
282impl From<ExtensionType> for u16 {
283 fn from(value: ExtensionType) -> Self {
284 match value {
285 ExtensionType::ApplicationId => 1,
286 ExtensionType::RatchetTree => 2,
287 ExtensionType::RequiredCapabilities => 3,
288 ExtensionType::ExternalPub => 4,
289 ExtensionType::ExternalSenders => 5,
290 #[cfg(feature = "extensions-draft-08")]
291 ExtensionType::AppDataDictionary => 6,
292 ExtensionType::LastResort => 10,
293 ExtensionType::Grease(value) => value,
294 ExtensionType::Unknown(unknown) => unknown,
295 }
296 }
297}
298
299#[derive(Debug, Clone, PartialEq, Eq)]
314#[cfg_attr(
315 feature = "0-8-1-storage-format",
316 derive(serde::Serialize, serde::Deserialize)
317)]
318#[cfg_attr(
319 not(feature = "0-8-1-storage-format"),
320 derive(
321 openmls_serialization_helpers::Serialize,
322 openmls_serialization_helpers::Deserialize,
323 )
324)]
325pub enum Extension {
326 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 0)]
327 ApplicationId(ApplicationIdExtension),
329
330 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 1)]
331 RatchetTree(RatchetTreeExtension),
333
334 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 2)]
335 RequiredCapabilities(RequiredCapabilitiesExtension),
337
338 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 3)]
339 ExternalPub(ExternalPubExtension),
341
342 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 4)]
343 ExternalSenders(ExternalSendersExtension),
345
346 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 7)]
347 #[cfg(feature = "extensions-draft-08")]
349 AppDataDictionary(AppDataDictionaryExtension),
350
351 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 5)]
352 LastResort(LastResortExtension),
354
355 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 6)]
356 Unknown(u16, UnknownExtension),
358}
359
360#[derive(
362 PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSize, TlsSerialize, TlsDeserialize,
363)]
364pub struct UnknownExtension(pub Vec<u8>);
365
366#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
368pub struct Extensions<T> {
369 unique: Vec<Extension>,
370 #[serde(skip)]
371 _object: core::marker::PhantomData<T>,
372}
373
374#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, TlsSize, TlsSerialize, TlsDeserialize)]
375pub struct AnyObject;
377
378impl<T> Default for Extensions<T> {
379 fn default() -> Self {
380 Self {
381 unique: vec![],
382 _object: PhantomData,
383 }
384 }
385}
386
387impl<T> Size for Extensions<T> {
388 fn tls_serialized_len(&self) -> usize {
389 Vec::tls_serialized_len(&self.unique)
390 }
391}
392
393impl<T> TlsSerializeTrait for Extensions<T> {
394 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
395 self.unique.tls_serialize(writer)
396 }
397}
398
399impl<T: ExtensionValidator> TlsDeserializeTrait for Extensions<T>
400where
401 InvalidExtensionError: From<T::Error>,
402{
403 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
404 where
405 Self: Sized,
406 {
407 let candidate: Vec<Extension> = Vec::tls_deserialize(bytes)?;
408 Extensions::<T>::try_from(candidate)
409 .map_err(|_| Error::DecodingError("Found duplicate extensions".into()))
410 }
411}
412
413impl<T: ExtensionValidator> DeserializeBytes for Extensions<T>
414where
415 InvalidExtensionError: From<T::Error>,
416{
417 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
418 where
419 Self: Sized,
420 {
421 let mut bytes_ref = bytes;
422 let extensions = Extensions::<T>::tls_deserialize(&mut bytes_ref)?;
423 let remainder = &bytes[extensions.tls_serialized_len()..];
424 Ok((extensions, remainder))
425 }
426}
427
428impl<T: ExtensionValidator> Extensions<T> {
429 pub fn empty() -> Self {
431 Self {
432 unique: vec![],
433 _object: PhantomData,
434 }
435 }
436
437 pub fn iter(&self) -> impl Iterator<Item = &Extension> {
439 self.unique.iter()
440 }
441
442 pub fn remove(&mut self, extension_type: ExtensionType) -> Option<Extension> {
447 if let Some(pos) = self
448 .unique
449 .iter()
450 .position(|ext| ext.extension_type() == extension_type)
451 {
452 Some(self.unique.remove(pos))
453 } else {
454 None
455 }
456 }
457
458 pub fn contains(&self, extension_type: ExtensionType) -> bool {
461 self.unique
462 .iter()
463 .any(|ext| ext.extension_type() == extension_type)
464 }
465}
466
467impl<T> Extensions<T>
468where
469 T: ExtensionValidator,
470 InvalidExtensionError: From<T::Error>,
471{
472 pub fn single(extension: Extension) -> Result<Self, InvalidExtensionError> {
474 T::validate_extension_type(&extension)?;
475 Ok(Self {
476 unique: vec![extension],
477 _object: PhantomData,
478 })
479 }
480
481 pub fn from_vec(extensions: Vec<Extension>) -> Result<Self, InvalidExtensionError> {
486 extensions.try_into()
487 }
488
489 pub fn validate<'a>(
491 extensions: impl Iterator<Item = &'a Extension>,
492 ) -> Result<(), InvalidExtensionError> {
493 for ext in extensions {
494 T::validate_extension_type(ext)?;
495 }
496 Ok(())
497 }
498
499 pub fn add(&mut self, extension: Extension) -> Result<(), InvalidExtensionError> {
504 T::validate_extension_type(&extension)?;
505 if self.contains(extension.extension_type()) {
506 return Err(InvalidExtensionError::Duplicate);
507 }
508
509 self.unique.push(extension);
510
511 Ok(())
512 }
513
514 pub fn add_or_replace(
518 &mut self,
519 extension: Extension,
520 ) -> Result<Option<Extension>, InvalidExtensionError> {
521 T::validate_extension_type(&extension)?;
522 let replaced = self.remove(extension.extension_type());
523 self.unique.push(extension);
524 Ok(replaced)
525 }
526}
527
528pub trait ExtensionValidator {
530 type Error;
532
533 fn validate_extension_type(ext: &Extension) -> Result<(), Self::Error>;
535}
536
537impl ExtensionValidator for AnyObject {
538 type Error = Infallible;
539
540 fn validate_extension_type(_ext: &Extension) -> Result<(), Infallible> {
541 Ok(())
542 }
543}
544
545impl<T: ExtensionValidator> TryFrom<Vec<Extension>> for Extensions<T>
546where
547 InvalidExtensionError: From<T::Error>,
548{
549 type Error = InvalidExtensionError;
550
551 fn try_from(candidate: Vec<Extension>) -> Result<Self, Self::Error> {
552 let mut unique: Vec<Extension> = Vec::new();
553 for extension in candidate.into_iter() {
554 T::validate_extension_type(&extension)?;
555
556 if unique
557 .iter()
558 .any(|ext| ext.extension_type() == extension.extension_type())
559 {
560 return Err(InvalidExtensionError::Duplicate);
561 } else {
562 unique.push(extension);
563 }
564 }
565
566 Ok(Self {
567 unique,
568 _object: PhantomData,
569 })
570 }
571}
572
573impl ExtensionValidator for GroupInfo {
575 type Error = ExtensionTypeNotValidInGroupInfoError;
576
577 fn validate_extension_type(
578 ext: &Extension,
579 ) -> Result<(), ExtensionTypeNotValidInGroupInfoError> {
580 if ext.extension_type().is_valid_in_group_info() == Some(true)
581 || ext.extension_type().is_valid_in_group_info().is_none()
582 {
583 Ok(())
584 } else {
585 Err(ExtensionTypeNotValidInGroupInfoError(ext.extension_type()))
586 }
587 }
588}
589
590impl ExtensionValidator for GroupContext {
592 type Error = ExtensionTypeNotValidInGroupContextError;
593
594 fn validate_extension_type(
595 ext: &Extension,
596 ) -> Result<(), ExtensionTypeNotValidInGroupContextError> {
597 if ext.extension_type().is_valid_in_group_context() {
598 Ok(())
599 } else {
600 Err(ExtensionTypeNotValidInGroupContextError(
601 ext.extension_type(),
602 ))
603 }
604 }
605}
606
607impl ExtensionValidator for KeyPackage {
609 type Error = ExtensionTypeNotValidInKeyPackageError;
610
611 fn validate_extension_type(
612 ext: &Extension,
613 ) -> Result<(), ExtensionTypeNotValidInKeyPackageError> {
614 if ext.extension_type().is_valid_in_key_package() {
615 Ok(())
616 } else {
617 Err(ExtensionTypeNotValidInKeyPackageError(ext.extension_type()))
618 }
619 }
620}
621
622impl ExtensionValidator for LeafNode {
624 type Error = ExtensionTypeNotValidInLeafNodeError;
625
626 fn validate_extension_type(
627 ext: &Extension,
628 ) -> Result<(), ExtensionTypeNotValidInLeafNodeError> {
629 if ext.extension_type().is_valid_in_leaf_node() {
630 Ok(())
631 } else {
632 Err(ExtensionTypeNotValidInLeafNodeError(ext.extension_type()))
633 }
634 }
635}
636
637impl<T> Extensions<T> {
638 fn find_by_type(&self, extension_type: ExtensionType) -> Option<&Extension> {
639 self.unique
640 .iter()
641 .find(|ext| ext.extension_type() == extension_type)
642 }
643
644 pub fn application_id(&self) -> Option<&ApplicationIdExtension> {
646 self.find_by_type(ExtensionType::ApplicationId)
647 .and_then(|e| match e {
648 Extension::ApplicationId(e) => Some(e),
649 _ => None,
650 })
651 }
652
653 pub fn ratchet_tree(&self) -> Option<&RatchetTreeExtension> {
655 self.find_by_type(ExtensionType::RatchetTree)
656 .and_then(|e| match e {
657 Extension::RatchetTree(e) => Some(e),
658 _ => None,
659 })
660 }
661
662 pub fn required_capabilities(&self) -> Option<&RequiredCapabilitiesExtension> {
665 self.find_by_type(ExtensionType::RequiredCapabilities)
666 .and_then(|e| match e {
667 Extension::RequiredCapabilities(e) => Some(e),
668 _ => None,
669 })
670 }
671
672 pub fn external_pub(&self) -> Option<&ExternalPubExtension> {
674 self.find_by_type(ExtensionType::ExternalPub)
675 .and_then(|e| match e {
676 Extension::ExternalPub(e) => Some(e),
677 _ => None,
678 })
679 }
680
681 pub fn external_senders(&self) -> Option<&ExternalSendersExtension> {
683 self.find_by_type(ExtensionType::ExternalSenders)
684 .and_then(|e| match e {
685 Extension::ExternalSenders(e) => Some(e),
686 _ => None,
687 })
688 }
689
690 #[cfg(feature = "extensions-draft-08")]
691 pub fn app_data_dictionary(&self) -> Option<&AppDataDictionaryExtension> {
693 self.find_by_type(ExtensionType::AppDataDictionary)
694 .and_then(|e| match e {
695 Extension::AppDataDictionary(e) => Some(e),
696 _ => None,
697 })
698 }
699
700 pub fn unknown(&self, extension_type_id: u16) -> Option<&UnknownExtension> {
702 let extension_type: ExtensionType = extension_type_id.into();
703
704 match extension_type {
705 ExtensionType::Unknown(_) => self.find_by_type(extension_type).and_then(|e| match e {
706 Extension::Unknown(_, e) => Some(e),
707 _ => None,
708 }),
709 _ => None,
710 }
711 }
712}
713
714impl Extension {
715 pub fn as_application_id_extension(&self) -> Result<&ApplicationIdExtension, ExtensionError> {
719 match self {
720 Self::ApplicationId(e) => Ok(e),
721 _ => Err(ExtensionError::InvalidExtensionType(
722 "This is not an ApplicationIdExtension".into(),
723 )),
724 }
725 }
726 #[cfg(feature = "extensions-draft-08")]
727 pub fn as_app_data_dictionary_extension(
731 &self,
732 ) -> Result<&AppDataDictionaryExtension, ExtensionError> {
733 match self {
734 Self::AppDataDictionary(e) => Ok(e),
735 _ => Err(ExtensionError::InvalidExtensionType(
736 "This is not an AppDataDictionaryExtension".into(),
737 )),
738 }
739 }
740
741 pub fn as_ratchet_tree_extension(&self) -> Result<&RatchetTreeExtension, ExtensionError> {
745 match self {
746 Self::RatchetTree(rte) => Ok(rte),
747 _ => Err(ExtensionError::InvalidExtensionType(
748 "This is not a RatchetTreeExtension".into(),
749 )),
750 }
751 }
752
753 pub fn as_required_capabilities_extension(
757 &self,
758 ) -> Result<&RequiredCapabilitiesExtension, ExtensionError> {
759 match self {
760 Self::RequiredCapabilities(e) => Ok(e),
761 _ => Err(ExtensionError::InvalidExtensionType(
762 "This is not a RequiredCapabilitiesExtension".into(),
763 )),
764 }
765 }
766
767 pub fn as_external_pub_extension(&self) -> Result<&ExternalPubExtension, ExtensionError> {
771 match self {
772 Self::ExternalPub(e) => Ok(e),
773 _ => Err(ExtensionError::InvalidExtensionType(
774 "This is not an ExternalPubExtension".into(),
775 )),
776 }
777 }
778
779 pub fn as_external_senders_extension(
783 &self,
784 ) -> Result<&ExternalSendersExtension, ExtensionError> {
785 match self {
786 Self::ExternalSenders(e) => Ok(e),
787 _ => Err(ExtensionError::InvalidExtensionType(
788 "This is not an ExternalSendersExtension".into(),
789 )),
790 }
791 }
792
793 #[inline]
795 pub const fn extension_type(&self) -> ExtensionType {
796 match self {
797 Extension::ApplicationId(_) => ExtensionType::ApplicationId,
798 Extension::RatchetTree(_) => ExtensionType::RatchetTree,
799 Extension::RequiredCapabilities(_) => ExtensionType::RequiredCapabilities,
800 Extension::ExternalPub(_) => ExtensionType::ExternalPub,
801 Extension::ExternalSenders(_) => ExtensionType::ExternalSenders,
802 #[cfg(feature = "extensions-draft-08")]
803 Extension::AppDataDictionary(_) => ExtensionType::AppDataDictionary,
804 Extension::LastResort(_) => ExtensionType::LastResort,
805 Extension::Unknown(kind, _) => ExtensionType::Unknown(*kind),
806 }
807 }
808}
809
810macro_rules! impl_from_extensions_validator {
811 ($validator:ty, $error:ty) => {
812 impl From<Extensions<$validator>> for Extensions<AnyObject> {
813 fn from(value: Extensions<$validator>) -> Self {
814 Extensions {
815 unique: value.unique,
816 _object: PhantomData,
817 }
818 }
819 }
820
821 impl TryFrom<Extensions<AnyObject>> for Extensions<$validator> {
822 type Error = $error;
823
824 fn try_from(value: Extensions<AnyObject>) -> Result<Self, $error> {
825 value
826 .unique
827 .iter()
828 .try_for_each(<$validator as ExtensionValidator>::validate_extension_type)?;
829
830 Ok(Extensions {
831 unique: value.unique,
832 _object: PhantomData,
833 })
834 }
835 }
836 };
837}
838
839impl_from_extensions_validator!(GroupContext, ExtensionTypeNotValidInGroupContextError);
840impl_from_extensions_validator!(LeafNode, ExtensionTypeNotValidInLeafNodeError);
841impl_from_extensions_validator!(KeyPackage, ExtensionTypeNotValidInKeyPackageError);
842
843#[cfg(any(feature = "test-utils", test))]
844impl Extensions<AnyObject> {
845 pub(crate) fn coerce<T: ExtensionValidator>(self) -> Extensions<T> {
847 Extensions {
848 unique: self.unique,
849 _object: PhantomData,
850 }
851 }
852}
853#[cfg(test)]
854mod test {
855 use itertools::Itertools;
856 use tls_codec::{Deserialize, Serialize, VLBytes};
857
858 use crate::{ciphersuite::HpkePublicKey, extensions::*};
859
860 #[test]
861 fn add() {
862 let mut extensions: Extensions<AnyObject> = Extensions::default();
863 extensions
864 .add(Extension::RequiredCapabilities(
865 RequiredCapabilitiesExtension::default(),
866 ))
867 .unwrap();
868 assert!(extensions
869 .add(Extension::RequiredCapabilities(
870 RequiredCapabilitiesExtension::default()
871 ))
872 .is_err());
873 }
874
875 #[test]
876 fn add_try_from() {
877 let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
880 let ext_y = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
881
882 let tests = [
883 (vec![], true),
884 (vec![ext_x.clone()], true),
885 (vec![ext_x.clone(), ext_x.clone()], false),
886 (vec![ext_x.clone(), ext_x.clone(), ext_x.clone()], false),
887 (vec![ext_y.clone()], true),
888 (vec![ext_y.clone(), ext_y.clone()], false),
889 (vec![ext_y.clone(), ext_y.clone(), ext_y.clone()], false),
890 (vec![ext_x.clone(), ext_y.clone()], true),
891 (vec![ext_y.clone(), ext_x.clone()], true),
892 (vec![ext_x.clone(), ext_x.clone(), ext_y.clone()], false),
893 (vec![ext_y.clone(), ext_y.clone(), ext_x.clone()], false),
894 (vec![ext_x.clone(), ext_y.clone(), ext_y.clone()], false),
895 (vec![ext_y.clone(), ext_x.clone(), ext_x.clone()], false),
896 (vec![ext_x.clone(), ext_y.clone(), ext_x.clone()], false),
897 (vec![ext_y.clone(), ext_x, ext_y], false),
898 ];
899
900 for (test, should_work) in tests.into_iter() {
901 {
903 let mut extensions: Extensions<AnyObject> = Extensions::default();
904
905 let mut works = true;
906 for ext in test.iter() {
907 match extensions.add(ext.clone()) {
908 Ok(_) => {}
909 Err(InvalidExtensionError::Duplicate) => {
910 works = false;
911 }
912 _ => panic!("This should have never happened."),
913 }
914 }
915
916 println!("{:?}, {:?}", test.clone(), should_work);
917 assert_eq!(works, should_work);
918 }
919
920 if should_work {
922 assert!(Extensions::<AnyObject>::try_from(test).is_ok());
923 } else {
924 assert!(Extensions::<AnyObject>::try_from(test).is_err());
925 }
926 }
927 }
928
929 #[test]
930 fn ensure_ordering() {
931 let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
935 let ext_y = Extension::ExternalPub(ExternalPubExtension::new(HpkePublicKey::new(vec![])));
936 let ext_z = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
937
938 for candidate in [ext_x, ext_y, ext_z]
939 .into_iter()
940 .permutations(3)
941 .collect::<Vec<_>>()
942 {
943 let candidate: Extensions<AnyObject> = Extensions::try_from(candidate).unwrap();
944 let bytes = candidate.tls_serialize_detached().unwrap();
945 let got = Extensions::tls_deserialize(&mut bytes.as_slice()).unwrap();
946 assert_eq!(candidate, got);
947 }
948 }
949
950 #[test]
951 fn that_unknown_extensions_are_de_serialized_correctly() {
952 let extension_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF100, 0xFFFF];
953 let extension_datas = [vec![], vec![0], vec![1, 2, 3]];
954
955 for extension_type in extension_types.into_iter() {
956 for extension_data in extension_datas.iter() {
957 let test = {
959 let mut buf = extension_type.to_be_bytes().to_vec();
960 buf.append(
961 &mut VLBytes::new(extension_data.clone())
962 .tls_serialize_detached()
963 .unwrap(),
964 );
965 buf
966 };
967
968 let got = Extension::tls_deserialize_exact(&test).unwrap();
970
971 match got {
972 Extension::Unknown(got_extension_type, ref got_extension_data) => {
973 assert_eq!(extension_type, got_extension_type);
974 assert_eq!(extension_data, &got_extension_data.0);
975 }
976 other => panic!("Expected `Extension::Unknown`, got {other:?}"),
977 }
978
979 let got_serialized = got.tls_serialize_detached().unwrap();
981 assert_eq!(test, got_serialized);
982 }
983 }
984 }
985}