openmls/treesync/node/leaf_node/
capabilities.rs

1use openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
2use serde::{Deserialize, Serialize};
3use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
4
5#[cfg(doc)]
6use super::LeafNode;
7use crate::{
8    credentials::CredentialType,
9    extensions::{Extension, ExtensionType, Extensions, RequiredCapabilitiesExtension},
10    messages::proposals::ProposalType,
11    treesync::errors::LeafNodeValidationError,
12    versions::ProtocolVersion,
13};
14
15/// Capabilities of [`LeafNode`]s.
16///
17/// ```text
18/// struct {
19///     ProtocolVersion versions<V>;
20///     CipherSuite ciphersuites<V>;
21///     ExtensionType extensions<V>;
22///     ProposalType proposals<V>;
23///     CredentialType credentials<V>;
24/// } Capabilities;
25/// ```
26#[derive(
27    Debug,
28    Clone,
29    PartialEq,
30    Eq,
31    Serialize,
32    Deserialize,
33    TlsSerialize,
34    TlsDeserialize,
35    TlsDeserializeBytes,
36    TlsSize,
37)]
38pub struct Capabilities {
39    pub(super) versions: Vec<ProtocolVersion>,
40    pub(super) ciphersuites: Vec<VerifiableCiphersuite>,
41    pub(super) extensions: Vec<ExtensionType>,
42    pub(super) proposals: Vec<ProposalType>,
43    pub(super) credentials: Vec<CredentialType>,
44}
45
46impl Capabilities {
47    /// Create a new [`Capabilities`] struct with the given configuration.
48    /// Any argument that is `None` is filled with the default values from the
49    /// global configuration.
50    // TODO(#1232)
51    pub fn new(
52        versions: Option<&[ProtocolVersion]>,
53        ciphersuites: Option<&[Ciphersuite]>,
54        extensions: Option<&[ExtensionType]>,
55        proposals: Option<&[ProposalType]>,
56        credentials: Option<&[CredentialType]>,
57    ) -> Self {
58        Self {
59            versions: match versions {
60                Some(v) => v.into(),
61                None => default_versions(),
62            },
63            ciphersuites: match ciphersuites {
64                Some(c) => c.iter().map(|c| VerifiableCiphersuite::from(*c)).collect(),
65                None => default_ciphersuites()
66                    .into_iter()
67                    .map(VerifiableCiphersuite::from)
68                    .collect(),
69            },
70            extensions: match extensions {
71                Some(e) => e.into(),
72                None => vec![],
73            },
74            proposals: match proposals {
75                Some(p) => p.into(),
76                None => vec![],
77            },
78            credentials: match credentials {
79                Some(c) => c.into(),
80                None => default_credentials(),
81            },
82        }
83    }
84
85    /// Create new empty [`Capabilities`].
86    pub fn empty() -> Self {
87        Self {
88            versions: Vec::new(),
89            ciphersuites: Vec::new(),
90            extensions: Vec::new(),
91            proposals: Vec::new(),
92            credentials: Vec::new(),
93        }
94    }
95
96    /// Creates a new [`CapabilitiesBuilder`] for constructing [`Capabilities`]
97    pub fn builder() -> CapabilitiesBuilder {
98        CapabilitiesBuilder(Self::default())
99    }
100
101    // ---------------------------------------------------------------------------------------------
102
103    /// Get a reference to the list of versions in this extension.
104    pub fn versions(&self) -> &[ProtocolVersion] {
105        &self.versions
106    }
107
108    /// Get a reference to the list of ciphersuites in this extension.
109    pub fn ciphersuites(&self) -> &[VerifiableCiphersuite] {
110        &self.ciphersuites
111    }
112
113    /// Get a reference to the list of supported extensions.
114    pub fn extensions(&self) -> &[ExtensionType] {
115        &self.extensions
116    }
117
118    /// Get a reference to the list of supported proposals.
119    pub fn proposals(&self) -> &[ProposalType] {
120        &self.proposals
121    }
122
123    /// Get a reference to the list of supported credential types.
124    pub fn credentials(&self) -> &[CredentialType] {
125        &self.credentials
126    }
127
128    // ---------------------------------------------------------------------------------------------
129
130    /// Check if these [`Capabilities`] support all the capabilities required by
131    /// the given [`RequiredCapabilitiesExtension`].
132    ///
133    /// # Errors
134    ///
135    /// Returns a [`LeafNodeValidationError`] error if any of the required
136    /// capabilities is not supported.
137    pub(crate) fn supports_required_capabilities(
138        &self,
139        required_capabilities: &RequiredCapabilitiesExtension,
140    ) -> Result<(), LeafNodeValidationError> {
141        // Check if all required extensions are supported.
142        let unsupported_extension_types = required_capabilities
143            .extension_types()
144            .iter()
145            .filter(|&e| !self.contains_extension(*e))
146            .collect::<Vec<_>>();
147        if !unsupported_extension_types.is_empty() {
148            log::error!(
149                "Leaf node does not support all required extension types\n
150                Supported extensions: {:?}\n
151                Required extensions: {:?}",
152                self.extensions(),
153                required_capabilities.extension_types()
154            );
155            return Err(LeafNodeValidationError::UnsupportedExtensions);
156        }
157        // Check if all required proposals are supported.
158        if required_capabilities
159            .proposal_types()
160            .iter()
161            .any(|p| !self.contains_proposal(*p))
162        {
163            return Err(LeafNodeValidationError::UnsupportedProposals);
164        }
165        // Check if all required credential types are supported.
166        if required_capabilities
167            .credential_types()
168            .iter()
169            .any(|c| !self.contains_credential(*c))
170        {
171            return Err(LeafNodeValidationError::UnsupportedCredentials);
172        }
173        Ok(())
174    }
175
176    /// Check if these [`Capabilities`] contain all the extensions.
177    pub(crate) fn contains_extensions(&self, extension: &Extensions) -> bool {
178        extension
179            .iter()
180            .map(Extension::extension_type)
181            .all(|e| e.is_default() || self.extensions().contains(&e))
182    }
183
184    /// Check if these [`Capabilities`] contains the credential.
185    pub(crate) fn contains_credential(&self, credential_type: CredentialType) -> bool {
186        self.credentials().contains(&credential_type)
187    }
188
189    /// Check if these [`Capabilities`] contain the extension.
190    pub(crate) fn contains_extension(&self, extension_type: ExtensionType) -> bool {
191        extension_type.is_default() || self.extensions().contains(&extension_type)
192    }
193
194    /// Check if these [`Capabilities`] contain the proposal.
195    pub(crate) fn contains_proposal(&self, proposal_type: ProposalType) -> bool {
196        proposal_type.is_default() || self.proposals().contains(&proposal_type)
197    }
198
199    /// Check if these [`Capabilities`] contain the version.
200    pub(crate) fn contains_version(&self, version: ProtocolVersion) -> bool {
201        self.versions().contains(&version)
202    }
203
204    /// Check if these [`Capabilities`] contain the ciphersuite.
205    pub(crate) fn contains_ciphersuite(&self, ciphersuite: VerifiableCiphersuite) -> bool {
206        self.ciphersuites().contains(&ciphersuite)
207    }
208
209    /// Add random GREASE values to the capabilities to ensure extensibility.
210    ///
211    /// This adds one random GREASE value to each capability list if no GREASE
212    /// value is already present:
213    /// - Ciphersuites
214    /// - Extensions
215    /// - Proposals
216    /// - Credentials
217    ///
218    /// GREASE values are used per [RFC 9420 Section 13.5](https://www.rfc-editor.org/rfc/rfc9420.html#section-13.5)
219    /// to help prevent extensibility failures by ensuring implementations properly
220    /// handle unknown values.
221    ///
222    /// # Example
223    ///
224    /// ```
225    /// use openmls::prelude::*;
226    /// use openmls_rust_crypto::OpenMlsRustCrypto;
227    ///
228    /// let provider = OpenMlsRustCrypto::default();
229    ///
230    /// // Create capabilities with GREASE values injected
231    /// let capabilities = Capabilities::builder()
232    ///     .build()
233    ///     .with_grease(provider.rand());
234    ///
235    /// // Verify GREASE values were added
236    /// assert!(capabilities.ciphersuites().iter().any(|cs| cs.is_grease()));
237    /// assert!(capabilities.extensions().iter().any(|ext| ext.is_grease()));
238    /// assert!(capabilities.proposals().iter().any(|prop| prop.is_grease()));
239    /// assert!(capabilities.credentials().iter().any(|cred| cred.is_grease()));
240    /// ```
241    pub fn with_grease(mut self, rand: &impl openmls_traits::random::OpenMlsRand) -> Self {
242        use crate::credentials::CredentialType;
243        use crate::extensions::ExtensionType;
244        use crate::messages::proposals::ProposalType;
245        use openmls_traits::types::VerifiableCiphersuite;
246
247        // Add GREASE ciphersuite if none present
248        if !self.ciphersuites.iter().any(|cs| cs.is_grease()) {
249            let grease_cs = VerifiableCiphersuite::new(crate::grease::random_grease_value(rand));
250            self.ciphersuites.push(grease_cs);
251        }
252
253        // Add GREASE extension if none present
254        if !self.extensions.iter().any(|ext| ext.is_grease()) {
255            let grease_ext = ExtensionType::Grease(crate::grease::random_grease_value(rand));
256            self.extensions.push(grease_ext);
257        }
258
259        // Add GREASE proposal if none present
260        if !self.proposals.iter().any(|prop| prop.is_grease()) {
261            let grease_prop = ProposalType::Grease(crate::grease::random_grease_value(rand));
262            self.proposals.push(grease_prop);
263        }
264
265        // Add GREASE credential if none present
266        if !self.credentials.iter().any(|cred| cred.is_grease()) {
267            let grease_cred = CredentialType::Grease(crate::grease::random_grease_value(rand));
268            self.credentials.push(grease_cred);
269        }
270
271        self
272    }
273}
274
275/// A helper for building [`Capabilities`]
276#[derive(Debug, Clone)]
277pub struct CapabilitiesBuilder(Capabilities);
278
279impl CapabilitiesBuilder {
280    /// Sets the `versions` field on the [`Capabilities`].
281    pub fn versions(self, versions: Vec<ProtocolVersion>) -> Self {
282        Self(Capabilities { versions, ..self.0 })
283    }
284
285    /// Sets the `ciphersuites` field on the [`Capabilities`].
286    pub fn ciphersuites(self, ciphersuites: Vec<Ciphersuite>) -> Self {
287        let ciphersuites = ciphersuites.into_iter().map(|cs| cs.into()).collect();
288
289        Self(Capabilities {
290            ciphersuites,
291            ..self.0
292        })
293    }
294
295    /// Sets the `extensions` field on the [`Capabilities`].
296    pub fn extensions(self, extensions: Vec<ExtensionType>) -> Self {
297        Self(Capabilities {
298            extensions,
299            ..self.0
300        })
301    }
302
303    /// Sets the `proposals` field on the [`Capabilities`].
304    pub fn proposals(self, proposals: Vec<ProposalType>) -> Self {
305        Self(Capabilities {
306            proposals,
307            ..self.0
308        })
309    }
310
311    /// Sets the `credentials` field on the [`Capabilities`].
312    pub fn credentials(self, credentials: Vec<CredentialType>) -> Self {
313        Self(Capabilities {
314            credentials,
315            ..self.0
316        })
317    }
318
319    /// Adds random GREASE values to the capabilities being built.
320    ///
321    /// This is a convenience method that calls [`Capabilities::with_grease`] on the
322    /// built capabilities. See that method for more details.
323    ///
324    /// # Example
325    ///
326    /// ```
327    /// use openmls::prelude::*;
328    /// use openmls_rust_crypto::OpenMlsRustCrypto;
329    ///
330    /// let provider = OpenMlsRustCrypto::default();
331    ///
332    /// let capabilities = Capabilities::builder()
333    ///     .with_grease(provider.rand())
334    ///     .build();
335    ///
336    /// // GREASE values were added
337    /// assert!(capabilities.ciphersuites().iter().any(|cs| cs.is_grease()));
338    /// ```
339    pub fn with_grease(self, rand: &impl openmls_traits::random::OpenMlsRand) -> Self {
340        Self(self.0.with_grease(rand))
341    }
342
343    /// Builds the [`Capabilities`].
344    pub fn build(self) -> Capabilities {
345        self.0
346    }
347}
348
349#[cfg(test)]
350impl Capabilities {
351    /// Set the versions list.
352    pub fn set_versions(&mut self, versions: Vec<ProtocolVersion>) {
353        self.versions = versions;
354    }
355
356    /// Set the ciphersuites list.
357    pub fn set_ciphersuites(&mut self, ciphersuites: Vec<VerifiableCiphersuite>) {
358        self.ciphersuites = ciphersuites;
359    }
360}
361
362impl Default for Capabilities {
363    fn default() -> Self {
364        Capabilities {
365            versions: default_versions(),
366            ciphersuites: default_ciphersuites()
367                .into_iter()
368                .map(VerifiableCiphersuite::from)
369                .collect(),
370            extensions: vec![],
371            proposals: vec![],
372            credentials: default_credentials(),
373        }
374    }
375}
376
377pub(super) fn default_versions() -> Vec<ProtocolVersion> {
378    vec![ProtocolVersion::Mls10]
379}
380
381pub(super) fn default_ciphersuites() -> Vec<Ciphersuite> {
382    vec![
383        Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519,
384        Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256,
385        Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519,
386        Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519,
387    ]
388}
389
390// TODO(#1231)
391pub(super) fn default_credentials() -> Vec<CredentialType> {
392    vec![CredentialType::Basic]
393}
394
395#[cfg(test)]
396mod tests {
397    use openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
398    use tls_codec::{Deserialize, Serialize};
399
400    use super::Capabilities;
401    use crate::{
402        credentials::CredentialType, messages::proposals::ProposalType, prelude::ExtensionType,
403        versions::ProtocolVersion,
404    };
405
406    #[test]
407    fn that_unknown_capabilities_are_de_serialized_correctly() {
408        let versions = vec![ProtocolVersion::Mls10, ProtocolVersion::Other(999)];
409        let ciphersuites = vec![
410            Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519.into(),
411            Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256.into(),
412            Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519.into(),
413            Ciphersuite::MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448.into(),
414            Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521.into(),
415            Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448.into(),
416            Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384.into(),
417            VerifiableCiphersuite::new(0x0000),
418            // Use non-GREASE values (GREASE pattern is 0x_A_A)
419            VerifiableCiphersuite::new(0x0B0B),
420            VerifiableCiphersuite::new(0x7C7C),
421            VerifiableCiphersuite::new(0xF000),
422            VerifiableCiphersuite::new(0xFFFF),
423        ];
424
425        let extensions = vec![
426            ExtensionType::Unknown(0x0000),
427            ExtensionType::Unknown(0xFAFA),
428        ];
429
430        // Use non-GREASE values
431        let proposals = vec![ProposalType::Custom(0x7C7C)];
432
433        let credentials = vec![
434            CredentialType::Basic,
435            CredentialType::X509,
436            CredentialType::Other(0x0000),
437            // Use non-GREASE values
438            CredentialType::Other(0x7C7C),
439            CredentialType::Other(0xFFFF),
440        ];
441
442        let expected = Capabilities {
443            versions,
444            ciphersuites,
445            extensions,
446            proposals,
447            credentials,
448        };
449
450        let test_serialized = expected.tls_serialize_detached().unwrap();
451
452        let got = Capabilities::tls_deserialize_exact(test_serialized).unwrap();
453
454        assert_eq!(expected, got);
455    }
456}