openmls/extensions/
mod.rs

1//! # Extensions
2//!
3//! In MLS, extensions appear in the following places:
4//!
5//! - In [`KeyPackages`](`crate::key_packages`), to describe client capabilities
6//!   and aspects of their participation in the group.
7//!
8//! - In `GroupInfo`, to inform new members of the group's parameters and to
9//!   provide any additional information required to join the group.
10//!
11//! - In the `GroupContext` object, to ensure that all members of the group have
12//!   a consistent view of the parameters in use.
13//!
14//! Note that `GroupInfo` and `GroupContext` are not exposed via OpenMLS' public
15//! API.
16//!
17//! OpenMLS supports the following extensions:
18//!
19//! - [`ApplicationIdExtension`] (KeyPackage extension)
20//! - [`RatchetTreeExtension`] (GroupInfo extension)
21//! - [`RequiredCapabilitiesExtension`] (GroupContext extension)
22//! - [`ExternalPubExtension`] (GroupInfo extension)
23
24use std::{
25    fmt::Debug,
26    io::{Read, Write},
27};
28
29use serde::{Deserialize, Serialize};
30
31// Private
32mod application_id_extension;
33mod codec;
34mod external_pub_extension;
35mod external_sender_extension;
36mod last_resort;
37mod ratchet_tree_extension;
38mod required_capabilities;
39use errors::*;
40
41// Public
42pub mod errors;
43
44// Public re-exports
45pub use application_id_extension::ApplicationIdExtension;
46pub use external_pub_extension::ExternalPubExtension;
47pub use external_sender_extension::{
48    ExternalSender, ExternalSendersExtension, SenderExtensionIndex,
49};
50pub use last_resort::LastResortExtension;
51pub use ratchet_tree_extension::RatchetTreeExtension;
52pub use required_capabilities::RequiredCapabilitiesExtension;
53use tls_codec::{
54    Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
55    Size, TlsSize,
56};
57
58#[cfg(test)]
59mod tests;
60
61/// MLS Extension Types
62///
63/// Copied from draft-ietf-mls-protocol-16:
64///
65/// | Value            | Name                     | Message(s) | Recommended | Reference |
66/// |:-----------------|:-------------------------|:-----------|:------------|:----------|
67/// | 0x0000           | RESERVED                 | N/A        | N/A         | RFC XXXX  |
68/// | 0x0001           | application_id           | LN         | Y           | RFC XXXX  |
69/// | 0x0002           | ratchet_tree             | GI         | Y           | RFC XXXX  |
70/// | 0x0003           | required_capabilities    | GC         | Y           | RFC XXXX  |
71/// | 0x0004           | external_pub             | GI         | Y           | RFC XXXX  |
72/// | 0x0005           | external_senders         | GC         | Y           | RFC XXXX  |
73/// | 0xff00  - 0xffff | Reserved for Private Use | N/A        | N/A         | RFC XXXX  |
74///
75/// Note: OpenMLS does not provide a `Reserved` variant in [ExtensionType].
76#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)]
77pub enum ExtensionType {
78    /// The application id extension allows applications to add an explicit,
79    /// application-defined identifier to a KeyPackage.
80    ApplicationId,
81
82    /// The ratchet tree extensions provides the whole public state of the
83    /// ratchet tree.
84    RatchetTree,
85
86    /// The required capabilities extension defines the configuration of a group
87    /// that imposes certain requirements on clients in the group.
88    RequiredCapabilities,
89
90    /// To join a group via an External Commit, a new member needs a GroupInfo
91    /// with an ExternalPub extension present in its extensions field.
92    ExternalPub,
93
94    /// Group context extension that contains the credentials and signature keys
95    /// of senders that are permitted to send external proposals to the group.
96    ExternalSenders,
97
98    /// KeyPackage extension that marks a KeyPackage for use in a last resort
99    /// scenario.
100    LastResort,
101
102    /// A currently unknown extension type.
103    Unknown(u16),
104}
105
106impl ExtensionType {
107    /// Returns true for all extension types that are considered "default" by the spec.
108    pub(crate) fn is_default(self) -> bool {
109        match self {
110            ExtensionType::ApplicationId
111            | ExtensionType::RatchetTree
112            | ExtensionType::RequiredCapabilities
113            | ExtensionType::ExternalPub
114            | ExtensionType::ExternalSenders => true,
115            ExtensionType::LastResort | ExtensionType::Unknown(_) => false,
116        }
117    }
118
119    /// Returns whether an extension type is valid when used in leaf nodes.
120    /// Returns None if validity can not be determined.
121    /// This is the case for unknown extensions.
122    pub(crate) fn is_valid_in_leaf_node(self) -> Option<bool> {
123        match self {
124            ExtensionType::LastResort
125            | ExtensionType::RatchetTree
126            | ExtensionType::RequiredCapabilities
127            | ExtensionType::ExternalPub
128            | ExtensionType::ExternalSenders => Some(false),
129            ExtensionType::ApplicationId => Some(true),
130            ExtensionType::Unknown(_) => None,
131        }
132    }
133}
134
135impl Size for ExtensionType {
136    fn tls_serialized_len(&self) -> usize {
137        2
138    }
139}
140
141impl TlsDeserializeTrait for ExtensionType {
142    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
143    where
144        Self: Sized,
145    {
146        let mut extension_type = [0u8; 2];
147        bytes.read_exact(&mut extension_type)?;
148
149        Ok(ExtensionType::from(u16::from_be_bytes(extension_type)))
150    }
151}
152
153impl DeserializeBytes for ExtensionType {
154    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
155    where
156        Self: Sized,
157    {
158        let mut bytes_ref = bytes;
159        let extension_type = ExtensionType::tls_deserialize(&mut bytes_ref)?;
160        let remainder = &bytes[extension_type.tls_serialized_len()..];
161        Ok((extension_type, remainder))
162    }
163}
164
165impl TlsSerializeTrait for ExtensionType {
166    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
167        writer.write_all(&u16::from(*self).to_be_bytes())?;
168
169        Ok(2)
170    }
171}
172
173impl From<u16> for ExtensionType {
174    fn from(a: u16) -> Self {
175        match a {
176            1 => ExtensionType::ApplicationId,
177            2 => ExtensionType::RatchetTree,
178            3 => ExtensionType::RequiredCapabilities,
179            4 => ExtensionType::ExternalPub,
180            5 => ExtensionType::ExternalSenders,
181            10 => ExtensionType::LastResort,
182            unknown => ExtensionType::Unknown(unknown),
183        }
184    }
185}
186
187impl From<ExtensionType> for u16 {
188    fn from(value: ExtensionType) -> Self {
189        match value {
190            ExtensionType::ApplicationId => 1,
191            ExtensionType::RatchetTree => 2,
192            ExtensionType::RequiredCapabilities => 3,
193            ExtensionType::ExternalPub => 4,
194            ExtensionType::ExternalSenders => 5,
195            ExtensionType::LastResort => 10,
196            ExtensionType::Unknown(unknown) => unknown,
197        }
198    }
199}
200
201/// # Extension
202///
203/// An extension is one of the [`Extension`] enum values.
204/// The enum provides a set of common functionality for all extensions.
205///
206/// See the individual extensions for more details on each extension.
207///
208/// ```c
209/// // draft-ietf-mls-protocol-16
210/// struct {
211///     ExtensionType extension_type;
212///     opaque extension_data<V>;
213/// } Extension;
214/// ```
215#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
216pub enum Extension {
217    /// An [`ApplicationIdExtension`]
218    ApplicationId(ApplicationIdExtension),
219
220    /// A [`RatchetTreeExtension`]
221    RatchetTree(RatchetTreeExtension),
222
223    /// A [`RequiredCapabilitiesExtension`]
224    RequiredCapabilities(RequiredCapabilitiesExtension),
225
226    /// An [`ExternalPubExtension`]
227    ExternalPub(ExternalPubExtension),
228
229    /// An [`ExternalSendersExtension`]
230    ExternalSenders(ExternalSendersExtension),
231
232    /// A [`LastResortExtension`]
233    LastResort(LastResortExtension),
234
235    /// A currently unknown extension.
236    Unknown(u16, UnknownExtension),
237}
238
239/// A unknown/unparsed extension represented by raw bytes.
240#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize)]
241pub struct UnknownExtension(pub Vec<u8>);
242
243/// A list of extensions with unique extension types.
244#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TlsSize)]
245pub struct Extensions {
246    unique: Vec<Extension>,
247}
248
249impl TlsSerializeTrait for Extensions {
250    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
251        self.unique.tls_serialize(writer)
252    }
253}
254
255impl TlsDeserializeTrait for Extensions {
256    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
257    where
258        Self: Sized,
259    {
260        let candidate: Vec<Extension> = Vec::tls_deserialize(bytes)?;
261        Extensions::try_from(candidate)
262            .map_err(|_| Error::DecodingError("Found duplicate extensions".into()))
263    }
264}
265
266impl DeserializeBytes for Extensions {
267    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
268    where
269        Self: Sized,
270    {
271        let mut bytes_ref = bytes;
272        let extensions = Extensions::tls_deserialize(&mut bytes_ref)?;
273        let remainder = &bytes[extensions.tls_serialized_len()..];
274        Ok((extensions, remainder))
275    }
276}
277
278impl Extensions {
279    /// Create an empty extension list.
280    pub fn empty() -> Self {
281        Self { unique: vec![] }
282    }
283
284    /// Create an extension list with a single extension.
285    pub fn single(extension: Extension) -> Self {
286        Self {
287            unique: vec![extension],
288        }
289    }
290
291    /// Create an extension list with multiple extensions.
292    ///
293    /// This function will fail when the list of extensions contains duplicate
294    /// extension types.
295    pub fn from_vec(extensions: Vec<Extension>) -> Result<Self, InvalidExtensionError> {
296        extensions.try_into()
297    }
298
299    /// Returns an iterator over the extension list.
300    pub fn iter(&self) -> impl Iterator<Item = &Extension> {
301        self.unique.iter()
302    }
303
304    /// Add an extension to the extension list.
305    ///
306    /// Returns an error when there already is an extension with the same
307    /// extension type.
308    pub fn add(&mut self, extension: Extension) -> Result<(), InvalidExtensionError> {
309        if self.contains(extension.extension_type()) {
310            return Err(InvalidExtensionError::Duplicate);
311        }
312
313        self.unique.push(extension);
314
315        Ok(())
316    }
317
318    /// Add an extension to the extension list (or replace an existing one.)
319    ///
320    /// Returns the replaced extension (if any).
321    pub fn add_or_replace(&mut self, extension: Extension) -> Option<Extension> {
322        let replaced = self.remove(extension.extension_type());
323        self.unique.push(extension);
324        replaced
325    }
326
327    /// Remove an extension from the extension list.
328    ///
329    /// Returns the removed extension or `None` when there is no extension with
330    /// the given extension type.
331    pub fn remove(&mut self, extension_type: ExtensionType) -> Option<Extension> {
332        if let Some(pos) = self
333            .unique
334            .iter()
335            .position(|ext| ext.extension_type() == extension_type)
336        {
337            Some(self.unique.remove(pos))
338        } else {
339            None
340        }
341    }
342
343    /// Returns `true` iff the extension list contains an extension with the
344    /// given extension type.
345    pub fn contains(&self, extension_type: ExtensionType) -> bool {
346        self.unique
347            .iter()
348            .any(|ext| ext.extension_type() == extension_type)
349    }
350}
351
352impl TryFrom<Vec<Extension>> for Extensions {
353    type Error = InvalidExtensionError;
354
355    fn try_from(candidate: Vec<Extension>) -> Result<Self, Self::Error> {
356        let mut unique: Vec<Extension> = Vec::new();
357
358        for extension in candidate.into_iter() {
359            if unique
360                .iter()
361                .any(|ext| ext.extension_type() == extension.extension_type())
362            {
363                return Err(InvalidExtensionError::Duplicate);
364            } else {
365                unique.push(extension);
366            }
367        }
368
369        Ok(Self { unique })
370    }
371}
372
373impl Extensions {
374    fn find_by_type(&self, extension_type: ExtensionType) -> Option<&Extension> {
375        self.unique
376            .iter()
377            .find(|ext| ext.extension_type() == extension_type)
378    }
379
380    /// Get a reference to the [`ApplicationIdExtension`] if there is any.
381    pub fn application_id(&self) -> Option<&ApplicationIdExtension> {
382        self.find_by_type(ExtensionType::ApplicationId)
383            .and_then(|e| match e {
384                Extension::ApplicationId(e) => Some(e),
385                _ => None,
386            })
387    }
388
389    /// Get a reference to the [`RatchetTreeExtension`] if there is any.
390    pub fn ratchet_tree(&self) -> Option<&RatchetTreeExtension> {
391        self.find_by_type(ExtensionType::RatchetTree)
392            .and_then(|e| match e {
393                Extension::RatchetTree(e) => Some(e),
394                _ => None,
395            })
396    }
397
398    /// Get a reference to the [`RequiredCapabilitiesExtension`] if there is
399    /// any.
400    pub fn required_capabilities(&self) -> Option<&RequiredCapabilitiesExtension> {
401        self.find_by_type(ExtensionType::RequiredCapabilities)
402            .and_then(|e| match e {
403                Extension::RequiredCapabilities(e) => Some(e),
404                _ => None,
405            })
406    }
407
408    /// Get a reference to the [`ExternalPubExtension`] if there is any.
409    pub fn external_pub(&self) -> Option<&ExternalPubExtension> {
410        self.find_by_type(ExtensionType::ExternalPub)
411            .and_then(|e| match e {
412                Extension::ExternalPub(e) => Some(e),
413                _ => None,
414            })
415    }
416
417    /// Get a reference to the [`ExternalSendersExtension`] if there is any.
418    pub fn external_senders(&self) -> Option<&ExternalSendersExtension> {
419        self.find_by_type(ExtensionType::ExternalSenders)
420            .and_then(|e| match e {
421                Extension::ExternalSenders(e) => Some(e),
422                _ => None,
423            })
424    }
425
426    /// Get a reference to the [`UnknownExtension`] with the given type id, if there is any.
427    pub fn unknown(&self, extension_type_id: u16) -> Option<&UnknownExtension> {
428        let extension_type: ExtensionType = extension_type_id.into();
429
430        match extension_type {
431            ExtensionType::Unknown(_) => self.find_by_type(extension_type).and_then(|e| match e {
432                Extension::Unknown(_, e) => Some(e),
433                _ => None,
434            }),
435            _ => None,
436        }
437    }
438}
439
440impl Extension {
441    /// Get a reference to this extension as [`ApplicationIdExtension`].
442    /// Returns an [`ExtensionError::InvalidExtensionType`] if called on an
443    /// [`Extension`] that's not an [`ApplicationIdExtension`].
444    pub fn as_application_id_extension(&self) -> Result<&ApplicationIdExtension, ExtensionError> {
445        match self {
446            Self::ApplicationId(e) => Ok(e),
447            _ => Err(ExtensionError::InvalidExtensionType(
448                "This is not an ApplicationIdExtension".into(),
449            )),
450        }
451    }
452
453    /// Get a reference to this extension as [`RatchetTreeExtension`].
454    /// Returns an [`ExtensionError::InvalidExtensionType`] if called on
455    /// an [`Extension`] that's not a [`RatchetTreeExtension`].
456    pub fn as_ratchet_tree_extension(&self) -> Result<&RatchetTreeExtension, ExtensionError> {
457        match self {
458            Self::RatchetTree(rte) => Ok(rte),
459            _ => Err(ExtensionError::InvalidExtensionType(
460                "This is not a RatchetTreeExtension".into(),
461            )),
462        }
463    }
464
465    /// Get a reference to this extension as [`RequiredCapabilitiesExtension`].
466    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
467    /// an [`Extension`] that's not a [`RequiredCapabilitiesExtension`].
468    pub fn as_required_capabilities_extension(
469        &self,
470    ) -> Result<&RequiredCapabilitiesExtension, ExtensionError> {
471        match self {
472            Self::RequiredCapabilities(e) => Ok(e),
473            _ => Err(ExtensionError::InvalidExtensionType(
474                "This is not a RequiredCapabilitiesExtension".into(),
475            )),
476        }
477    }
478
479    /// Get a reference to this extension as [`ExternalPubExtension`].
480    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
481    /// an [`Extension`] that's not a [`ExternalPubExtension`].
482    pub fn as_external_pub_extension(&self) -> Result<&ExternalPubExtension, ExtensionError> {
483        match self {
484            Self::ExternalPub(e) => Ok(e),
485            _ => Err(ExtensionError::InvalidExtensionType(
486                "This is not an ExternalPubExtension".into(),
487            )),
488        }
489    }
490
491    /// Get a reference to this extension as [`ExternalSendersExtension`].
492    /// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
493    /// an [`Extension`] that's not a [`ExternalSendersExtension`].
494    pub fn as_external_senders_extension(
495        &self,
496    ) -> Result<&ExternalSendersExtension, ExtensionError> {
497        match self {
498            Self::ExternalSenders(e) => Ok(e),
499            _ => Err(ExtensionError::InvalidExtensionType(
500                "This is not an ExternalSendersExtension".into(),
501            )),
502        }
503    }
504
505    /// Returns the [`ExtensionType`]
506    #[inline]
507    pub const fn extension_type(&self) -> ExtensionType {
508        match self {
509            Extension::ApplicationId(_) => ExtensionType::ApplicationId,
510            Extension::RatchetTree(_) => ExtensionType::RatchetTree,
511            Extension::RequiredCapabilities(_) => ExtensionType::RequiredCapabilities,
512            Extension::ExternalPub(_) => ExtensionType::ExternalPub,
513            Extension::ExternalSenders(_) => ExtensionType::ExternalSenders,
514            Extension::LastResort(_) => ExtensionType::LastResort,
515            Extension::Unknown(kind, _) => ExtensionType::Unknown(*kind),
516        }
517    }
518}
519
520#[cfg(test)]
521mod test {
522    use itertools::Itertools;
523    use tls_codec::{Deserialize, Serialize, VLBytes};
524
525    use crate::{ciphersuite::HpkePublicKey, extensions::*};
526
527    #[test]
528    fn add() {
529        let mut extensions = Extensions::default();
530        extensions
531            .add(Extension::RequiredCapabilities(
532                RequiredCapabilitiesExtension::default(),
533            ))
534            .unwrap();
535        assert!(extensions
536            .add(Extension::RequiredCapabilities(
537                RequiredCapabilitiesExtension::default()
538            ))
539            .is_err());
540    }
541
542    #[test]
543    fn add_try_from() {
544        // Create some extensions with different extension types and test that
545        // duplicates are rejected. The extension content does not matter in this test.
546        let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
547        let ext_y = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
548
549        let tests = [
550            (vec![], true),
551            (vec![ext_x.clone()], true),
552            (vec![ext_x.clone(), ext_x.clone()], false),
553            (vec![ext_x.clone(), ext_x.clone(), ext_x.clone()], false),
554            (vec![ext_y.clone()], true),
555            (vec![ext_y.clone(), ext_y.clone()], false),
556            (vec![ext_y.clone(), ext_y.clone(), ext_y.clone()], false),
557            (vec![ext_x.clone(), ext_y.clone()], true),
558            (vec![ext_y.clone(), ext_x.clone()], true),
559            (vec![ext_x.clone(), ext_x.clone(), ext_y.clone()], false),
560            (vec![ext_y.clone(), ext_y.clone(), ext_x.clone()], false),
561            (vec![ext_x.clone(), ext_y.clone(), ext_y.clone()], false),
562            (vec![ext_y.clone(), ext_x.clone(), ext_x.clone()], false),
563            (vec![ext_x.clone(), ext_y.clone(), ext_x.clone()], false),
564            (vec![ext_y.clone(), ext_x, ext_y], false),
565        ];
566
567        for (test, should_work) in tests.into_iter() {
568            // Test `add`.
569            {
570                let mut extensions = Extensions::default();
571
572                let mut works = true;
573                for ext in test.iter() {
574                    match extensions.add(ext.clone()) {
575                        Ok(_) => {}
576                        Err(InvalidExtensionError::Duplicate) => {
577                            works = false;
578                        }
579                        _ => panic!("This should have never happened."),
580                    }
581                }
582
583                println!("{:?}, {:?}", test.clone(), should_work);
584                assert_eq!(works, should_work);
585            }
586
587            // Test `try_from`.
588            if should_work {
589                assert!(Extensions::try_from(test).is_ok());
590            } else {
591                assert!(Extensions::try_from(test).is_err());
592            }
593        }
594    }
595
596    #[test]
597    fn ensure_ordering() {
598        // Create some extensions with different extension types and test
599        // that all permutations keep their order after being (de)serialized.
600        // The extension content does not matter in this test.
601        let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
602        let ext_y = Extension::ExternalPub(ExternalPubExtension::new(HpkePublicKey::new(vec![])));
603        let ext_z = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
604
605        for candidate in [ext_x, ext_y, ext_z]
606            .into_iter()
607            .permutations(3)
608            .collect::<Vec<_>>()
609        {
610            let candidate: Extensions = Extensions::try_from(candidate).unwrap();
611            let bytes = candidate.tls_serialize_detached().unwrap();
612            let got = Extensions::tls_deserialize(&mut bytes.as_slice()).unwrap();
613            assert_eq!(candidate, got);
614        }
615    }
616
617    #[test]
618    fn that_unknown_extensions_are_de_serialized_correctly() {
619        let extension_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF100, 0xFFFF];
620        let extension_datas = [vec![], vec![0], vec![1, 2, 3]];
621
622        for extension_type in extension_types.into_iter() {
623            for extension_data in extension_datas.iter() {
624                // Construct an unknown extension manually.
625                let test = {
626                    let mut buf = extension_type.to_be_bytes().to_vec();
627                    buf.append(
628                        &mut VLBytes::new(extension_data.clone())
629                            .tls_serialize_detached()
630                            .unwrap(),
631                    );
632                    buf
633                };
634
635                // Test deserialization.
636                let got = Extension::tls_deserialize_exact(&test).unwrap();
637
638                match got {
639                    Extension::Unknown(got_extension_type, ref got_extension_data) => {
640                        assert_eq!(extension_type, got_extension_type);
641                        assert_eq!(extension_data, &got_extension_data.0);
642                    }
643                    other => panic!("Expected `Extension::Unknown`, got {:?}", other),
644                }
645
646                // Test serialization.
647                let got_serialized = got.tls_serialize_detached().unwrap();
648                assert_eq!(test, got_serialized);
649            }
650        }
651    }
652}