Skip to main content

openmls/extensions/
mod.rs

1//! # Extensions
2//!
3//! In MLS, extensions appear in the following places:
4//!
5//! - In [`KeyPackages`](`crate::key_packages`), to describe client capabilities
6//!   and aspects of their participation in the group.
7//!
8//! - In `GroupInfo`, to inform new members of the group's parameters and to
9//!   provide any additional information required to join the group.
10//!
11//! - In the `GroupContext` object, to ensure that all members of the group have
12//!   a consistent view of the parameters in use.
13//!
14//! Note that `GroupInfo` and `GroupContext` are not exposed via OpenMLS' public
15//! API.
16//!
17//! OpenMLS supports the following extensions:
18//!
19//! - [`ApplicationIdExtension`] (KeyPackage extension)
20//! - [`RatchetTreeExtension`] (GroupInfo extension)
21//! - [`RequiredCapabilitiesExtension`] (GroupContext extension)
22//! - [`ExternalPubExtension`] (GroupInfo extension)
23
24use std::{
25    convert::Infallible,
26    fmt::Debug,
27    io::{Read, Write},
28    marker::PhantomData,
29};
30
31use serde::{Deserialize, Serialize};
32
33// Private
34#[cfg(feature = "extensions-draft-08")]
35mod app_data_dict_extension;
36mod application_id_extension;
37mod codec;
38mod external_pub_extension;
39mod external_sender_extension;
40mod last_resort;
41mod ratchet_tree_extension;
42mod required_capabilities;
43use errors::*;
44
45// Public
46pub mod errors;
47
48// Public re-exports
49#[cfg(feature = "extensions-draft-08")]
50pub use app_data_dict_extension::{AppDataDictionary, AppDataDictionaryExtension};
51pub use application_id_extension::ApplicationIdExtension;
52pub use external_pub_extension::ExternalPubExtension;
53pub use external_sender_extension::{
54    ExternalSender, ExternalSendersExtension, SenderExtensionIndex,
55};
56pub use last_resort::LastResortExtension;
57pub use ratchet_tree_extension::RatchetTreeExtension;
58pub use required_capabilities::RequiredCapabilitiesExtension;
59
60use tls_codec::{
61    Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
62    Size, TlsDeserialize, TlsSerialize, TlsSize,
63};
64
65use crate::{
66    group::GroupContext, key_packages::KeyPackage, messages::group_info::GroupInfo,
67    treesync::LeafNode,
68};
69
70#[cfg(test)]
71mod tests;
72
73/// MLS Extension Types
74///
75/// Copied from draft-ietf-mls-protocol-16:
76///
77/// | Value            | Name                     | Message(s) | Recommended | Reference |
78/// |:-----------------|:-------------------------|:-----------|:------------|:----------|
79/// | 0x0000           | RESERVED                 | N/A        | N/A         | RFC XXXX  |
80/// | 0x0001           | application_id           | LN         | Y           | RFC XXXX  |
81/// | 0x0002           | ratchet_tree             | GI         | Y           | RFC XXXX  |
82/// | 0x0003           | required_capabilities    | GC         | Y           | RFC XXXX  |
83/// | 0x0004           | external_pub             | GI         | Y           | RFC XXXX  |
84/// | 0x0005           | external_senders         | GC         | Y           | RFC XXXX  |
85/// | 0xff00  - 0xffff | Reserved for Private Use | N/A        | N/A         | RFC XXXX  |
86///
87/// Note: OpenMLS does not provide a `Reserved` variant in [ExtensionType].
88#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)]
89pub enum ExtensionType {
90    /// The application id extension allows applications to add an explicit,
91    /// application-defined identifier to a KeyPackage.
92    ApplicationId,
93
94    /// The ratchet tree extensions provides the whole public state of the
95    /// ratchet tree.
96    RatchetTree,
97
98    /// The required capabilities extension defines the configuration of a group
99    /// that imposes certain requirements on clients in the group.
100    RequiredCapabilities,
101
102    /// To join a group via an External Commit, a new member needs a GroupInfo
103    /// with an ExternalPub extension present in its extensions field.
104    ExternalPub,
105
106    /// Group context extension that contains the credentials and signature keys
107    /// of senders that are permitted to send external proposals to the group.
108    ExternalSenders,
109
110    /// KeyPackage extension that marks a KeyPackage for use in a last resort
111    /// scenario.
112    LastResort,
113
114    #[cfg(feature = "extensions-draft-08")]
115    /// AppDataDictionary extension
116    AppDataDictionary,
117
118    /// A GREASE extension type for ensuring extensibility.
119    Grease(u16),
120
121    /// A currently unknown extension type.
122    Unknown(u16),
123}
124
125impl ExtensionType {
126    /// Returns true for all extension types that are considered "default" by the spec.
127    pub(crate) fn is_default(self) -> bool {
128        match self {
129            ExtensionType::ApplicationId
130            | ExtensionType::RatchetTree
131            | ExtensionType::RequiredCapabilities
132            | ExtensionType::ExternalPub
133            | ExtensionType::ExternalSenders => true,
134            ExtensionType::LastResort | ExtensionType::Grease(_) | ExtensionType::Unknown(_) => {
135                false
136            }
137            #[cfg(feature = "extensions-draft-08")]
138            ExtensionType::AppDataDictionary => false,
139        }
140    }
141
142    /// Returns whether an extension type is valid when used in leaf nodes.
143    /// Returns None if validity can not be determined.
144    /// This is the case for unknown extensions.
145    //  https://validation.openmls.tech/#valn1601
146    pub(crate) fn is_valid_in_leaf_node(self) -> bool {
147        match self {
148            ExtensionType::Grease(_)
149            | ExtensionType::LastResort
150            | ExtensionType::RatchetTree
151            | ExtensionType::RequiredCapabilities
152            | ExtensionType::ExternalPub
153            | ExtensionType::ExternalSenders => false,
154            ExtensionType::Unknown(_) | ExtensionType::ApplicationId => true,
155            #[cfg(feature = "extensions-draft-08")]
156            ExtensionType::AppDataDictionary => true,
157        }
158    }
159    pub(crate) fn is_valid_in_group_info(self) -> Option<bool> {
160        match self {
161            ExtensionType::Grease(_)
162            | ExtensionType::LastResort
163            | ExtensionType::RequiredCapabilities
164            | ExtensionType::ExternalSenders
165            | ExtensionType::ApplicationId => Some(false),
166            ExtensionType::RatchetTree | ExtensionType::ExternalPub => Some(true),
167            ExtensionType::Unknown(_) => None,
168            #[cfg(feature = "extensions-draft-08")]
169            ExtensionType::AppDataDictionary => Some(true),
170        }
171    }
172
173    pub(crate) fn is_valid_in_group_context(self) -> bool {
174        match self {
175            ExtensionType::RequiredCapabilities
176            | ExtensionType::ExternalSenders
177            | ExtensionType::Unknown(_) => true,
178            #[cfg(feature = "extensions-draft-08")]
179            ExtensionType::AppDataDictionary => true,
180            _ => false,
181        }
182    }
183
184    /// Returns true if this is a GREASE extension type.
185    ///
186    /// GREASE values are used to ensure implementations properly handle unknown
187    /// extension types. See [RFC 9420 Section 13.5](https://www.rfc-editor.org/rfc/rfc9420.html#section-13.5).
188    pub fn is_grease(&self) -> bool {
189        matches!(self, ExtensionType::Grease(_))
190    }
191}
192
193impl Size for ExtensionType {
194    fn tls_serialized_len(&self) -> usize {
195        2
196    }
197}
198
199impl TlsDeserializeTrait for ExtensionType {
200    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
201    where
202        Self: Sized,
203    {
204        let mut extension_type = [0u8; 2];
205        bytes.read_exact(&mut extension_type)?;
206
207        Ok(ExtensionType::from(u16::from_be_bytes(extension_type)))
208    }
209}
210
211impl DeserializeBytes for ExtensionType {
212    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
213    where
214        Self: Sized,
215    {
216        let mut bytes_ref = bytes;
217        let extension_type = ExtensionType::tls_deserialize(&mut bytes_ref)?;
218        let remainder = &bytes[extension_type.tls_serialized_len()..];
219        Ok((extension_type, remainder))
220    }
221}
222
223impl TlsSerializeTrait for ExtensionType {
224    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
225        writer.write_all(&u16::from(*self).to_be_bytes())?;
226
227        Ok(2)
228    }
229}
230
231impl From<u16> for ExtensionType {
232    fn from(a: u16) -> Self {
233        match a {
234            1 => ExtensionType::ApplicationId,
235            2 => ExtensionType::RatchetTree,
236            3 => ExtensionType::RequiredCapabilities,
237            4 => ExtensionType::ExternalPub,
238            5 => ExtensionType::ExternalSenders,
239            #[cfg(feature = "extensions-draft-08")]
240            6 => ExtensionType::AppDataDictionary,
241            10 => ExtensionType::LastResort,
242            unknown if crate::grease::is_grease_value(unknown) => ExtensionType::Grease(unknown),
243            unknown => ExtensionType::Unknown(unknown),
244        }
245    }
246}
247
248impl From<ExtensionType> for u16 {
249    fn from(value: ExtensionType) -> Self {
250        match value {
251            ExtensionType::ApplicationId => 1,
252            ExtensionType::RatchetTree => 2,
253            ExtensionType::RequiredCapabilities => 3,
254            ExtensionType::ExternalPub => 4,
255            ExtensionType::ExternalSenders => 5,
256            #[cfg(feature = "extensions-draft-08")]
257            ExtensionType::AppDataDictionary => 6,
258            ExtensionType::LastResort => 10,
259            ExtensionType::Grease(value) => value,
260            ExtensionType::Unknown(unknown) => unknown,
261        }
262    }
263}
264
265/// # Extension
266///
267/// An extension is one of the [`Extension`] enum values.
268/// The enum provides a set of common functionality for all extensions.
269///
270/// See the individual extensions for more details on each extension.
271///
272/// ```c
273/// // draft-ietf-mls-protocol-16
274/// struct {
275///     ExtensionType extension_type;
276///     opaque extension_data<V>;
277/// } Extension;
278/// ```
279#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
280pub enum Extension {
281    /// An [`ApplicationIdExtension`]
282    ApplicationId(ApplicationIdExtension),
283
284    /// A [`RatchetTreeExtension`]
285    RatchetTree(RatchetTreeExtension),
286
287    /// A [`RequiredCapabilitiesExtension`]
288    RequiredCapabilities(RequiredCapabilitiesExtension),
289
290    /// An [`ExternalPubExtension`]
291    ExternalPub(ExternalPubExtension),
292
293    /// An [`ExternalSendersExtension`]
294    ExternalSenders(ExternalSendersExtension),
295
296    /// An [`AppDataDictionaryExtension`]
297    #[cfg(feature = "extensions-draft-08")]
298    AppDataDictionary(AppDataDictionaryExtension),
299
300    /// A [`LastResortExtension`]
301    LastResort(LastResortExtension),
302
303    /// A currently unknown extension.
304    Unknown(u16, UnknownExtension),
305}
306
307/// A unknown/unparsed extension represented by raw bytes.
308#[derive(
309    PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSize, TlsSerialize, TlsDeserialize,
310)]
311pub struct UnknownExtension(pub Vec<u8>);
312
313/// A Extension for Object of type T
314#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
315pub struct Extensions<T> {
316    unique: Vec<Extension>,
317    #[serde(skip)]
318    _object: core::marker::PhantomData<T>,
319}
320
321#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, TlsSize, TlsSerialize, TlsDeserialize)]
322/// Any object
323pub struct AnyObject;
324
325impl<T> Default for Extensions<T> {
326    fn default() -> Self {
327        Self {
328            unique: vec![],
329            _object: PhantomData,
330        }
331    }
332}
333
334impl<T> Size for Extensions<T> {
335    fn tls_serialized_len(&self) -> usize {
336        Vec::tls_serialized_len(&self.unique)
337    }
338}
339
340impl<T> TlsSerializeTrait for Extensions<T> {
341    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
342        self.unique.tls_serialize(writer)
343    }
344}
345
346impl<T: ExtensionValidator> TlsDeserializeTrait for Extensions<T>
347where
348    InvalidExtensionError: From<T::Error>,
349{
350    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
351    where
352        Self: Sized,
353    {
354        let candidate: Vec<Extension> = Vec::tls_deserialize(bytes)?;
355        Extensions::<T>::try_from(candidate)
356            .map_err(|_| Error::DecodingError("Found duplicate extensions".into()))
357    }
358}
359
360impl<T: ExtensionValidator> DeserializeBytes for Extensions<T>
361where
362    InvalidExtensionError: From<T::Error>,
363{
364    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
365    where
366        Self: Sized,
367    {
368        let mut bytes_ref = bytes;
369        let extensions = Extensions::<T>::tls_deserialize(&mut bytes_ref)?;
370        let remainder = &bytes[extensions.tls_serialized_len()..];
371        Ok((extensions, remainder))
372    }
373}
374
375impl<T: ExtensionValidator> Extensions<T> {
376    /// Create an empty extension list.
377    pub fn empty() -> Self {
378        Self {
379            unique: vec![],
380            _object: PhantomData,
381        }
382    }
383
384    /// Returns an iterator over the extension list.
385    pub fn iter(&self) -> impl Iterator<Item = &Extension> {
386        self.unique.iter()
387    }
388
389    /// Remove an extension from the extension list.
390    ///
391    /// Returns the removed extension or `None` when there is no extension with
392    /// the given extension type.
393    pub fn remove(&mut self, extension_type: ExtensionType) -> Option<Extension> {
394        if let Some(pos) = self
395            .unique
396            .iter()
397            .position(|ext| ext.extension_type() == extension_type)
398        {
399            Some(self.unique.remove(pos))
400        } else {
401            None
402        }
403    }
404
405    /// Returns `true` iff the extension list contains an extension with the
406    /// given extension type.
407    pub fn contains(&self, extension_type: ExtensionType) -> bool {
408        self.unique
409            .iter()
410            .any(|ext| ext.extension_type() == extension_type)
411    }
412}
413
414impl<T> Extensions<T>
415where
416    T: ExtensionValidator,
417    InvalidExtensionError: From<T::Error>,
418{
419    /// Create an extension list with a single extension.
420    pub fn single(extension: Extension) -> Result<Self, InvalidExtensionError> {
421        T::validate_extension_type(&extension)?;
422        Ok(Self {
423            unique: vec![extension],
424            _object: PhantomData,
425        })
426    }
427
428    /// Create an extension list with multiple extensions.
429    ///
430    /// This function will fail when the list of extensions contains duplicate
431    /// extension types.
432    pub fn from_vec(extensions: Vec<Extension>) -> Result<Self, InvalidExtensionError> {
433        extensions.try_into()
434    }
435
436    /// Validate if the extensions are valid for this context
437    pub fn validate<'a>(
438        extensions: impl Iterator<Item = &'a Extension>,
439    ) -> Result<(), InvalidExtensionError> {
440        for ext in extensions {
441            T::validate_extension_type(ext)?;
442        }
443        Ok(())
444    }
445
446    /// Add an extension to the extension list.
447    ///
448    /// Returns an error when there already is an extension with the same
449    /// extension type.
450    pub fn add(&mut self, extension: Extension) -> Result<(), InvalidExtensionError> {
451        T::validate_extension_type(&extension)?;
452        if self.contains(extension.extension_type()) {
453            return Err(InvalidExtensionError::Duplicate);
454        }
455
456        self.unique.push(extension);
457
458        Ok(())
459    }
460
461    /// Add an extension to the extension list (or replace an existing one.)
462    ///
463    /// Returns the replaced extension (if any).
464    pub fn add_or_replace(
465        &mut self,
466        extension: Extension,
467    ) -> Result<Option<Extension>, InvalidExtensionError> {
468        T::validate_extension_type(&extension)?;
469        let replaced = self.remove(extension.extension_type());
470        self.unique.push(extension);
471        Ok(replaced)
472    }
473}
474
475/// Can be implemented by a type to validate extensions.
476pub trait ExtensionValidator {
477    /// The error returned by the validator
478    type Error;
479
480    /// Check if the extension is valid.
481    fn validate_extension_type(ext: &Extension) -> Result<(), Self::Error>;
482}
483
484impl ExtensionValidator for AnyObject {
485    type Error = Infallible;
486
487    fn validate_extension_type(_ext: &Extension) -> Result<(), Infallible> {
488        Ok(())
489    }
490}
491
492impl<T: ExtensionValidator> TryFrom<Vec<Extension>> for Extensions<T>
493where
494    InvalidExtensionError: From<T::Error>,
495{
496    type Error = InvalidExtensionError;
497
498    fn try_from(candidate: Vec<Extension>) -> Result<Self, Self::Error> {
499        let mut unique: Vec<Extension> = Vec::new();
500        for extension in candidate.into_iter() {
501            T::validate_extension_type(&extension)?;
502
503            if unique
504                .iter()
505                .any(|ext| ext.extension_type() == extension.extension_type())
506            {
507                return Err(InvalidExtensionError::Duplicate);
508            } else {
509                unique.push(extension);
510            }
511        }
512
513        Ok(Self {
514            unique,
515            _object: PhantomData,
516        })
517    }
518}
519
520// https://validation.openmls.tech/#valn1602
521impl ExtensionValidator for GroupInfo {
522    type Error = ExtensionTypeNotValidInGroupInfoError;
523
524    fn validate_extension_type(
525        ext: &Extension,
526    ) -> Result<(), ExtensionTypeNotValidInGroupInfoError> {
527        if ext.extension_type().is_valid_in_group_info() == Some(true)
528            || ext.extension_type().is_valid_in_group_info().is_none()
529        {
530            Ok(())
531        } else {
532            Err(ExtensionTypeNotValidInGroupInfoError(ext.extension_type()))
533        }
534    }
535}
536
537// https://validation.openmls.tech/#valn1603
538impl ExtensionValidator for GroupContext {
539    type Error = ExtensionTypeNotValidInGroupContextError;
540
541    fn validate_extension_type(
542        ext: &Extension,
543    ) -> Result<(), ExtensionTypeNotValidInGroupContextError> {
544        if ext.extension_type().is_valid_in_group_context() {
545            Ok(())
546        } else {
547            Err(ExtensionTypeNotValidInGroupContextError(
548                ext.extension_type(),
549            ))
550        }
551    }
552}
553
554// https://validation.openmls.tech/#valn1604
555impl ExtensionValidator for KeyPackage {
556    type Error = ExtensionTypeNotValidInKeyPackageError;
557
558    fn validate_extension_type(
559        ext: &Extension,
560    ) -> Result<(), ExtensionTypeNotValidInKeyPackageError> {
561        if ext.extension_type() == ExtensionType::LastResort
562            || matches!(ext.extension_type(), ExtensionType::Unknown(_))
563        {
564            Ok(())
565        } else {
566            Err(ExtensionTypeNotValidInKeyPackageError(ext.extension_type()))
567        }
568    }
569}
570
571// https://validation.openmls.tech/#valn1601
572impl ExtensionValidator for LeafNode {
573    type Error = ExtensionTypeNotValidInLeafNodeError;
574
575    fn validate_extension_type(
576        ext: &Extension,
577    ) -> Result<(), ExtensionTypeNotValidInLeafNodeError> {
578        if ext.extension_type().is_valid_in_leaf_node() {
579            Ok(())
580        } else {
581            Err(ExtensionTypeNotValidInLeafNodeError(ext.extension_type()))
582        }
583    }
584}
585
586impl<T> Extensions<T> {
587    fn find_by_type(&self, extension_type: ExtensionType) -> Option<&Extension> {
588        self.unique
589            .iter()
590            .find(|ext| ext.extension_type() == extension_type)
591    }
592
593    /// Get a reference to the [`ApplicationIdExtension`] if there is any.
594    pub fn application_id(&self) -> Option<&ApplicationIdExtension> {
595        self.find_by_type(ExtensionType::ApplicationId)
596            .and_then(|e| match e {
597                Extension::ApplicationId(e) => Some(e),
598                _ => None,
599            })
600    }
601
602    /// Get a reference to the [`RatchetTreeExtension`] if there is any.
603    pub fn ratchet_tree(&self) -> Option<&RatchetTreeExtension> {
604        self.find_by_type(ExtensionType::RatchetTree)
605            .and_then(|e| match e {
606                Extension::RatchetTree(e) => Some(e),
607                _ => None,
608            })
609    }
610
611    /// Get a reference to the [`RequiredCapabilitiesExtension`] if there is
612    /// any.
613    pub fn required_capabilities(&self) -> Option<&RequiredCapabilitiesExtension> {
614        self.find_by_type(ExtensionType::RequiredCapabilities)
615            .and_then(|e| match e {
616                Extension::RequiredCapabilities(e) => Some(e),
617                _ => None,
618            })
619    }
620
621    /// Get a reference to the [`ExternalPubExtension`] if there is any.
622    pub fn external_pub(&self) -> Option<&ExternalPubExtension> {
623        self.find_by_type(ExtensionType::ExternalPub)
624            .and_then(|e| match e {
625                Extension::ExternalPub(e) => Some(e),
626                _ => None,
627            })
628    }
629
630    /// Get a reference to the [`ExternalSendersExtension`] if there is any.
631    pub fn external_senders(&self) -> Option<&ExternalSendersExtension> {
632        self.find_by_type(ExtensionType::ExternalSenders)
633            .and_then(|e| match e {
634                Extension::ExternalSenders(e) => Some(e),
635                _ => None,
636            })
637    }
638
639    #[cfg(feature = "extensions-draft-08")]
640    /// Get a reference to the [`AppDataDictionaryExtension`] if there is any.
641    pub fn app_data_dictionary(&self) -> Option<&AppDataDictionaryExtension> {
642        self.find_by_type(ExtensionType::AppDataDictionary)
643            .and_then(|e| match e {
644                Extension::AppDataDictionary(e) => Some(e),
645                _ => None,
646            })
647    }
648
649    /// Get a reference to the [`UnknownExtension`] with the given type id, if there is any.
650    pub fn unknown(&self, extension_type_id: u16) -> Option<&UnknownExtension> {
651        let extension_type: ExtensionType = extension_type_id.into();
652
653        match extension_type {
654            ExtensionType::Unknown(_) => self.find_by_type(extension_type).and_then(|e| match e {
655                Extension::Unknown(_, e) => Some(e),
656                _ => None,
657            }),
658            _ => None,
659        }
660    }
661}
662
663impl Extension {
664    /// Get a reference to this extension as [`ApplicationIdExtension`].
665    /// Returns an [`ExtensionError::InvalidExtensionType`] if called on an
666    /// [`Extension`] that's not an [`ApplicationIdExtension`].
667    pub fn as_application_id_extension(&self) -> Result<&ApplicationIdExtension, ExtensionError> {
668        match self {
669            Self::ApplicationId(e) => Ok(e),
670            _ => Err(ExtensionError::InvalidExtensionType(
671                "This is not an ApplicationIdExtension".into(),
672            )),
673        }
674    }
675    #[cfg(feature = "extensions-draft-08")]
676    /// Get a reference to this extension as [`AppDataDictionaryExtension`].
677    /// Returns an [`ExtensionError::InvalidExtensionType`] if called on an
678    /// [`Extension`] that's not an [`AppDataDictionaryExtension`].
679    pub fn as_app_data_dictionary_extension(
680        &self,
681    ) -> Result<&AppDataDictionaryExtension, ExtensionError> {
682        match self {
683            Self::AppDataDictionary(e) => Ok(e),
684            _ => Err(ExtensionError::InvalidExtensionType(
685                "This is not an AppDataDictionaryExtension".into(),
686            )),
687        }
688    }
689
690    /// Get a reference to this extension as [`RatchetTreeExtension`].
691    /// Returns an [`ExtensionError::InvalidExtensionType`] if called on
692    /// an [`Extension`] that's not a [`RatchetTreeExtension`].
693    pub fn as_ratchet_tree_extension(&self) -> Result<&RatchetTreeExtension, ExtensionError> {
694        match self {
695            Self::RatchetTree(rte) => Ok(rte),
696            _ => Err(ExtensionError::InvalidExtensionType(
697                "This is not a RatchetTreeExtension".into(),
698            )),
699        }
700    }
701
702    /// Get a reference to this extension as [`RequiredCapabilitiesExtension`].
703    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
704    /// an [`Extension`] that's not a [`RequiredCapabilitiesExtension`].
705    pub fn as_required_capabilities_extension(
706        &self,
707    ) -> Result<&RequiredCapabilitiesExtension, ExtensionError> {
708        match self {
709            Self::RequiredCapabilities(e) => Ok(e),
710            _ => Err(ExtensionError::InvalidExtensionType(
711                "This is not a RequiredCapabilitiesExtension".into(),
712            )),
713        }
714    }
715
716    /// Get a reference to this extension as [`ExternalPubExtension`].
717    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
718    /// an [`Extension`] that's not a [`ExternalPubExtension`].
719    pub fn as_external_pub_extension(&self) -> Result<&ExternalPubExtension, ExtensionError> {
720        match self {
721            Self::ExternalPub(e) => Ok(e),
722            _ => Err(ExtensionError::InvalidExtensionType(
723                "This is not an ExternalPubExtension".into(),
724            )),
725        }
726    }
727
728    /// Get a reference to this extension as [`ExternalSendersExtension`].
729    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
730    /// an [`Extension`] that's not a [`ExternalSendersExtension`].
731    pub fn as_external_senders_extension(
732        &self,
733    ) -> Result<&ExternalSendersExtension, ExtensionError> {
734        match self {
735            Self::ExternalSenders(e) => Ok(e),
736            _ => Err(ExtensionError::InvalidExtensionType(
737                "This is not an ExternalSendersExtension".into(),
738            )),
739        }
740    }
741
742    /// Returns the [`ExtensionType`]
743    #[inline]
744    pub const fn extension_type(&self) -> ExtensionType {
745        match self {
746            Extension::ApplicationId(_) => ExtensionType::ApplicationId,
747            Extension::RatchetTree(_) => ExtensionType::RatchetTree,
748            Extension::RequiredCapabilities(_) => ExtensionType::RequiredCapabilities,
749            Extension::ExternalPub(_) => ExtensionType::ExternalPub,
750            Extension::ExternalSenders(_) => ExtensionType::ExternalSenders,
751            #[cfg(feature = "extensions-draft-08")]
752            Extension::AppDataDictionary(_) => ExtensionType::AppDataDictionary,
753            Extension::LastResort(_) => ExtensionType::LastResort,
754            Extension::Unknown(kind, _) => ExtensionType::Unknown(*kind),
755        }
756    }
757}
758
759macro_rules! impl_from_extensions_validator {
760    ($validator:ty, $error:ty) => {
761        impl From<Extensions<$validator>> for Extensions<AnyObject> {
762            fn from(value: Extensions<$validator>) -> Self {
763                Extensions {
764                    unique: value.unique,
765                    _object: PhantomData,
766                }
767            }
768        }
769
770        impl TryFrom<Extensions<AnyObject>> for Extensions<$validator> {
771            type Error = $error;
772
773            fn try_from(value: Extensions<AnyObject>) -> Result<Self, $error> {
774                value
775                    .unique
776                    .iter()
777                    .try_for_each(<$validator as ExtensionValidator>::validate_extension_type)?;
778
779                Ok(Extensions {
780                    unique: value.unique,
781                    _object: PhantomData,
782                })
783            }
784        }
785    };
786}
787
788impl_from_extensions_validator!(GroupContext, ExtensionTypeNotValidInGroupContextError);
789impl_from_extensions_validator!(LeafNode, ExtensionTypeNotValidInLeafNodeError);
790impl_from_extensions_validator!(KeyPackage, ExtensionTypeNotValidInKeyPackageError);
791
792#[cfg(any(feature = "test-utils", test))]
793impl Extensions<AnyObject> {
794    /// Coerces the extensions to an Extensions with the given validator. Unsafe.
795    pub(crate) fn coerce<T: ExtensionValidator>(self) -> Extensions<T> {
796        Extensions {
797            unique: self.unique,
798            _object: PhantomData,
799        }
800    }
801}
802#[cfg(test)]
803mod test {
804    use itertools::Itertools;
805    use tls_codec::{Deserialize, Serialize, VLBytes};
806
807    use crate::{ciphersuite::HpkePublicKey, extensions::*};
808
809    #[test]
810    fn add() {
811        let mut extensions: Extensions<AnyObject> = Extensions::default();
812        extensions
813            .add(Extension::RequiredCapabilities(
814                RequiredCapabilitiesExtension::default(),
815            ))
816            .unwrap();
817        assert!(extensions
818            .add(Extension::RequiredCapabilities(
819                RequiredCapabilitiesExtension::default()
820            ))
821            .is_err());
822    }
823
824    #[test]
825    fn add_try_from() {
826        // Create some extensions with different extension types and test that
827        // duplicates are rejected. The extension content does not matter in this test.
828        let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
829        let ext_y = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
830
831        let tests = [
832            (vec![], true),
833            (vec![ext_x.clone()], true),
834            (vec![ext_x.clone(), ext_x.clone()], false),
835            (vec![ext_x.clone(), ext_x.clone(), ext_x.clone()], false),
836            (vec![ext_y.clone()], true),
837            (vec![ext_y.clone(), ext_y.clone()], false),
838            (vec![ext_y.clone(), ext_y.clone(), ext_y.clone()], false),
839            (vec![ext_x.clone(), ext_y.clone()], true),
840            (vec![ext_y.clone(), ext_x.clone()], true),
841            (vec![ext_x.clone(), ext_x.clone(), ext_y.clone()], false),
842            (vec![ext_y.clone(), ext_y.clone(), ext_x.clone()], false),
843            (vec![ext_x.clone(), ext_y.clone(), ext_y.clone()], false),
844            (vec![ext_y.clone(), ext_x.clone(), ext_x.clone()], false),
845            (vec![ext_x.clone(), ext_y.clone(), ext_x.clone()], false),
846            (vec![ext_y.clone(), ext_x, ext_y], false),
847        ];
848
849        for (test, should_work) in tests.into_iter() {
850            // Test `add`.
851            {
852                let mut extensions: Extensions<AnyObject> = Extensions::default();
853
854                let mut works = true;
855                for ext in test.iter() {
856                    match extensions.add(ext.clone()) {
857                        Ok(_) => {}
858                        Err(InvalidExtensionError::Duplicate) => {
859                            works = false;
860                        }
861                        _ => panic!("This should have never happened."),
862                    }
863                }
864
865                println!("{:?}, {:?}", test.clone(), should_work);
866                assert_eq!(works, should_work);
867            }
868
869            // Test `try_from`.
870            if should_work {
871                assert!(Extensions::<AnyObject>::try_from(test).is_ok());
872            } else {
873                assert!(Extensions::<AnyObject>::try_from(test).is_err());
874            }
875        }
876    }
877
878    #[test]
879    fn ensure_ordering() {
880        // Create some extensions with different extension types and test
881        // that all permutations keep their order after being (de)serialized.
882        // The extension content does not matter in this test.
883        let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
884        let ext_y = Extension::ExternalPub(ExternalPubExtension::new(HpkePublicKey::new(vec![])));
885        let ext_z = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
886
887        for candidate in [ext_x, ext_y, ext_z]
888            .into_iter()
889            .permutations(3)
890            .collect::<Vec<_>>()
891        {
892            let candidate: Extensions<AnyObject> = Extensions::try_from(candidate).unwrap();
893            let bytes = candidate.tls_serialize_detached().unwrap();
894            let got = Extensions::tls_deserialize(&mut bytes.as_slice()).unwrap();
895            assert_eq!(candidate, got);
896        }
897    }
898
899    #[test]
900    fn that_unknown_extensions_are_de_serialized_correctly() {
901        let extension_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF100, 0xFFFF];
902        let extension_datas = [vec![], vec![0], vec![1, 2, 3]];
903
904        for extension_type in extension_types.into_iter() {
905            for extension_data in extension_datas.iter() {
906                // Construct an unknown extension manually.
907                let test = {
908                    let mut buf = extension_type.to_be_bytes().to_vec();
909                    buf.append(
910                        &mut VLBytes::new(extension_data.clone())
911                            .tls_serialize_detached()
912                            .unwrap(),
913                    );
914                    buf
915                };
916
917                // Test deserialization.
918                let got = Extension::tls_deserialize_exact(&test).unwrap();
919
920                match got {
921                    Extension::Unknown(got_extension_type, ref got_extension_data) => {
922                        assert_eq!(extension_type, got_extension_type);
923                        assert_eq!(extension_data, &got_extension_data.0);
924                    }
925                    other => panic!("Expected `Extension::Unknown`, got {other:?}"),
926                }
927
928                // Test serialization.
929                let got_serialized = got.tls_serialize_detached().unwrap();
930                assert_eq!(test, got_serialized);
931            }
932        }
933    }
934}