openmls/extensions/
mod.rs

1//! # Extensions
2//!
3//! In MLS, extensions appear in the following places:
4//!
5//! - In [`KeyPackages`](`crate::key_packages`), to describe client capabilities
6//!   and aspects of their participation in the group.
7//!
8//! - In `GroupInfo`, to inform new members of the group's parameters and to
9//!   provide any additional information required to join the group.
10//!
11//! - In the `GroupContext` object, to ensure that all members of the group have
12//!   a consistent view of the parameters in use.
13//!
14//! Note that `GroupInfo` and `GroupContext` are not exposed via OpenMLS' public
15//! API.
16//!
17//! OpenMLS supports the following extensions:
18//!
19//! - [`ApplicationIdExtension`] (KeyPackage extension)
20//! - [`RatchetTreeExtension`] (GroupInfo extension)
21//! - [`RequiredCapabilitiesExtension`] (GroupContext extension)
22//! - [`ExternalPubExtension`] (GroupInfo extension)
23
24use std::{
25    fmt::Debug,
26    io::{Read, Write},
27};
28
29use serde::{Deserialize, Serialize};
30
31// Private
32mod application_id_extension;
33mod codec;
34mod external_pub_extension;
35mod external_sender_extension;
36mod last_resort;
37mod ratchet_tree_extension;
38mod required_capabilities;
39use errors::*;
40
41// Public
42pub mod errors;
43
44// Public re-exports
45pub use application_id_extension::ApplicationIdExtension;
46pub use external_pub_extension::ExternalPubExtension;
47pub use external_sender_extension::{
48    ExternalSender, ExternalSendersExtension, SenderExtensionIndex,
49};
50pub use last_resort::LastResortExtension;
51pub use ratchet_tree_extension::RatchetTreeExtension;
52pub use required_capabilities::RequiredCapabilitiesExtension;
53use tls_codec::{
54    Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
55    Size, TlsSize,
56};
57
58#[cfg(test)]
59mod tests;
60
61/// MLS Extension Types
62///
63/// Copied from draft-ietf-mls-protocol-16:
64///
65/// | Value            | Name                     | Message(s) | Recommended | Reference |
66/// |:-----------------|:-------------------------|:-----------|:------------|:----------|
67/// | 0x0000           | RESERVED                 | N/A        | N/A         | RFC XXXX  |
68/// | 0x0001           | application_id           | LN         | Y           | RFC XXXX  |
69/// | 0x0002           | ratchet_tree             | GI         | Y           | RFC XXXX  |
70/// | 0x0003           | required_capabilities    | GC         | Y           | RFC XXXX  |
71/// | 0x0004           | external_pub             | GI         | Y           | RFC XXXX  |
72/// | 0x0005           | external_senders         | GC         | Y           | RFC XXXX  |
73/// | 0xff00  - 0xffff | Reserved for Private Use | N/A        | N/A         | RFC XXXX  |
74///
75/// Note: OpenMLS does not provide a `Reserved` variant in [ExtensionType].
76#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)]
77pub enum ExtensionType {
78    /// The application id extension allows applications to add an explicit,
79    /// application-defined identifier to a KeyPackage.
80    ApplicationId,
81
82    /// The ratchet tree extensions provides the whole public state of the
83    /// ratchet tree.
84    RatchetTree,
85
86    /// The required capabilities extension defines the configuration of a group
87    /// that imposes certain requirements on clients in the group.
88    RequiredCapabilities,
89
90    /// To join a group via an External Commit, a new member needs a GroupInfo
91    /// with an ExternalPub extension present in its extensions field.
92    ExternalPub,
93
94    /// Group context extension that contains the credentials and signature keys
95    /// of senders that are permitted to send external proposals to the group.
96    ExternalSenders,
97
98    /// KeyPackage extension that marks a KeyPackage for use in a last resort
99    /// scenario.
100    LastResort,
101
102    /// A currently unknown extension type.
103    Unknown(u16),
104}
105
106impl ExtensionType {
107    /// Returns true for all extension types that are considered "default" by the spec.
108    pub(crate) fn is_default(self) -> bool {
109        match self {
110            ExtensionType::ApplicationId
111            | ExtensionType::RatchetTree
112            | ExtensionType::RequiredCapabilities
113            | ExtensionType::ExternalPub
114            | ExtensionType::ExternalSenders => true,
115            ExtensionType::LastResort | ExtensionType::Unknown(_) => false,
116        }
117    }
118
119    /// Returns whether an extension type is valid when used in leaf nodes.
120    /// Returns None if validity can not be determined.
121    /// This is the case for unknown extensions.
122    //  https://validation.openmls.tech/#valn1601
123    pub(crate) fn is_valid_in_leaf_node(self) -> Option<bool> {
124        match self {
125            ExtensionType::LastResort
126            | ExtensionType::RatchetTree
127            | ExtensionType::RequiredCapabilities
128            | ExtensionType::ExternalPub
129            | ExtensionType::ExternalSenders => Some(false),
130            ExtensionType::ApplicationId => Some(true),
131            ExtensionType::Unknown(_) => None,
132        }
133    }
134}
135
136impl Size for ExtensionType {
137    fn tls_serialized_len(&self) -> usize {
138        2
139    }
140}
141
142impl TlsDeserializeTrait for ExtensionType {
143    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
144    where
145        Self: Sized,
146    {
147        let mut extension_type = [0u8; 2];
148        bytes.read_exact(&mut extension_type)?;
149
150        Ok(ExtensionType::from(u16::from_be_bytes(extension_type)))
151    }
152}
153
154impl DeserializeBytes for ExtensionType {
155    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
156    where
157        Self: Sized,
158    {
159        let mut bytes_ref = bytes;
160        let extension_type = ExtensionType::tls_deserialize(&mut bytes_ref)?;
161        let remainder = &bytes[extension_type.tls_serialized_len()..];
162        Ok((extension_type, remainder))
163    }
164}
165
166impl TlsSerializeTrait for ExtensionType {
167    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
168        writer.write_all(&u16::from(*self).to_be_bytes())?;
169
170        Ok(2)
171    }
172}
173
174impl From<u16> for ExtensionType {
175    fn from(a: u16) -> Self {
176        match a {
177            1 => ExtensionType::ApplicationId,
178            2 => ExtensionType::RatchetTree,
179            3 => ExtensionType::RequiredCapabilities,
180            4 => ExtensionType::ExternalPub,
181            5 => ExtensionType::ExternalSenders,
182            10 => ExtensionType::LastResort,
183            unknown => ExtensionType::Unknown(unknown),
184        }
185    }
186}
187
188impl From<ExtensionType> for u16 {
189    fn from(value: ExtensionType) -> Self {
190        match value {
191            ExtensionType::ApplicationId => 1,
192            ExtensionType::RatchetTree => 2,
193            ExtensionType::RequiredCapabilities => 3,
194            ExtensionType::ExternalPub => 4,
195            ExtensionType::ExternalSenders => 5,
196            ExtensionType::LastResort => 10,
197            ExtensionType::Unknown(unknown) => unknown,
198        }
199    }
200}
201
202/// # Extension
203///
204/// An extension is one of the [`Extension`] enum values.
205/// The enum provides a set of common functionality for all extensions.
206///
207/// See the individual extensions for more details on each extension.
208///
209/// ```c
210/// // draft-ietf-mls-protocol-16
211/// struct {
212///     ExtensionType extension_type;
213///     opaque extension_data<V>;
214/// } Extension;
215/// ```
216#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
217pub enum Extension {
218    /// An [`ApplicationIdExtension`]
219    ApplicationId(ApplicationIdExtension),
220
221    /// A [`RatchetTreeExtension`]
222    RatchetTree(RatchetTreeExtension),
223
224    /// A [`RequiredCapabilitiesExtension`]
225    RequiredCapabilities(RequiredCapabilitiesExtension),
226
227    /// An [`ExternalPubExtension`]
228    ExternalPub(ExternalPubExtension),
229
230    /// An [`ExternalSendersExtension`]
231    ExternalSenders(ExternalSendersExtension),
232
233    /// A [`LastResortExtension`]
234    LastResort(LastResortExtension),
235
236    /// A currently unknown extension.
237    Unknown(u16, UnknownExtension),
238}
239
240/// A unknown/unparsed extension represented by raw bytes.
241#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize)]
242pub struct UnknownExtension(pub Vec<u8>);
243
244/// A list of extensions with unique extension types.
245#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TlsSize)]
246pub struct Extensions {
247    unique: Vec<Extension>,
248}
249
250impl TlsSerializeTrait for Extensions {
251    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
252        self.unique.tls_serialize(writer)
253    }
254}
255
256impl TlsDeserializeTrait for Extensions {
257    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
258    where
259        Self: Sized,
260    {
261        let candidate: Vec<Extension> = Vec::tls_deserialize(bytes)?;
262        Extensions::try_from(candidate)
263            .map_err(|_| Error::DecodingError("Found duplicate extensions".into()))
264    }
265}
266
267impl DeserializeBytes for Extensions {
268    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
269    where
270        Self: Sized,
271    {
272        let mut bytes_ref = bytes;
273        let extensions = Extensions::tls_deserialize(&mut bytes_ref)?;
274        let remainder = &bytes[extensions.tls_serialized_len()..];
275        Ok((extensions, remainder))
276    }
277}
278
279impl Extensions {
280    /// Create an empty extension list.
281    pub fn empty() -> Self {
282        Self { unique: vec![] }
283    }
284
285    /// Create an extension list with a single extension.
286    pub fn single(extension: Extension) -> Self {
287        Self {
288            unique: vec![extension],
289        }
290    }
291
292    /// Create an extension list with multiple extensions.
293    ///
294    /// This function will fail when the list of extensions contains duplicate
295    /// extension types.
296    pub fn from_vec(extensions: Vec<Extension>) -> Result<Self, InvalidExtensionError> {
297        extensions.try_into()
298    }
299
300    /// Returns an iterator over the extension list.
301    pub fn iter(&self) -> impl Iterator<Item = &Extension> {
302        self.unique.iter()
303    }
304
305    /// Add an extension to the extension list.
306    ///
307    /// Returns an error when there already is an extension with the same
308    /// extension type.
309    pub fn add(&mut self, extension: Extension) -> Result<(), InvalidExtensionError> {
310        if self.contains(extension.extension_type()) {
311            return Err(InvalidExtensionError::Duplicate);
312        }
313
314        self.unique.push(extension);
315
316        Ok(())
317    }
318
319    /// Add an extension to the extension list (or replace an existing one.)
320    ///
321    /// Returns the replaced extension (if any).
322    pub fn add_or_replace(&mut self, extension: Extension) -> Option<Extension> {
323        let replaced = self.remove(extension.extension_type());
324        self.unique.push(extension);
325        replaced
326    }
327
328    /// Remove an extension from the extension list.
329    ///
330    /// Returns the removed extension or `None` when there is no extension with
331    /// the given extension type.
332    pub fn remove(&mut self, extension_type: ExtensionType) -> Option<Extension> {
333        if let Some(pos) = self
334            .unique
335            .iter()
336            .position(|ext| ext.extension_type() == extension_type)
337        {
338            Some(self.unique.remove(pos))
339        } else {
340            None
341        }
342    }
343
344    /// Returns `true` iff the extension list contains an extension with the
345    /// given extension type.
346    pub fn contains(&self, extension_type: ExtensionType) -> bool {
347        self.unique
348            .iter()
349            .any(|ext| ext.extension_type() == extension_type)
350    }
351
352    // validate that all extensions can be added to a leaf node.
353    // https://validation.openmls.tech/#valn1601
354    pub(crate) fn validate_extension_types_for_leaf_node(
355        &self,
356    ) -> Result<(), InvalidExtensionError> {
357        for extension_type in self.unique.iter().map(Extension::extension_type) {
358            // also allow unknown extensions, which return `None` here
359            if extension_type.is_valid_in_leaf_node() == Some(false) {
360                return Err(InvalidExtensionError::IllegalInLeafNodes);
361            }
362        }
363        Ok(())
364    }
365}
366
367impl TryFrom<Vec<Extension>> for Extensions {
368    type Error = InvalidExtensionError;
369
370    fn try_from(candidate: Vec<Extension>) -> Result<Self, Self::Error> {
371        let mut unique: Vec<Extension> = Vec::new();
372
373        for extension in candidate.into_iter() {
374            if unique
375                .iter()
376                .any(|ext| ext.extension_type() == extension.extension_type())
377            {
378                return Err(InvalidExtensionError::Duplicate);
379            } else {
380                unique.push(extension);
381            }
382        }
383
384        Ok(Self { unique })
385    }
386}
387
388impl Extensions {
389    fn find_by_type(&self, extension_type: ExtensionType) -> Option<&Extension> {
390        self.unique
391            .iter()
392            .find(|ext| ext.extension_type() == extension_type)
393    }
394
395    /// Get a reference to the [`ApplicationIdExtension`] if there is any.
396    pub fn application_id(&self) -> Option<&ApplicationIdExtension> {
397        self.find_by_type(ExtensionType::ApplicationId)
398            .and_then(|e| match e {
399                Extension::ApplicationId(e) => Some(e),
400                _ => None,
401            })
402    }
403
404    /// Get a reference to the [`RatchetTreeExtension`] if there is any.
405    pub fn ratchet_tree(&self) -> Option<&RatchetTreeExtension> {
406        self.find_by_type(ExtensionType::RatchetTree)
407            .and_then(|e| match e {
408                Extension::RatchetTree(e) => Some(e),
409                _ => None,
410            })
411    }
412
413    /// Get a reference to the [`RequiredCapabilitiesExtension`] if there is
414    /// any.
415    pub fn required_capabilities(&self) -> Option<&RequiredCapabilitiesExtension> {
416        self.find_by_type(ExtensionType::RequiredCapabilities)
417            .and_then(|e| match e {
418                Extension::RequiredCapabilities(e) => Some(e),
419                _ => None,
420            })
421    }
422
423    /// Get a reference to the [`ExternalPubExtension`] if there is any.
424    pub fn external_pub(&self) -> Option<&ExternalPubExtension> {
425        self.find_by_type(ExtensionType::ExternalPub)
426            .and_then(|e| match e {
427                Extension::ExternalPub(e) => Some(e),
428                _ => None,
429            })
430    }
431
432    /// Get a reference to the [`ExternalSendersExtension`] if there is any.
433    pub fn external_senders(&self) -> Option<&ExternalSendersExtension> {
434        self.find_by_type(ExtensionType::ExternalSenders)
435            .and_then(|e| match e {
436                Extension::ExternalSenders(e) => Some(e),
437                _ => None,
438            })
439    }
440
441    /// Get a reference to the [`UnknownExtension`] with the given type id, if there is any.
442    pub fn unknown(&self, extension_type_id: u16) -> Option<&UnknownExtension> {
443        let extension_type: ExtensionType = extension_type_id.into();
444
445        match extension_type {
446            ExtensionType::Unknown(_) => self.find_by_type(extension_type).and_then(|e| match e {
447                Extension::Unknown(_, e) => Some(e),
448                _ => None,
449            }),
450            _ => None,
451        }
452    }
453}
454
455impl Extension {
456    /// Get a reference to this extension as [`ApplicationIdExtension`].
457    /// Returns an [`ExtensionError::InvalidExtensionType`] if called on an
458    /// [`Extension`] that's not an [`ApplicationIdExtension`].
459    pub fn as_application_id_extension(&self) -> Result<&ApplicationIdExtension, ExtensionError> {
460        match self {
461            Self::ApplicationId(e) => Ok(e),
462            _ => Err(ExtensionError::InvalidExtensionType(
463                "This is not an ApplicationIdExtension".into(),
464            )),
465        }
466    }
467
468    /// Get a reference to this extension as [`RatchetTreeExtension`].
469    /// Returns an [`ExtensionError::InvalidExtensionType`] if called on
470    /// an [`Extension`] that's not a [`RatchetTreeExtension`].
471    pub fn as_ratchet_tree_extension(&self) -> Result<&RatchetTreeExtension, ExtensionError> {
472        match self {
473            Self::RatchetTree(rte) => Ok(rte),
474            _ => Err(ExtensionError::InvalidExtensionType(
475                "This is not a RatchetTreeExtension".into(),
476            )),
477        }
478    }
479
480    /// Get a reference to this extension as [`RequiredCapabilitiesExtension`].
481    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
482    /// an [`Extension`] that's not a [`RequiredCapabilitiesExtension`].
483    pub fn as_required_capabilities_extension(
484        &self,
485    ) -> Result<&RequiredCapabilitiesExtension, ExtensionError> {
486        match self {
487            Self::RequiredCapabilities(e) => Ok(e),
488            _ => Err(ExtensionError::InvalidExtensionType(
489                "This is not a RequiredCapabilitiesExtension".into(),
490            )),
491        }
492    }
493
494    /// Get a reference to this extension as [`ExternalPubExtension`].
495    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
496    /// an [`Extension`] that's not a [`ExternalPubExtension`].
497    pub fn as_external_pub_extension(&self) -> Result<&ExternalPubExtension, ExtensionError> {
498        match self {
499            Self::ExternalPub(e) => Ok(e),
500            _ => Err(ExtensionError::InvalidExtensionType(
501                "This is not an ExternalPubExtension".into(),
502            )),
503        }
504    }
505
506    /// Get a reference to this extension as [`ExternalSendersExtension`].
507    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
508    /// an [`Extension`] that's not a [`ExternalSendersExtension`].
509    pub fn as_external_senders_extension(
510        &self,
511    ) -> Result<&ExternalSendersExtension, ExtensionError> {
512        match self {
513            Self::ExternalSenders(e) => Ok(e),
514            _ => Err(ExtensionError::InvalidExtensionType(
515                "This is not an ExternalSendersExtension".into(),
516            )),
517        }
518    }
519
520    /// Returns the [`ExtensionType`]
521    #[inline]
522    pub const fn extension_type(&self) -> ExtensionType {
523        match self {
524            Extension::ApplicationId(_) => ExtensionType::ApplicationId,
525            Extension::RatchetTree(_) => ExtensionType::RatchetTree,
526            Extension::RequiredCapabilities(_) => ExtensionType::RequiredCapabilities,
527            Extension::ExternalPub(_) => ExtensionType::ExternalPub,
528            Extension::ExternalSenders(_) => ExtensionType::ExternalSenders,
529            Extension::LastResort(_) => ExtensionType::LastResort,
530            Extension::Unknown(kind, _) => ExtensionType::Unknown(*kind),
531        }
532    }
533}
534
535#[cfg(test)]
536mod test {
537    use itertools::Itertools;
538    use tls_codec::{Deserialize, Serialize, VLBytes};
539
540    use crate::{ciphersuite::HpkePublicKey, extensions::*};
541
542    #[test]
543    fn add() {
544        let mut extensions = Extensions::default();
545        extensions
546            .add(Extension::RequiredCapabilities(
547                RequiredCapabilitiesExtension::default(),
548            ))
549            .unwrap();
550        assert!(extensions
551            .add(Extension::RequiredCapabilities(
552                RequiredCapabilitiesExtension::default()
553            ))
554            .is_err());
555    }
556
557    #[test]
558    fn add_try_from() {
559        // Create some extensions with different extension types and test that
560        // duplicates are rejected. The extension content does not matter in this test.
561        let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
562        let ext_y = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
563
564        let tests = [
565            (vec![], true),
566            (vec![ext_x.clone()], true),
567            (vec![ext_x.clone(), ext_x.clone()], false),
568            (vec![ext_x.clone(), ext_x.clone(), ext_x.clone()], false),
569            (vec![ext_y.clone()], true),
570            (vec![ext_y.clone(), ext_y.clone()], false),
571            (vec![ext_y.clone(), ext_y.clone(), ext_y.clone()], false),
572            (vec![ext_x.clone(), ext_y.clone()], true),
573            (vec![ext_y.clone(), ext_x.clone()], true),
574            (vec![ext_x.clone(), ext_x.clone(), ext_y.clone()], false),
575            (vec![ext_y.clone(), ext_y.clone(), ext_x.clone()], false),
576            (vec![ext_x.clone(), ext_y.clone(), ext_y.clone()], false),
577            (vec![ext_y.clone(), ext_x.clone(), ext_x.clone()], false),
578            (vec![ext_x.clone(), ext_y.clone(), ext_x.clone()], false),
579            (vec![ext_y.clone(), ext_x, ext_y], false),
580        ];
581
582        for (test, should_work) in tests.into_iter() {
583            // Test `add`.
584            {
585                let mut extensions = Extensions::default();
586
587                let mut works = true;
588                for ext in test.iter() {
589                    match extensions.add(ext.clone()) {
590                        Ok(_) => {}
591                        Err(InvalidExtensionError::Duplicate) => {
592                            works = false;
593                        }
594                        _ => panic!("This should have never happened."),
595                    }
596                }
597
598                println!("{:?}, {:?}", test.clone(), should_work);
599                assert_eq!(works, should_work);
600            }
601
602            // Test `try_from`.
603            if should_work {
604                assert!(Extensions::try_from(test).is_ok());
605            } else {
606                assert!(Extensions::try_from(test).is_err());
607            }
608        }
609    }
610
611    #[test]
612    fn ensure_ordering() {
613        // Create some extensions with different extension types and test
614        // that all permutations keep their order after being (de)serialized.
615        // The extension content does not matter in this test.
616        let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
617        let ext_y = Extension::ExternalPub(ExternalPubExtension::new(HpkePublicKey::new(vec![])));
618        let ext_z = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
619
620        for candidate in [ext_x, ext_y, ext_z]
621            .into_iter()
622            .permutations(3)
623            .collect::<Vec<_>>()
624        {
625            let candidate: Extensions = Extensions::try_from(candidate).unwrap();
626            let bytes = candidate.tls_serialize_detached().unwrap();
627            let got = Extensions::tls_deserialize(&mut bytes.as_slice()).unwrap();
628            assert_eq!(candidate, got);
629        }
630    }
631
632    #[test]
633    fn that_unknown_extensions_are_de_serialized_correctly() {
634        let extension_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF100, 0xFFFF];
635        let extension_datas = [vec![], vec![0], vec![1, 2, 3]];
636
637        for extension_type in extension_types.into_iter() {
638            for extension_data in extension_datas.iter() {
639                // Construct an unknown extension manually.
640                let test = {
641                    let mut buf = extension_type.to_be_bytes().to_vec();
642                    buf.append(
643                        &mut VLBytes::new(extension_data.clone())
644                            .tls_serialize_detached()
645                            .unwrap(),
646                    );
647                    buf
648                };
649
650                // Test deserialization.
651                let got = Extension::tls_deserialize_exact(&test).unwrap();
652
653                match got {
654                    Extension::Unknown(got_extension_type, ref got_extension_data) => {
655                        assert_eq!(extension_type, got_extension_type);
656                        assert_eq!(extension_data, &got_extension_data.0);
657                    }
658                    other => panic!("Expected `Extension::Unknown`, got {other:?}"),
659                }
660
661                // Test serialization.
662                let got_serialized = got.tls_serialize_detached().unwrap();
663                assert_eq!(test, got_serialized);
664            }
665        }
666    }
667}