Skip to main content

openmls/extensions/
mod.rs

1//! # Extensions
2//!
3//! In MLS, extensions appear in the following places:
4//!
5//! - In [`KeyPackages`](`crate::key_packages`), to describe client capabilities
6//!   and aspects of their participation in the group.
7//!
8//! - In `GroupInfo`, to inform new members of the group's parameters and to
9//!   provide any additional information required to join the group.
10//!
11//! - In the `GroupContext` object, to ensure that all members of the group have
12//!   a consistent view of the parameters in use.
13//!
14//! Note that `GroupInfo` and `GroupContext` are not exposed via OpenMLS' public
15//! API.
16//!
17//! OpenMLS supports the following extensions:
18//!
19//! - [`ApplicationIdExtension`] (KeyPackage extension)
20//! - [`RatchetTreeExtension`] (GroupInfo extension)
21//! - [`RequiredCapabilitiesExtension`] (GroupContext extension)
22//! - [`ExternalPubExtension`] (GroupInfo extension)
23
24use std::{
25    convert::Infallible,
26    fmt::Debug,
27    io::{Read, Write},
28    marker::PhantomData,
29};
30
31use serde::{Deserialize, Serialize};
32
33// Private
34#[cfg(feature = "extensions-draft-08")]
35mod app_data_dict_extension;
36mod application_id_extension;
37mod codec;
38mod external_pub_extension;
39mod external_sender_extension;
40mod last_resort;
41mod ratchet_tree_extension;
42mod required_capabilities;
43use errors::*;
44
45// Public
46pub mod errors;
47
48// Public re-exports
49#[cfg(feature = "extensions-draft-08")]
50pub use app_data_dict_extension::{AppDataDictionary, AppDataDictionaryExtension};
51pub use application_id_extension::ApplicationIdExtension;
52pub use external_pub_extension::ExternalPubExtension;
53pub use external_sender_extension::{
54    ExternalSender, ExternalSendersExtension, SenderExtensionIndex,
55};
56pub use last_resort::LastResortExtension;
57pub use ratchet_tree_extension::RatchetTreeExtension;
58pub use required_capabilities::RequiredCapabilitiesExtension;
59
60use tls_codec::{
61    Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
62    Size, TlsDeserialize, TlsSerialize, TlsSize,
63};
64
65use crate::{
66    group::GroupContext, key_packages::KeyPackage, messages::group_info::GroupInfo,
67    treesync::LeafNode,
68};
69
70#[cfg(test)]
71mod tests;
72
73/// MLS Extension Types
74///
75/// Copied from draft-ietf-mls-protocol-16:
76///
77/// | Value            | Name                     | Message(s) | Recommended | Reference |
78/// |:-----------------|:-------------------------|:-----------|:------------|:----------|
79/// | 0x0000           | RESERVED                 | N/A        | N/A         | RFC XXXX  |
80/// | 0x0001           | application_id           | LN         | Y           | RFC XXXX  |
81/// | 0x0002           | ratchet_tree             | GI         | Y           | RFC XXXX  |
82/// | 0x0003           | required_capabilities    | GC         | Y           | RFC XXXX  |
83/// | 0x0004           | external_pub             | GI         | Y           | RFC XXXX  |
84/// | 0x0005           | external_senders         | GC         | Y           | RFC XXXX  |
85/// | 0xff00  - 0xffff | Reserved for Private Use | N/A        | N/A         | RFC XXXX  |
86///
87/// Note: OpenMLS does not provide a `Reserved` variant in [ExtensionType].
88#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)]
89pub enum ExtensionType {
90    /// The application id extension allows applications to add an explicit,
91    /// application-defined identifier to a KeyPackage.
92    ApplicationId,
93
94    /// The ratchet tree extensions provides the whole public state of the
95    /// ratchet tree.
96    RatchetTree,
97
98    /// The required capabilities extension defines the configuration of a group
99    /// that imposes certain requirements on clients in the group.
100    RequiredCapabilities,
101
102    /// To join a group via an External Commit, a new member needs a GroupInfo
103    /// with an ExternalPub extension present in its extensions field.
104    ExternalPub,
105
106    /// Group context extension that contains the credentials and signature keys
107    /// of senders that are permitted to send external proposals to the group.
108    ExternalSenders,
109
110    /// KeyPackage extension that marks a KeyPackage for use in a last resort
111    /// scenario.
112    LastResort,
113
114    #[cfg(feature = "extensions-draft-08")]
115    /// AppDataDictionary extension
116    AppDataDictionary,
117
118    /// A GREASE extension type for ensuring extensibility.
119    Grease(u16),
120
121    /// A currently unknown extension type.
122    Unknown(u16),
123}
124
125impl ExtensionType {
126    /// Returns true for all extension types that are considered "default" by the spec.
127    pub(crate) fn is_default(self) -> bool {
128        match self {
129            ExtensionType::ApplicationId
130            | ExtensionType::RatchetTree
131            | ExtensionType::RequiredCapabilities
132            | ExtensionType::ExternalPub
133            | ExtensionType::ExternalSenders => true,
134            ExtensionType::LastResort | ExtensionType::Grease(_) | ExtensionType::Unknown(_) => {
135                false
136            }
137            #[cfg(feature = "extensions-draft-08")]
138            ExtensionType::AppDataDictionary => false,
139        }
140    }
141
142    /// Returns whether an extension type is valid when used in leaf nodes.
143    /// Returns None if validity can not be determined.
144    /// This is the case for unknown extensions.
145    //  https://validation.openmls.tech/#valn1601
146    pub(crate) fn is_valid_in_leaf_node(self) -> bool {
147        match self {
148            ExtensionType::Grease(_)
149            | ExtensionType::LastResort
150            | ExtensionType::RatchetTree
151            | ExtensionType::RequiredCapabilities
152            | ExtensionType::ExternalPub
153            | ExtensionType::ExternalSenders => false,
154            ExtensionType::Unknown(_) | ExtensionType::ApplicationId => true,
155            #[cfg(feature = "extensions-draft-08")]
156            ExtensionType::AppDataDictionary => true,
157        }
158    }
159    pub(crate) fn is_valid_in_group_info(self) -> Option<bool> {
160        match self {
161            ExtensionType::Grease(_)
162            | ExtensionType::LastResort
163            | ExtensionType::RequiredCapabilities
164            | ExtensionType::ExternalSenders
165            | ExtensionType::ApplicationId => Some(false),
166            ExtensionType::RatchetTree | ExtensionType::ExternalPub => Some(true),
167            ExtensionType::Unknown(_) => None,
168            #[cfg(feature = "extensions-draft-08")]
169            ExtensionType::AppDataDictionary => Some(true),
170        }
171    }
172
173    pub(crate) fn is_valid_in_key_package(self) -> bool {
174        match self {
175            ExtensionType::Grease(_)
176            | ExtensionType::RatchetTree
177            | ExtensionType::RequiredCapabilities
178            | ExtensionType::ExternalPub
179            | ExtensionType::ExternalSenders
180            | ExtensionType::ApplicationId => false,
181            ExtensionType::Unknown(_) | ExtensionType::LastResort => true,
182            #[cfg(feature = "extensions-draft-08")]
183            ExtensionType::AppDataDictionary => true,
184        }
185    }
186
187    pub(crate) fn is_valid_in_group_context(self) -> bool {
188        match self {
189            ExtensionType::RequiredCapabilities
190            | ExtensionType::ExternalSenders
191            | ExtensionType::Unknown(_) => true,
192            #[cfg(feature = "extensions-draft-08")]
193            ExtensionType::AppDataDictionary => true,
194            _ => false,
195        }
196    }
197
198    /// Returns true if this is a GREASE extension type.
199    ///
200    /// GREASE values are used to ensure implementations properly handle unknown
201    /// extension types. See [RFC 9420 Section 13.5](https://www.rfc-editor.org/rfc/rfc9420.html#section-13.5).
202    pub fn is_grease(&self) -> bool {
203        matches!(self, ExtensionType::Grease(_))
204    }
205}
206
207impl Size for ExtensionType {
208    fn tls_serialized_len(&self) -> usize {
209        2
210    }
211}
212
213impl TlsDeserializeTrait for ExtensionType {
214    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
215    where
216        Self: Sized,
217    {
218        let mut extension_type = [0u8; 2];
219        bytes.read_exact(&mut extension_type)?;
220
221        Ok(ExtensionType::from(u16::from_be_bytes(extension_type)))
222    }
223}
224
225impl DeserializeBytes for ExtensionType {
226    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
227    where
228        Self: Sized,
229    {
230        let mut bytes_ref = bytes;
231        let extension_type = ExtensionType::tls_deserialize(&mut bytes_ref)?;
232        let remainder = &bytes[extension_type.tls_serialized_len()..];
233        Ok((extension_type, remainder))
234    }
235}
236
237impl TlsSerializeTrait for ExtensionType {
238    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
239        writer.write_all(&u16::from(*self).to_be_bytes())?;
240
241        Ok(2)
242    }
243}
244
245impl From<u16> for ExtensionType {
246    fn from(a: u16) -> Self {
247        match a {
248            1 => ExtensionType::ApplicationId,
249            2 => ExtensionType::RatchetTree,
250            3 => ExtensionType::RequiredCapabilities,
251            4 => ExtensionType::ExternalPub,
252            5 => ExtensionType::ExternalSenders,
253            #[cfg(feature = "extensions-draft-08")]
254            6 => ExtensionType::AppDataDictionary,
255            10 => ExtensionType::LastResort,
256            unknown if crate::grease::is_grease_value(unknown) => ExtensionType::Grease(unknown),
257            unknown => ExtensionType::Unknown(unknown),
258        }
259    }
260}
261
262impl From<ExtensionType> for u16 {
263    fn from(value: ExtensionType) -> Self {
264        match value {
265            ExtensionType::ApplicationId => 1,
266            ExtensionType::RatchetTree => 2,
267            ExtensionType::RequiredCapabilities => 3,
268            ExtensionType::ExternalPub => 4,
269            ExtensionType::ExternalSenders => 5,
270            #[cfg(feature = "extensions-draft-08")]
271            ExtensionType::AppDataDictionary => 6,
272            ExtensionType::LastResort => 10,
273            ExtensionType::Grease(value) => value,
274            ExtensionType::Unknown(unknown) => unknown,
275        }
276    }
277}
278
279/// # Extension
280///
281/// An extension is one of the [`Extension`] enum values.
282/// The enum provides a set of common functionality for all extensions.
283///
284/// See the individual extensions for more details on each extension.
285///
286/// ```c
287/// // draft-ietf-mls-protocol-16
288/// struct {
289///     ExtensionType extension_type;
290///     opaque extension_data<V>;
291/// } Extension;
292/// ```
293#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
294pub enum Extension {
295    /// An [`ApplicationIdExtension`]
296    ApplicationId(ApplicationIdExtension),
297
298    /// A [`RatchetTreeExtension`]
299    RatchetTree(RatchetTreeExtension),
300
301    /// A [`RequiredCapabilitiesExtension`]
302    RequiredCapabilities(RequiredCapabilitiesExtension),
303
304    /// An [`ExternalPubExtension`]
305    ExternalPub(ExternalPubExtension),
306
307    /// An [`ExternalSendersExtension`]
308    ExternalSenders(ExternalSendersExtension),
309
310    /// An [`AppDataDictionaryExtension`]
311    #[cfg(feature = "extensions-draft-08")]
312    AppDataDictionary(AppDataDictionaryExtension),
313
314    /// A [`LastResortExtension`]
315    LastResort(LastResortExtension),
316
317    /// A currently unknown extension.
318    Unknown(u16, UnknownExtension),
319}
320
321/// A unknown/unparsed extension represented by raw bytes.
322#[derive(
323    PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSize, TlsSerialize, TlsDeserialize,
324)]
325pub struct UnknownExtension(pub Vec<u8>);
326
327/// A Extension for Object of type T
328#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
329pub struct Extensions<T> {
330    unique: Vec<Extension>,
331    #[serde(skip)]
332    _object: core::marker::PhantomData<T>,
333}
334
335#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, TlsSize, TlsSerialize, TlsDeserialize)]
336/// Any object
337pub struct AnyObject;
338
339impl<T> Default for Extensions<T> {
340    fn default() -> Self {
341        Self {
342            unique: vec![],
343            _object: PhantomData,
344        }
345    }
346}
347
348impl<T> Size for Extensions<T> {
349    fn tls_serialized_len(&self) -> usize {
350        Vec::tls_serialized_len(&self.unique)
351    }
352}
353
354impl<T> TlsSerializeTrait for Extensions<T> {
355    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
356        self.unique.tls_serialize(writer)
357    }
358}
359
360impl<T: ExtensionValidator> TlsDeserializeTrait for Extensions<T>
361where
362    InvalidExtensionError: From<T::Error>,
363{
364    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
365    where
366        Self: Sized,
367    {
368        let candidate: Vec<Extension> = Vec::tls_deserialize(bytes)?;
369        Extensions::<T>::try_from(candidate)
370            .map_err(|_| Error::DecodingError("Found duplicate extensions".into()))
371    }
372}
373
374impl<T: ExtensionValidator> DeserializeBytes for Extensions<T>
375where
376    InvalidExtensionError: From<T::Error>,
377{
378    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
379    where
380        Self: Sized,
381    {
382        let mut bytes_ref = bytes;
383        let extensions = Extensions::<T>::tls_deserialize(&mut bytes_ref)?;
384        let remainder = &bytes[extensions.tls_serialized_len()..];
385        Ok((extensions, remainder))
386    }
387}
388
389impl<T: ExtensionValidator> Extensions<T> {
390    /// Create an empty extension list.
391    pub fn empty() -> Self {
392        Self {
393            unique: vec![],
394            _object: PhantomData,
395        }
396    }
397
398    /// Returns an iterator over the extension list.
399    pub fn iter(&self) -> impl Iterator<Item = &Extension> {
400        self.unique.iter()
401    }
402
403    /// Remove an extension from the extension list.
404    ///
405    /// Returns the removed extension or `None` when there is no extension with
406    /// the given extension type.
407    pub fn remove(&mut self, extension_type: ExtensionType) -> Option<Extension> {
408        if let Some(pos) = self
409            .unique
410            .iter()
411            .position(|ext| ext.extension_type() == extension_type)
412        {
413            Some(self.unique.remove(pos))
414        } else {
415            None
416        }
417    }
418
419    /// Returns `true` iff the extension list contains an extension with the
420    /// given extension type.
421    pub fn contains(&self, extension_type: ExtensionType) -> bool {
422        self.unique
423            .iter()
424            .any(|ext| ext.extension_type() == extension_type)
425    }
426}
427
428impl<T> Extensions<T>
429where
430    T: ExtensionValidator,
431    InvalidExtensionError: From<T::Error>,
432{
433    /// Create an extension list with a single extension.
434    pub fn single(extension: Extension) -> Result<Self, InvalidExtensionError> {
435        T::validate_extension_type(&extension)?;
436        Ok(Self {
437            unique: vec![extension],
438            _object: PhantomData,
439        })
440    }
441
442    /// Create an extension list with multiple extensions.
443    ///
444    /// This function will fail when the list of extensions contains duplicate
445    /// extension types.
446    pub fn from_vec(extensions: Vec<Extension>) -> Result<Self, InvalidExtensionError> {
447        extensions.try_into()
448    }
449
450    /// Validate if the extensions are valid for this context
451    pub fn validate<'a>(
452        extensions: impl Iterator<Item = &'a Extension>,
453    ) -> Result<(), InvalidExtensionError> {
454        for ext in extensions {
455            T::validate_extension_type(ext)?;
456        }
457        Ok(())
458    }
459
460    /// Add an extension to the extension list.
461    ///
462    /// Returns an error when there already is an extension with the same
463    /// extension type.
464    pub fn add(&mut self, extension: Extension) -> Result<(), InvalidExtensionError> {
465        T::validate_extension_type(&extension)?;
466        if self.contains(extension.extension_type()) {
467            return Err(InvalidExtensionError::Duplicate);
468        }
469
470        self.unique.push(extension);
471
472        Ok(())
473    }
474
475    /// Add an extension to the extension list (or replace an existing one.)
476    ///
477    /// Returns the replaced extension (if any).
478    pub fn add_or_replace(
479        &mut self,
480        extension: Extension,
481    ) -> Result<Option<Extension>, InvalidExtensionError> {
482        T::validate_extension_type(&extension)?;
483        let replaced = self.remove(extension.extension_type());
484        self.unique.push(extension);
485        Ok(replaced)
486    }
487}
488
489/// Can be implemented by a type to validate extensions.
490pub trait ExtensionValidator {
491    /// The error returned by the validator
492    type Error;
493
494    /// Check if the extension is valid.
495    fn validate_extension_type(ext: &Extension) -> Result<(), Self::Error>;
496}
497
498impl ExtensionValidator for AnyObject {
499    type Error = Infallible;
500
501    fn validate_extension_type(_ext: &Extension) -> Result<(), Infallible> {
502        Ok(())
503    }
504}
505
506impl<T: ExtensionValidator> TryFrom<Vec<Extension>> for Extensions<T>
507where
508    InvalidExtensionError: From<T::Error>,
509{
510    type Error = InvalidExtensionError;
511
512    fn try_from(candidate: Vec<Extension>) -> Result<Self, Self::Error> {
513        let mut unique: Vec<Extension> = Vec::new();
514        for extension in candidate.into_iter() {
515            T::validate_extension_type(&extension)?;
516
517            if unique
518                .iter()
519                .any(|ext| ext.extension_type() == extension.extension_type())
520            {
521                return Err(InvalidExtensionError::Duplicate);
522            } else {
523                unique.push(extension);
524            }
525        }
526
527        Ok(Self {
528            unique,
529            _object: PhantomData,
530        })
531    }
532}
533
534// https://validation.openmls.tech/#valn1602
535impl ExtensionValidator for GroupInfo {
536    type Error = ExtensionTypeNotValidInGroupInfoError;
537
538    fn validate_extension_type(
539        ext: &Extension,
540    ) -> Result<(), ExtensionTypeNotValidInGroupInfoError> {
541        if ext.extension_type().is_valid_in_group_info() == Some(true)
542            || ext.extension_type().is_valid_in_group_info().is_none()
543        {
544            Ok(())
545        } else {
546            Err(ExtensionTypeNotValidInGroupInfoError(ext.extension_type()))
547        }
548    }
549}
550
551// https://validation.openmls.tech/#valn1603
552impl ExtensionValidator for GroupContext {
553    type Error = ExtensionTypeNotValidInGroupContextError;
554
555    fn validate_extension_type(
556        ext: &Extension,
557    ) -> Result<(), ExtensionTypeNotValidInGroupContextError> {
558        if ext.extension_type().is_valid_in_group_context() {
559            Ok(())
560        } else {
561            Err(ExtensionTypeNotValidInGroupContextError(
562                ext.extension_type(),
563            ))
564        }
565    }
566}
567
568// https://validation.openmls.tech/#valn1604
569impl ExtensionValidator for KeyPackage {
570    type Error = ExtensionTypeNotValidInKeyPackageError;
571
572    fn validate_extension_type(
573        ext: &Extension,
574    ) -> Result<(), ExtensionTypeNotValidInKeyPackageError> {
575        if ext.extension_type().is_valid_in_key_package() {
576            Ok(())
577        } else {
578            Err(ExtensionTypeNotValidInKeyPackageError(ext.extension_type()))
579        }
580    }
581}
582
583// https://validation.openmls.tech/#valn1601
584impl ExtensionValidator for LeafNode {
585    type Error = ExtensionTypeNotValidInLeafNodeError;
586
587    fn validate_extension_type(
588        ext: &Extension,
589    ) -> Result<(), ExtensionTypeNotValidInLeafNodeError> {
590        if ext.extension_type().is_valid_in_leaf_node() {
591            Ok(())
592        } else {
593            Err(ExtensionTypeNotValidInLeafNodeError(ext.extension_type()))
594        }
595    }
596}
597
598impl<T> Extensions<T> {
599    fn find_by_type(&self, extension_type: ExtensionType) -> Option<&Extension> {
600        self.unique
601            .iter()
602            .find(|ext| ext.extension_type() == extension_type)
603    }
604
605    /// Get a reference to the [`ApplicationIdExtension`] if there is any.
606    pub fn application_id(&self) -> Option<&ApplicationIdExtension> {
607        self.find_by_type(ExtensionType::ApplicationId)
608            .and_then(|e| match e {
609                Extension::ApplicationId(e) => Some(e),
610                _ => None,
611            })
612    }
613
614    /// Get a reference to the [`RatchetTreeExtension`] if there is any.
615    pub fn ratchet_tree(&self) -> Option<&RatchetTreeExtension> {
616        self.find_by_type(ExtensionType::RatchetTree)
617            .and_then(|e| match e {
618                Extension::RatchetTree(e) => Some(e),
619                _ => None,
620            })
621    }
622
623    /// Get a reference to the [`RequiredCapabilitiesExtension`] if there is
624    /// any.
625    pub fn required_capabilities(&self) -> Option<&RequiredCapabilitiesExtension> {
626        self.find_by_type(ExtensionType::RequiredCapabilities)
627            .and_then(|e| match e {
628                Extension::RequiredCapabilities(e) => Some(e),
629                _ => None,
630            })
631    }
632
633    /// Get a reference to the [`ExternalPubExtension`] if there is any.
634    pub fn external_pub(&self) -> Option<&ExternalPubExtension> {
635        self.find_by_type(ExtensionType::ExternalPub)
636            .and_then(|e| match e {
637                Extension::ExternalPub(e) => Some(e),
638                _ => None,
639            })
640    }
641
642    /// Get a reference to the [`ExternalSendersExtension`] if there is any.
643    pub fn external_senders(&self) -> Option<&ExternalSendersExtension> {
644        self.find_by_type(ExtensionType::ExternalSenders)
645            .and_then(|e| match e {
646                Extension::ExternalSenders(e) => Some(e),
647                _ => None,
648            })
649    }
650
651    #[cfg(feature = "extensions-draft-08")]
652    /// Get a reference to the [`AppDataDictionaryExtension`] if there is any.
653    pub fn app_data_dictionary(&self) -> Option<&AppDataDictionaryExtension> {
654        self.find_by_type(ExtensionType::AppDataDictionary)
655            .and_then(|e| match e {
656                Extension::AppDataDictionary(e) => Some(e),
657                _ => None,
658            })
659    }
660
661    /// Get a reference to the [`UnknownExtension`] with the given type id, if there is any.
662    pub fn unknown(&self, extension_type_id: u16) -> Option<&UnknownExtension> {
663        let extension_type: ExtensionType = extension_type_id.into();
664
665        match extension_type {
666            ExtensionType::Unknown(_) => self.find_by_type(extension_type).and_then(|e| match e {
667                Extension::Unknown(_, e) => Some(e),
668                _ => None,
669            }),
670            _ => None,
671        }
672    }
673}
674
675impl Extension {
676    /// Get a reference to this extension as [`ApplicationIdExtension`].
677    /// Returns an [`ExtensionError::InvalidExtensionType`] if called on an
678    /// [`Extension`] that's not an [`ApplicationIdExtension`].
679    pub fn as_application_id_extension(&self) -> Result<&ApplicationIdExtension, ExtensionError> {
680        match self {
681            Self::ApplicationId(e) => Ok(e),
682            _ => Err(ExtensionError::InvalidExtensionType(
683                "This is not an ApplicationIdExtension".into(),
684            )),
685        }
686    }
687    #[cfg(feature = "extensions-draft-08")]
688    /// Get a reference to this extension as [`AppDataDictionaryExtension`].
689    /// Returns an [`ExtensionError::InvalidExtensionType`] if called on an
690    /// [`Extension`] that's not an [`AppDataDictionaryExtension`].
691    pub fn as_app_data_dictionary_extension(
692        &self,
693    ) -> Result<&AppDataDictionaryExtension, ExtensionError> {
694        match self {
695            Self::AppDataDictionary(e) => Ok(e),
696            _ => Err(ExtensionError::InvalidExtensionType(
697                "This is not an AppDataDictionaryExtension".into(),
698            )),
699        }
700    }
701
702    /// Get a reference to this extension as [`RatchetTreeExtension`].
703    /// Returns an [`ExtensionError::InvalidExtensionType`] if called on
704    /// an [`Extension`] that's not a [`RatchetTreeExtension`].
705    pub fn as_ratchet_tree_extension(&self) -> Result<&RatchetTreeExtension, ExtensionError> {
706        match self {
707            Self::RatchetTree(rte) => Ok(rte),
708            _ => Err(ExtensionError::InvalidExtensionType(
709                "This is not a RatchetTreeExtension".into(),
710            )),
711        }
712    }
713
714    /// Get a reference to this extension as [`RequiredCapabilitiesExtension`].
715    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
716    /// an [`Extension`] that's not a [`RequiredCapabilitiesExtension`].
717    pub fn as_required_capabilities_extension(
718        &self,
719    ) -> Result<&RequiredCapabilitiesExtension, ExtensionError> {
720        match self {
721            Self::RequiredCapabilities(e) => Ok(e),
722            _ => Err(ExtensionError::InvalidExtensionType(
723                "This is not a RequiredCapabilitiesExtension".into(),
724            )),
725        }
726    }
727
728    /// Get a reference to this extension as [`ExternalPubExtension`].
729    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
730    /// an [`Extension`] that's not a [`ExternalPubExtension`].
731    pub fn as_external_pub_extension(&self) -> Result<&ExternalPubExtension, ExtensionError> {
732        match self {
733            Self::ExternalPub(e) => Ok(e),
734            _ => Err(ExtensionError::InvalidExtensionType(
735                "This is not an ExternalPubExtension".into(),
736            )),
737        }
738    }
739
740    /// Get a reference to this extension as [`ExternalSendersExtension`].
741    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
742    /// an [`Extension`] that's not a [`ExternalSendersExtension`].
743    pub fn as_external_senders_extension(
744        &self,
745    ) -> Result<&ExternalSendersExtension, ExtensionError> {
746        match self {
747            Self::ExternalSenders(e) => Ok(e),
748            _ => Err(ExtensionError::InvalidExtensionType(
749                "This is not an ExternalSendersExtension".into(),
750            )),
751        }
752    }
753
754    /// Returns the [`ExtensionType`]
755    #[inline]
756    pub const fn extension_type(&self) -> ExtensionType {
757        match self {
758            Extension::ApplicationId(_) => ExtensionType::ApplicationId,
759            Extension::RatchetTree(_) => ExtensionType::RatchetTree,
760            Extension::RequiredCapabilities(_) => ExtensionType::RequiredCapabilities,
761            Extension::ExternalPub(_) => ExtensionType::ExternalPub,
762            Extension::ExternalSenders(_) => ExtensionType::ExternalSenders,
763            #[cfg(feature = "extensions-draft-08")]
764            Extension::AppDataDictionary(_) => ExtensionType::AppDataDictionary,
765            Extension::LastResort(_) => ExtensionType::LastResort,
766            Extension::Unknown(kind, _) => ExtensionType::Unknown(*kind),
767        }
768    }
769}
770
771macro_rules! impl_from_extensions_validator {
772    ($validator:ty, $error:ty) => {
773        impl From<Extensions<$validator>> for Extensions<AnyObject> {
774            fn from(value: Extensions<$validator>) -> Self {
775                Extensions {
776                    unique: value.unique,
777                    _object: PhantomData,
778                }
779            }
780        }
781
782        impl TryFrom<Extensions<AnyObject>> for Extensions<$validator> {
783            type Error = $error;
784
785            fn try_from(value: Extensions<AnyObject>) -> Result<Self, $error> {
786                value
787                    .unique
788                    .iter()
789                    .try_for_each(<$validator as ExtensionValidator>::validate_extension_type)?;
790
791                Ok(Extensions {
792                    unique: value.unique,
793                    _object: PhantomData,
794                })
795            }
796        }
797    };
798}
799
800impl_from_extensions_validator!(GroupContext, ExtensionTypeNotValidInGroupContextError);
801impl_from_extensions_validator!(LeafNode, ExtensionTypeNotValidInLeafNodeError);
802impl_from_extensions_validator!(KeyPackage, ExtensionTypeNotValidInKeyPackageError);
803
804#[cfg(any(feature = "test-utils", test))]
805impl Extensions<AnyObject> {
806    /// Coerces the extensions to an Extensions with the given validator. Unsafe.
807    pub(crate) fn coerce<T: ExtensionValidator>(self) -> Extensions<T> {
808        Extensions {
809            unique: self.unique,
810            _object: PhantomData,
811        }
812    }
813}
814#[cfg(test)]
815mod test {
816    use itertools::Itertools;
817    use tls_codec::{Deserialize, Serialize, VLBytes};
818
819    use crate::{ciphersuite::HpkePublicKey, extensions::*};
820
821    #[test]
822    fn add() {
823        let mut extensions: Extensions<AnyObject> = Extensions::default();
824        extensions
825            .add(Extension::RequiredCapabilities(
826                RequiredCapabilitiesExtension::default(),
827            ))
828            .unwrap();
829        assert!(extensions
830            .add(Extension::RequiredCapabilities(
831                RequiredCapabilitiesExtension::default()
832            ))
833            .is_err());
834    }
835
836    #[test]
837    fn add_try_from() {
838        // Create some extensions with different extension types and test that
839        // duplicates are rejected. The extension content does not matter in this test.
840        let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
841        let ext_y = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
842
843        let tests = [
844            (vec![], true),
845            (vec![ext_x.clone()], true),
846            (vec![ext_x.clone(), ext_x.clone()], false),
847            (vec![ext_x.clone(), ext_x.clone(), ext_x.clone()], false),
848            (vec![ext_y.clone()], true),
849            (vec![ext_y.clone(), ext_y.clone()], false),
850            (vec![ext_y.clone(), ext_y.clone(), ext_y.clone()], false),
851            (vec![ext_x.clone(), ext_y.clone()], true),
852            (vec![ext_y.clone(), ext_x.clone()], true),
853            (vec![ext_x.clone(), ext_x.clone(), ext_y.clone()], false),
854            (vec![ext_y.clone(), ext_y.clone(), ext_x.clone()], false),
855            (vec![ext_x.clone(), ext_y.clone(), ext_y.clone()], false),
856            (vec![ext_y.clone(), ext_x.clone(), ext_x.clone()], false),
857            (vec![ext_x.clone(), ext_y.clone(), ext_x.clone()], false),
858            (vec![ext_y.clone(), ext_x, ext_y], false),
859        ];
860
861        for (test, should_work) in tests.into_iter() {
862            // Test `add`.
863            {
864                let mut extensions: Extensions<AnyObject> = Extensions::default();
865
866                let mut works = true;
867                for ext in test.iter() {
868                    match extensions.add(ext.clone()) {
869                        Ok(_) => {}
870                        Err(InvalidExtensionError::Duplicate) => {
871                            works = false;
872                        }
873                        _ => panic!("This should have never happened."),
874                    }
875                }
876
877                println!("{:?}, {:?}", test.clone(), should_work);
878                assert_eq!(works, should_work);
879            }
880
881            // Test `try_from`.
882            if should_work {
883                assert!(Extensions::<AnyObject>::try_from(test).is_ok());
884            } else {
885                assert!(Extensions::<AnyObject>::try_from(test).is_err());
886            }
887        }
888    }
889
890    #[test]
891    fn ensure_ordering() {
892        // Create some extensions with different extension types and test
893        // that all permutations keep their order after being (de)serialized.
894        // The extension content does not matter in this test.
895        let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
896        let ext_y = Extension::ExternalPub(ExternalPubExtension::new(HpkePublicKey::new(vec![])));
897        let ext_z = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
898
899        for candidate in [ext_x, ext_y, ext_z]
900            .into_iter()
901            .permutations(3)
902            .collect::<Vec<_>>()
903        {
904            let candidate: Extensions<AnyObject> = Extensions::try_from(candidate).unwrap();
905            let bytes = candidate.tls_serialize_detached().unwrap();
906            let got = Extensions::tls_deserialize(&mut bytes.as_slice()).unwrap();
907            assert_eq!(candidate, got);
908        }
909    }
910
911    #[test]
912    fn that_unknown_extensions_are_de_serialized_correctly() {
913        let extension_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF100, 0xFFFF];
914        let extension_datas = [vec![], vec![0], vec![1, 2, 3]];
915
916        for extension_type in extension_types.into_iter() {
917            for extension_data in extension_datas.iter() {
918                // Construct an unknown extension manually.
919                let test = {
920                    let mut buf = extension_type.to_be_bytes().to_vec();
921                    buf.append(
922                        &mut VLBytes::new(extension_data.clone())
923                            .tls_serialize_detached()
924                            .unwrap(),
925                    );
926                    buf
927                };
928
929                // Test deserialization.
930                let got = Extension::tls_deserialize_exact(&test).unwrap();
931
932                match got {
933                    Extension::Unknown(got_extension_type, ref got_extension_data) => {
934                        assert_eq!(extension_type, got_extension_type);
935                        assert_eq!(extension_data, &got_extension_data.0);
936                    }
937                    other => panic!("Expected `Extension::Unknown`, got {other:?}"),
938                }
939
940                // Test serialization.
941                let got_serialized = got.tls_serialize_detached().unwrap();
942                assert_eq!(test, got_serialized);
943            }
944        }
945    }
946}