Skip to main content

openmls/extensions/
mod.rs

1//! # Extensions
2//!
3//! In MLS, extensions appear in the following places:
4//!
5//! - In [`KeyPackages`](`crate::key_packages`), to describe client capabilities
6//!   and aspects of their participation in the group.
7//!
8//! - In `GroupInfo`, to inform new members of the group's parameters and to
9//!   provide any additional information required to join the group.
10//!
11//! - In the `GroupContext` object, to ensure that all members of the group have
12//!   a consistent view of the parameters in use.
13//!
14//! Note that `GroupInfo` and `GroupContext` are not exposed via OpenMLS' public
15//! API.
16//!
17//! OpenMLS supports the following extensions:
18//!
19//! - [`ApplicationIdExtension`] (KeyPackage extension)
20//! - [`RatchetTreeExtension`] (GroupInfo extension)
21//! - [`RequiredCapabilitiesExtension`] (GroupContext extension)
22//! - [`ExternalPubExtension`] (GroupInfo extension)
23
24use std::{
25    convert::Infallible,
26    fmt::Debug,
27    io::{Read, Write},
28    marker::PhantomData,
29};
30
31use serde::{Deserialize, Serialize};
32
33// Private
34#[cfg(feature = "extensions-draft-08")]
35mod app_data_dict_extension;
36mod application_id_extension;
37mod codec;
38mod external_pub_extension;
39mod external_sender_extension;
40mod last_resort;
41mod ratchet_tree_extension;
42mod required_capabilities;
43use errors::*;
44
45// Public
46pub mod errors;
47
48// Public re-exports
49#[cfg(feature = "extensions-draft-08")]
50pub use app_data_dict_extension::{AppDataDictionary, AppDataDictionaryExtension};
51pub use application_id_extension::ApplicationIdExtension;
52pub use external_pub_extension::ExternalPubExtension;
53pub use external_sender_extension::{
54    ExternalSender, ExternalSendersExtension, SenderExtensionIndex,
55};
56pub use last_resort::LastResortExtension;
57pub use ratchet_tree_extension::RatchetTreeExtension;
58pub use required_capabilities::RequiredCapabilitiesExtension;
59
60use tls_codec::{
61    Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
62    Size, TlsDeserialize, TlsSerialize, TlsSize,
63};
64
65use crate::{
66    group::GroupContext, key_packages::KeyPackage, messages::group_info::GroupInfo,
67    treesync::LeafNode,
68};
69
70#[cfg(test)]
71mod tests;
72
73/// MLS Extension Types
74///
75/// Copied from draft-ietf-mls-protocol-16:
76///
77/// | Value            | Name                     | Message(s) | Recommended | Reference |
78/// |:-----------------|:-------------------------|:-----------|:------------|:----------|
79/// | 0x0000           | RESERVED                 | N/A        | N/A         | RFC XXXX  |
80/// | 0x0001           | application_id           | LN         | Y           | RFC XXXX  |
81/// | 0x0002           | ratchet_tree             | GI         | Y           | RFC XXXX  |
82/// | 0x0003           | required_capabilities    | GC         | Y           | RFC XXXX  |
83/// | 0x0004           | external_pub             | GI         | Y           | RFC XXXX  |
84/// | 0x0005           | external_senders         | GC         | Y           | RFC XXXX  |
85/// | 0xff00  - 0xffff | Reserved for Private Use | N/A        | N/A         | RFC XXXX  |
86///
87/// Note: OpenMLS does not provide a `Reserved` variant in [ExtensionType].
88#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
89#[cfg_attr(
90    feature = "0-8-1-storage-format",
91    derive(serde::Serialize, serde::Deserialize)
92)]
93#[cfg_attr(
94    not(feature = "0-8-1-storage-format"),
95    derive(
96        openmls_serialization_helpers::Serialize,
97        openmls_serialization_helpers::Deserialize,
98    )
99)]
100pub enum ExtensionType {
101    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 0)]
102    /// The application id extension allows applications to add an explicit,
103    /// application-defined identifier to a KeyPackage.
104    ApplicationId,
105
106    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 1)]
107    /// The ratchet tree extensions provides the whole public state of the
108    /// ratchet tree.
109    RatchetTree,
110
111    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 2)]
112    /// The required capabilities extension defines the configuration of a group
113    /// that imposes certain requirements on clients in the group.
114    RequiredCapabilities,
115
116    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 3)]
117    /// To join a group via an External Commit, a new member needs a GroupInfo
118    /// with an ExternalPub extension present in its extensions field.
119    ExternalPub,
120
121    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 4)]
122    /// Group context extension that contains the credentials and signature keys
123    /// of senders that are permitted to send external proposals to the group.
124    ExternalSenders,
125
126    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 5)]
127    /// KeyPackage extension that marks a KeyPackage for use in a last resort
128    /// scenario.
129    LastResort,
130
131    #[cfg(feature = "extensions-draft-08")]
132    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 8)]
133    /// AppDataDictionary extension
134    AppDataDictionary,
135
136    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 7)]
137    /// A GREASE extension type for ensuring extensibility.
138    Grease(u16),
139
140    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 6)]
141    /// A currently unknown extension type.
142    Unknown(u16),
143}
144
145impl ExtensionType {
146    /// Returns true for all extension types that are considered "default" by the spec.
147    pub(crate) fn is_default(self) -> bool {
148        match self {
149            ExtensionType::ApplicationId
150            | ExtensionType::RatchetTree
151            | ExtensionType::RequiredCapabilities
152            | ExtensionType::ExternalPub
153            | ExtensionType::ExternalSenders => true,
154            ExtensionType::LastResort | ExtensionType::Grease(_) | ExtensionType::Unknown(_) => {
155                false
156            }
157            #[cfg(feature = "extensions-draft-08")]
158            ExtensionType::AppDataDictionary => false,
159        }
160    }
161
162    /// Returns whether an extension type is valid when used in leaf nodes.
163    /// Returns None if validity can not be determined.
164    /// This is the case for unknown extensions.
165    //  https://validation.openmls.tech/#valn1601
166    pub(crate) fn is_valid_in_leaf_node(self) -> bool {
167        match self {
168            ExtensionType::Grease(_)
169            | ExtensionType::LastResort
170            | ExtensionType::RatchetTree
171            | ExtensionType::RequiredCapabilities
172            | ExtensionType::ExternalPub
173            | ExtensionType::ExternalSenders => false,
174            ExtensionType::Unknown(_) | ExtensionType::ApplicationId => true,
175            #[cfg(feature = "extensions-draft-08")]
176            ExtensionType::AppDataDictionary => true,
177        }
178    }
179    pub(crate) fn is_valid_in_group_info(self) -> Option<bool> {
180        match self {
181            ExtensionType::Grease(_)
182            | ExtensionType::LastResort
183            | ExtensionType::RequiredCapabilities
184            | ExtensionType::ExternalSenders
185            | ExtensionType::ApplicationId => Some(false),
186            ExtensionType::RatchetTree | ExtensionType::ExternalPub => Some(true),
187            ExtensionType::Unknown(_) => None,
188            #[cfg(feature = "extensions-draft-08")]
189            ExtensionType::AppDataDictionary => Some(true),
190        }
191    }
192
193    pub(crate) fn is_valid_in_key_package(self) -> bool {
194        match self {
195            ExtensionType::Grease(_)
196            | ExtensionType::RatchetTree
197            | ExtensionType::RequiredCapabilities
198            | ExtensionType::ExternalPub
199            | ExtensionType::ExternalSenders
200            | ExtensionType::ApplicationId => false,
201            ExtensionType::Unknown(_) | ExtensionType::LastResort => true,
202            #[cfg(feature = "extensions-draft-08")]
203            ExtensionType::AppDataDictionary => true,
204        }
205    }
206
207    pub(crate) fn is_valid_in_group_context(self) -> bool {
208        match self {
209            ExtensionType::RequiredCapabilities
210            | ExtensionType::ExternalSenders
211            | ExtensionType::Unknown(_) => true,
212            #[cfg(feature = "extensions-draft-08")]
213            ExtensionType::AppDataDictionary => true,
214            _ => false,
215        }
216    }
217
218    /// Returns true if this is a GREASE extension type.
219    ///
220    /// GREASE values are used to ensure implementations properly handle unknown
221    /// extension types. See [RFC 9420 Section 13.5](https://www.rfc-editor.org/rfc/rfc9420.html#section-13.5).
222    pub fn is_grease(&self) -> bool {
223        matches!(self, ExtensionType::Grease(_))
224    }
225}
226
227impl Size for ExtensionType {
228    fn tls_serialized_len(&self) -> usize {
229        2
230    }
231}
232
233impl TlsDeserializeTrait for ExtensionType {
234    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
235    where
236        Self: Sized,
237    {
238        let mut extension_type = [0u8; 2];
239        bytes.read_exact(&mut extension_type)?;
240
241        Ok(ExtensionType::from(u16::from_be_bytes(extension_type)))
242    }
243}
244
245impl DeserializeBytes for ExtensionType {
246    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
247    where
248        Self: Sized,
249    {
250        let mut bytes_ref = bytes;
251        let extension_type = ExtensionType::tls_deserialize(&mut bytes_ref)?;
252        let remainder = &bytes[extension_type.tls_serialized_len()..];
253        Ok((extension_type, remainder))
254    }
255}
256
257impl TlsSerializeTrait for ExtensionType {
258    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
259        writer.write_all(&u16::from(*self).to_be_bytes())?;
260
261        Ok(2)
262    }
263}
264
265impl From<u16> for ExtensionType {
266    fn from(a: u16) -> Self {
267        match a {
268            1 => ExtensionType::ApplicationId,
269            2 => ExtensionType::RatchetTree,
270            3 => ExtensionType::RequiredCapabilities,
271            4 => ExtensionType::ExternalPub,
272            5 => ExtensionType::ExternalSenders,
273            #[cfg(feature = "extensions-draft-08")]
274            6 => ExtensionType::AppDataDictionary,
275            10 => ExtensionType::LastResort,
276            unknown if crate::grease::is_grease_value(unknown) => ExtensionType::Grease(unknown),
277            unknown => ExtensionType::Unknown(unknown),
278        }
279    }
280}
281
282impl From<ExtensionType> for u16 {
283    fn from(value: ExtensionType) -> Self {
284        match value {
285            ExtensionType::ApplicationId => 1,
286            ExtensionType::RatchetTree => 2,
287            ExtensionType::RequiredCapabilities => 3,
288            ExtensionType::ExternalPub => 4,
289            ExtensionType::ExternalSenders => 5,
290            #[cfg(feature = "extensions-draft-08")]
291            ExtensionType::AppDataDictionary => 6,
292            ExtensionType::LastResort => 10,
293            ExtensionType::Grease(value) => value,
294            ExtensionType::Unknown(unknown) => unknown,
295        }
296    }
297}
298
299/// # Extension
300///
301/// An extension is one of the [`Extension`] enum values.
302/// The enum provides a set of common functionality for all extensions.
303///
304/// See the individual extensions for more details on each extension.
305///
306/// ```c
307/// // draft-ietf-mls-protocol-16
308/// struct {
309///     ExtensionType extension_type;
310///     opaque extension_data<V>;
311/// } Extension;
312/// ```
313#[derive(Debug, Clone, PartialEq, Eq)]
314#[cfg_attr(
315    feature = "0-8-1-storage-format",
316    derive(serde::Serialize, serde::Deserialize)
317)]
318#[cfg_attr(
319    not(feature = "0-8-1-storage-format"),
320    derive(
321        openmls_serialization_helpers::Serialize,
322        openmls_serialization_helpers::Deserialize,
323    )
324)]
325pub enum Extension {
326    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 0)]
327    /// An [`ApplicationIdExtension`]
328    ApplicationId(ApplicationIdExtension),
329
330    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 1)]
331    /// A [`RatchetTreeExtension`]
332    RatchetTree(RatchetTreeExtension),
333
334    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 2)]
335    /// A [`RequiredCapabilitiesExtension`]
336    RequiredCapabilities(RequiredCapabilitiesExtension),
337
338    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 3)]
339    /// An [`ExternalPubExtension`]
340    ExternalPub(ExternalPubExtension),
341
342    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 4)]
343    /// An [`ExternalSendersExtension`]
344    ExternalSenders(ExternalSendersExtension),
345
346    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 7)]
347    /// An [`AppDataDictionaryExtension`]
348    #[cfg(feature = "extensions-draft-08")]
349    AppDataDictionary(AppDataDictionaryExtension),
350
351    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 5)]
352    /// A [`LastResortExtension`]
353    LastResort(LastResortExtension),
354
355    #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 6)]
356    /// A currently unknown extension.
357    Unknown(u16, UnknownExtension),
358}
359
360/// A unknown/unparsed extension represented by raw bytes.
361#[derive(
362    PartialEq, Eq, Clone, Debug, Serialize, Deserialize, TlsSize, TlsSerialize, TlsDeserialize,
363)]
364pub struct UnknownExtension(pub Vec<u8>);
365
366/// A Extension for Object of type T
367#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
368pub struct Extensions<T> {
369    unique: Vec<Extension>,
370    #[serde(skip)]
371    _object: core::marker::PhantomData<T>,
372}
373
374#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, TlsSize, TlsSerialize, TlsDeserialize)]
375/// Any object
376pub struct AnyObject;
377
378impl<T> Default for Extensions<T> {
379    fn default() -> Self {
380        Self {
381            unique: vec![],
382            _object: PhantomData,
383        }
384    }
385}
386
387impl<T> Size for Extensions<T> {
388    fn tls_serialized_len(&self) -> usize {
389        Vec::tls_serialized_len(&self.unique)
390    }
391}
392
393impl<T> TlsSerializeTrait for Extensions<T> {
394    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
395        self.unique.tls_serialize(writer)
396    }
397}
398
399impl<T: ExtensionValidator> TlsDeserializeTrait for Extensions<T>
400where
401    InvalidExtensionError: From<T::Error>,
402{
403    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
404    where
405        Self: Sized,
406    {
407        let candidate: Vec<Extension> = Vec::tls_deserialize(bytes)?;
408        Extensions::<T>::try_from(candidate)
409            .map_err(|_| Error::DecodingError("Found duplicate extensions".into()))
410    }
411}
412
413impl<T: ExtensionValidator> DeserializeBytes for Extensions<T>
414where
415    InvalidExtensionError: From<T::Error>,
416{
417    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
418    where
419        Self: Sized,
420    {
421        let mut bytes_ref = bytes;
422        let extensions = Extensions::<T>::tls_deserialize(&mut bytes_ref)?;
423        let remainder = &bytes[extensions.tls_serialized_len()..];
424        Ok((extensions, remainder))
425    }
426}
427
428impl<T: ExtensionValidator> Extensions<T> {
429    /// Create an empty extension list.
430    pub fn empty() -> Self {
431        Self {
432            unique: vec![],
433            _object: PhantomData,
434        }
435    }
436
437    /// Returns an iterator over the extension list.
438    pub fn iter(&self) -> impl Iterator<Item = &Extension> {
439        self.unique.iter()
440    }
441
442    /// Remove an extension from the extension list.
443    ///
444    /// Returns the removed extension or `None` when there is no extension with
445    /// the given extension type.
446    pub fn remove(&mut self, extension_type: ExtensionType) -> Option<Extension> {
447        if let Some(pos) = self
448            .unique
449            .iter()
450            .position(|ext| ext.extension_type() == extension_type)
451        {
452            Some(self.unique.remove(pos))
453        } else {
454            None
455        }
456    }
457
458    /// Returns `true` iff the extension list contains an extension with the
459    /// given extension type.
460    pub fn contains(&self, extension_type: ExtensionType) -> bool {
461        self.unique
462            .iter()
463            .any(|ext| ext.extension_type() == extension_type)
464    }
465}
466
467impl<T> Extensions<T>
468where
469    T: ExtensionValidator,
470    InvalidExtensionError: From<T::Error>,
471{
472    /// Create an extension list with a single extension.
473    pub fn single(extension: Extension) -> Result<Self, InvalidExtensionError> {
474        T::validate_extension_type(&extension)?;
475        Ok(Self {
476            unique: vec![extension],
477            _object: PhantomData,
478        })
479    }
480
481    /// Create an extension list with multiple extensions.
482    ///
483    /// This function will fail when the list of extensions contains duplicate
484    /// extension types.
485    pub fn from_vec(extensions: Vec<Extension>) -> Result<Self, InvalidExtensionError> {
486        extensions.try_into()
487    }
488
489    /// Validate if the extensions are valid for this context
490    pub fn validate<'a>(
491        extensions: impl Iterator<Item = &'a Extension>,
492    ) -> Result<(), InvalidExtensionError> {
493        for ext in extensions {
494            T::validate_extension_type(ext)?;
495        }
496        Ok(())
497    }
498
499    /// Add an extension to the extension list.
500    ///
501    /// Returns an error when there already is an extension with the same
502    /// extension type.
503    pub fn add(&mut self, extension: Extension) -> Result<(), InvalidExtensionError> {
504        T::validate_extension_type(&extension)?;
505        if self.contains(extension.extension_type()) {
506            return Err(InvalidExtensionError::Duplicate);
507        }
508
509        self.unique.push(extension);
510
511        Ok(())
512    }
513
514    /// Add an extension to the extension list (or replace an existing one.)
515    ///
516    /// Returns the replaced extension (if any).
517    pub fn add_or_replace(
518        &mut self,
519        extension: Extension,
520    ) -> Result<Option<Extension>, InvalidExtensionError> {
521        T::validate_extension_type(&extension)?;
522        let replaced = self.remove(extension.extension_type());
523        self.unique.push(extension);
524        Ok(replaced)
525    }
526}
527
528/// Can be implemented by a type to validate extensions.
529pub trait ExtensionValidator {
530    /// The error returned by the validator
531    type Error;
532
533    /// Check if the extension is valid.
534    fn validate_extension_type(ext: &Extension) -> Result<(), Self::Error>;
535}
536
537impl ExtensionValidator for AnyObject {
538    type Error = Infallible;
539
540    fn validate_extension_type(_ext: &Extension) -> Result<(), Infallible> {
541        Ok(())
542    }
543}
544
545impl<T: ExtensionValidator> TryFrom<Vec<Extension>> for Extensions<T>
546where
547    InvalidExtensionError: From<T::Error>,
548{
549    type Error = InvalidExtensionError;
550
551    fn try_from(candidate: Vec<Extension>) -> Result<Self, Self::Error> {
552        let mut unique: Vec<Extension> = Vec::new();
553        for extension in candidate.into_iter() {
554            T::validate_extension_type(&extension)?;
555
556            if unique
557                .iter()
558                .any(|ext| ext.extension_type() == extension.extension_type())
559            {
560                return Err(InvalidExtensionError::Duplicate);
561            } else {
562                unique.push(extension);
563            }
564        }
565
566        Ok(Self {
567            unique,
568            _object: PhantomData,
569        })
570    }
571}
572
573// https://validation.openmls.tech/#valn1602
574impl ExtensionValidator for GroupInfo {
575    type Error = ExtensionTypeNotValidInGroupInfoError;
576
577    fn validate_extension_type(
578        ext: &Extension,
579    ) -> Result<(), ExtensionTypeNotValidInGroupInfoError> {
580        if ext.extension_type().is_valid_in_group_info() == Some(true)
581            || ext.extension_type().is_valid_in_group_info().is_none()
582        {
583            Ok(())
584        } else {
585            Err(ExtensionTypeNotValidInGroupInfoError(ext.extension_type()))
586        }
587    }
588}
589
590// https://validation.openmls.tech/#valn1603
591impl ExtensionValidator for GroupContext {
592    type Error = ExtensionTypeNotValidInGroupContextError;
593
594    fn validate_extension_type(
595        ext: &Extension,
596    ) -> Result<(), ExtensionTypeNotValidInGroupContextError> {
597        if ext.extension_type().is_valid_in_group_context() {
598            Ok(())
599        } else {
600            Err(ExtensionTypeNotValidInGroupContextError(
601                ext.extension_type(),
602            ))
603        }
604    }
605}
606
607// https://validation.openmls.tech/#valn1604
608impl ExtensionValidator for KeyPackage {
609    type Error = ExtensionTypeNotValidInKeyPackageError;
610
611    fn validate_extension_type(
612        ext: &Extension,
613    ) -> Result<(), ExtensionTypeNotValidInKeyPackageError> {
614        if ext.extension_type().is_valid_in_key_package() {
615            Ok(())
616        } else {
617            Err(ExtensionTypeNotValidInKeyPackageError(ext.extension_type()))
618        }
619    }
620}
621
622// https://validation.openmls.tech/#valn1601
623impl ExtensionValidator for LeafNode {
624    type Error = ExtensionTypeNotValidInLeafNodeError;
625
626    fn validate_extension_type(
627        ext: &Extension,
628    ) -> Result<(), ExtensionTypeNotValidInLeafNodeError> {
629        if ext.extension_type().is_valid_in_leaf_node() {
630            Ok(())
631        } else {
632            Err(ExtensionTypeNotValidInLeafNodeError(ext.extension_type()))
633        }
634    }
635}
636
637impl<T> Extensions<T> {
638    fn find_by_type(&self, extension_type: ExtensionType) -> Option<&Extension> {
639        self.unique
640            .iter()
641            .find(|ext| ext.extension_type() == extension_type)
642    }
643
644    /// Get a reference to the [`ApplicationIdExtension`] if there is any.
645    pub fn application_id(&self) -> Option<&ApplicationIdExtension> {
646        self.find_by_type(ExtensionType::ApplicationId)
647            .and_then(|e| match e {
648                Extension::ApplicationId(e) => Some(e),
649                _ => None,
650            })
651    }
652
653    /// Get a reference to the [`RatchetTreeExtension`] if there is any.
654    pub fn ratchet_tree(&self) -> Option<&RatchetTreeExtension> {
655        self.find_by_type(ExtensionType::RatchetTree)
656            .and_then(|e| match e {
657                Extension::RatchetTree(e) => Some(e),
658                _ => None,
659            })
660    }
661
662    /// Get a reference to the [`RequiredCapabilitiesExtension`] if there is
663    /// any.
664    pub fn required_capabilities(&self) -> Option<&RequiredCapabilitiesExtension> {
665        self.find_by_type(ExtensionType::RequiredCapabilities)
666            .and_then(|e| match e {
667                Extension::RequiredCapabilities(e) => Some(e),
668                _ => None,
669            })
670    }
671
672    /// Get a reference to the [`ExternalPubExtension`] if there is any.
673    pub fn external_pub(&self) -> Option<&ExternalPubExtension> {
674        self.find_by_type(ExtensionType::ExternalPub)
675            .and_then(|e| match e {
676                Extension::ExternalPub(e) => Some(e),
677                _ => None,
678            })
679    }
680
681    /// Get a reference to the [`ExternalSendersExtension`] if there is any.
682    pub fn external_senders(&self) -> Option<&ExternalSendersExtension> {
683        self.find_by_type(ExtensionType::ExternalSenders)
684            .and_then(|e| match e {
685                Extension::ExternalSenders(e) => Some(e),
686                _ => None,
687            })
688    }
689
690    #[cfg(feature = "extensions-draft-08")]
691    /// Get a reference to the [`AppDataDictionaryExtension`] if there is any.
692    pub fn app_data_dictionary(&self) -> Option<&AppDataDictionaryExtension> {
693        self.find_by_type(ExtensionType::AppDataDictionary)
694            .and_then(|e| match e {
695                Extension::AppDataDictionary(e) => Some(e),
696                _ => None,
697            })
698    }
699
700    /// Get a reference to the [`UnknownExtension`] with the given type id, if there is any.
701    pub fn unknown(&self, extension_type_id: u16) -> Option<&UnknownExtension> {
702        let extension_type: ExtensionType = extension_type_id.into();
703
704        match extension_type {
705            ExtensionType::Unknown(_) => self.find_by_type(extension_type).and_then(|e| match e {
706                Extension::Unknown(_, e) => Some(e),
707                _ => None,
708            }),
709            _ => None,
710        }
711    }
712}
713
714impl Extension {
715    /// Get a reference to this extension as [`ApplicationIdExtension`].
716    /// Returns an [`ExtensionError::InvalidExtensionType`] if called on an
717    /// [`Extension`] that's not an [`ApplicationIdExtension`].
718    pub fn as_application_id_extension(&self) -> Result<&ApplicationIdExtension, ExtensionError> {
719        match self {
720            Self::ApplicationId(e) => Ok(e),
721            _ => Err(ExtensionError::InvalidExtensionType(
722                "This is not an ApplicationIdExtension".into(),
723            )),
724        }
725    }
726    #[cfg(feature = "extensions-draft-08")]
727    /// Get a reference to this extension as [`AppDataDictionaryExtension`].
728    /// Returns an [`ExtensionError::InvalidExtensionType`] if called on an
729    /// [`Extension`] that's not an [`AppDataDictionaryExtension`].
730    pub fn as_app_data_dictionary_extension(
731        &self,
732    ) -> Result<&AppDataDictionaryExtension, ExtensionError> {
733        match self {
734            Self::AppDataDictionary(e) => Ok(e),
735            _ => Err(ExtensionError::InvalidExtensionType(
736                "This is not an AppDataDictionaryExtension".into(),
737            )),
738        }
739    }
740
741    /// Get a reference to this extension as [`RatchetTreeExtension`].
742    /// Returns an [`ExtensionError::InvalidExtensionType`] if called on
743    /// an [`Extension`] that's not a [`RatchetTreeExtension`].
744    pub fn as_ratchet_tree_extension(&self) -> Result<&RatchetTreeExtension, ExtensionError> {
745        match self {
746            Self::RatchetTree(rte) => Ok(rte),
747            _ => Err(ExtensionError::InvalidExtensionType(
748                "This is not a RatchetTreeExtension".into(),
749            )),
750        }
751    }
752
753    /// Get a reference to this extension as [`RequiredCapabilitiesExtension`].
754    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
755    /// an [`Extension`] that's not a [`RequiredCapabilitiesExtension`].
756    pub fn as_required_capabilities_extension(
757        &self,
758    ) -> Result<&RequiredCapabilitiesExtension, ExtensionError> {
759        match self {
760            Self::RequiredCapabilities(e) => Ok(e),
761            _ => Err(ExtensionError::InvalidExtensionType(
762                "This is not a RequiredCapabilitiesExtension".into(),
763            )),
764        }
765    }
766
767    /// Get a reference to this extension as [`ExternalPubExtension`].
768    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
769    /// an [`Extension`] that's not a [`ExternalPubExtension`].
770    pub fn as_external_pub_extension(&self) -> Result<&ExternalPubExtension, ExtensionError> {
771        match self {
772            Self::ExternalPub(e) => Ok(e),
773            _ => Err(ExtensionError::InvalidExtensionType(
774                "This is not an ExternalPubExtension".into(),
775            )),
776        }
777    }
778
779    /// Get a reference to this extension as [`ExternalSendersExtension`].
780    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
781    /// an [`Extension`] that's not a [`ExternalSendersExtension`].
782    pub fn as_external_senders_extension(
783        &self,
784    ) -> Result<&ExternalSendersExtension, ExtensionError> {
785        match self {
786            Self::ExternalSenders(e) => Ok(e),
787            _ => Err(ExtensionError::InvalidExtensionType(
788                "This is not an ExternalSendersExtension".into(),
789            )),
790        }
791    }
792
793    /// Returns the [`ExtensionType`]
794    #[inline]
795    pub const fn extension_type(&self) -> ExtensionType {
796        match self {
797            Extension::ApplicationId(_) => ExtensionType::ApplicationId,
798            Extension::RatchetTree(_) => ExtensionType::RatchetTree,
799            Extension::RequiredCapabilities(_) => ExtensionType::RequiredCapabilities,
800            Extension::ExternalPub(_) => ExtensionType::ExternalPub,
801            Extension::ExternalSenders(_) => ExtensionType::ExternalSenders,
802            #[cfg(feature = "extensions-draft-08")]
803            Extension::AppDataDictionary(_) => ExtensionType::AppDataDictionary,
804            Extension::LastResort(_) => ExtensionType::LastResort,
805            Extension::Unknown(kind, _) => ExtensionType::Unknown(*kind),
806        }
807    }
808}
809
810macro_rules! impl_from_extensions_validator {
811    ($validator:ty, $error:ty) => {
812        impl From<Extensions<$validator>> for Extensions<AnyObject> {
813            fn from(value: Extensions<$validator>) -> Self {
814                Extensions {
815                    unique: value.unique,
816                    _object: PhantomData,
817                }
818            }
819        }
820
821        impl TryFrom<Extensions<AnyObject>> for Extensions<$validator> {
822            type Error = $error;
823
824            fn try_from(value: Extensions<AnyObject>) -> Result<Self, $error> {
825                value
826                    .unique
827                    .iter()
828                    .try_for_each(<$validator as ExtensionValidator>::validate_extension_type)?;
829
830                Ok(Extensions {
831                    unique: value.unique,
832                    _object: PhantomData,
833                })
834            }
835        }
836    };
837}
838
839impl_from_extensions_validator!(GroupContext, ExtensionTypeNotValidInGroupContextError);
840impl_from_extensions_validator!(LeafNode, ExtensionTypeNotValidInLeafNodeError);
841impl_from_extensions_validator!(KeyPackage, ExtensionTypeNotValidInKeyPackageError);
842
843#[cfg(any(feature = "test-utils", test))]
844impl Extensions<AnyObject> {
845    /// Coerces the extensions to an Extensions with the given validator. Unsafe.
846    pub(crate) fn coerce<T: ExtensionValidator>(self) -> Extensions<T> {
847        Extensions {
848            unique: self.unique,
849            _object: PhantomData,
850        }
851    }
852}
853#[cfg(test)]
854mod test {
855    use itertools::Itertools;
856    use tls_codec::{Deserialize, Serialize, VLBytes};
857
858    use crate::{ciphersuite::HpkePublicKey, extensions::*};
859
860    #[test]
861    fn add() {
862        let mut extensions: Extensions<AnyObject> = Extensions::default();
863        extensions
864            .add(Extension::RequiredCapabilities(
865                RequiredCapabilitiesExtension::default(),
866            ))
867            .unwrap();
868        assert!(extensions
869            .add(Extension::RequiredCapabilities(
870                RequiredCapabilitiesExtension::default()
871            ))
872            .is_err());
873    }
874
875    #[test]
876    fn add_try_from() {
877        // Create some extensions with different extension types and test that
878        // duplicates are rejected. The extension content does not matter in this test.
879        let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
880        let ext_y = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
881
882        let tests = [
883            (vec![], true),
884            (vec![ext_x.clone()], true),
885            (vec![ext_x.clone(), ext_x.clone()], false),
886            (vec![ext_x.clone(), ext_x.clone(), ext_x.clone()], false),
887            (vec![ext_y.clone()], true),
888            (vec![ext_y.clone(), ext_y.clone()], false),
889            (vec![ext_y.clone(), ext_y.clone(), ext_y.clone()], false),
890            (vec![ext_x.clone(), ext_y.clone()], true),
891            (vec![ext_y.clone(), ext_x.clone()], true),
892            (vec![ext_x.clone(), ext_x.clone(), ext_y.clone()], false),
893            (vec![ext_y.clone(), ext_y.clone(), ext_x.clone()], false),
894            (vec![ext_x.clone(), ext_y.clone(), ext_y.clone()], false),
895            (vec![ext_y.clone(), ext_x.clone(), ext_x.clone()], false),
896            (vec![ext_x.clone(), ext_y.clone(), ext_x.clone()], false),
897            (vec![ext_y.clone(), ext_x, ext_y], false),
898        ];
899
900        for (test, should_work) in tests.into_iter() {
901            // Test `add`.
902            {
903                let mut extensions: Extensions<AnyObject> = Extensions::default();
904
905                let mut works = true;
906                for ext in test.iter() {
907                    match extensions.add(ext.clone()) {
908                        Ok(_) => {}
909                        Err(InvalidExtensionError::Duplicate) => {
910                            works = false;
911                        }
912                        _ => panic!("This should have never happened."),
913                    }
914                }
915
916                println!("{:?}, {:?}", test.clone(), should_work);
917                assert_eq!(works, should_work);
918            }
919
920            // Test `try_from`.
921            if should_work {
922                assert!(Extensions::<AnyObject>::try_from(test).is_ok());
923            } else {
924                assert!(Extensions::<AnyObject>::try_from(test).is_err());
925            }
926        }
927    }
928
929    #[test]
930    fn ensure_ordering() {
931        // Create some extensions with different extension types and test
932        // that all permutations keep their order after being (de)serialized.
933        // The extension content does not matter in this test.
934        let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
935        let ext_y = Extension::ExternalPub(ExternalPubExtension::new(HpkePublicKey::new(vec![])));
936        let ext_z = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
937
938        for candidate in [ext_x, ext_y, ext_z]
939            .into_iter()
940            .permutations(3)
941            .collect::<Vec<_>>()
942        {
943            let candidate: Extensions<AnyObject> = Extensions::try_from(candidate).unwrap();
944            let bytes = candidate.tls_serialize_detached().unwrap();
945            let got = Extensions::tls_deserialize(&mut bytes.as_slice()).unwrap();
946            assert_eq!(candidate, got);
947        }
948    }
949
950    #[test]
951    fn that_unknown_extensions_are_de_serialized_correctly() {
952        let extension_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF100, 0xFFFF];
953        let extension_datas = [vec![], vec![0], vec![1, 2, 3]];
954
955        for extension_type in extension_types.into_iter() {
956            for extension_data in extension_datas.iter() {
957                // Construct an unknown extension manually.
958                let test = {
959                    let mut buf = extension_type.to_be_bytes().to_vec();
960                    buf.append(
961                        &mut VLBytes::new(extension_data.clone())
962                            .tls_serialize_detached()
963                            .unwrap(),
964                    );
965                    buf
966                };
967
968                // Test deserialization.
969                let got = Extension::tls_deserialize_exact(&test).unwrap();
970
971                match got {
972                    Extension::Unknown(got_extension_type, ref got_extension_data) => {
973                        assert_eq!(extension_type, got_extension_type);
974                        assert_eq!(extension_data, &got_extension_data.0);
975                    }
976                    other => panic!("Expected `Extension::Unknown`, got {other:?}"),
977                }
978
979                // Test serialization.
980                let got_serialized = got.tls_serialize_detached().unwrap();
981                assert_eq!(test, got_serialized);
982            }
983        }
984    }
985}