Skip to main content

openmls/treesync/node/leaf_node/
capabilities.rs

1use openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
2use serde::{Deserialize, Serialize};
3use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
4
5#[cfg(doc)]
6use super::LeafNode;
7use crate::{
8    credentials::CredentialType,
9    extensions::{
10        Extension, ExtensionType, ExtensionValidator, Extensions, RequiredCapabilitiesExtension,
11    },
12    messages::proposals::ProposalType,
13    treesync::errors::LeafNodeValidationError,
14    versions::ProtocolVersion,
15};
16
17/// Capabilities of [`LeafNode`]s.
18///
19/// ```text
20/// struct {
21///     ProtocolVersion versions<V>;
22///     CipherSuite ciphersuites<V>;
23///     ExtensionType extensions<V>;
24///     ProposalType proposals<V>;
25///     CredentialType credentials<V>;
26/// } Capabilities;
27/// ```
28#[derive(
29    Debug,
30    Clone,
31    PartialEq,
32    Eq,
33    Serialize,
34    Deserialize,
35    TlsSerialize,
36    TlsDeserialize,
37    TlsDeserializeBytes,
38    TlsSize,
39)]
40pub struct Capabilities {
41    pub(super) versions: Vec<ProtocolVersion>,
42    pub(super) ciphersuites: Vec<VerifiableCiphersuite>,
43    pub(super) extensions: Vec<ExtensionType>,
44    pub(super) proposals: Vec<ProposalType>,
45    pub(super) credentials: Vec<CredentialType>,
46}
47
48impl Capabilities {
49    /// Create a new [`Capabilities`] struct with the given configuration.
50    /// Any argument that is `None` is filled with the default values from the
51    /// global configuration.
52    // TODO(#1232)
53    pub fn new(
54        versions: Option<&[ProtocolVersion]>,
55        ciphersuites: Option<&[Ciphersuite]>,
56        extensions: Option<&[ExtensionType]>,
57        proposals: Option<&[ProposalType]>,
58        credentials: Option<&[CredentialType]>,
59    ) -> Self {
60        Self {
61            versions: match versions {
62                Some(v) => v.into(),
63                None => default_versions(),
64            },
65            ciphersuites: match ciphersuites {
66                Some(c) => c.iter().map(|c| VerifiableCiphersuite::from(*c)).collect(),
67                None => default_ciphersuites()
68                    .into_iter()
69                    .map(VerifiableCiphersuite::from)
70                    .collect(),
71            },
72            extensions: match extensions {
73                Some(e) => e.into(),
74                None => vec![],
75            },
76            proposals: match proposals {
77                Some(p) => p.into(),
78                None => vec![],
79            },
80            credentials: match credentials {
81                Some(c) => c.into(),
82                None => default_credentials(),
83            },
84        }
85    }
86
87    /// Create new empty [`Capabilities`].
88    pub fn empty() -> Self {
89        Self {
90            versions: Vec::new(),
91            ciphersuites: Vec::new(),
92            extensions: Vec::new(),
93            proposals: Vec::new(),
94            credentials: Vec::new(),
95        }
96    }
97
98    /// Creates a new [`CapabilitiesBuilder`] for constructing [`Capabilities`]
99    pub fn builder() -> CapabilitiesBuilder {
100        CapabilitiesBuilder(Self::default())
101    }
102
103    // ---------------------------------------------------------------------------------------------
104
105    /// Get a reference to the list of versions in this extension.
106    pub fn versions(&self) -> &[ProtocolVersion] {
107        &self.versions
108    }
109
110    /// Get a reference to the list of ciphersuites in this extension.
111    pub fn ciphersuites(&self) -> &[VerifiableCiphersuite] {
112        &self.ciphersuites
113    }
114
115    /// Get a reference to the list of supported extensions.
116    pub fn extensions(&self) -> &[ExtensionType] {
117        &self.extensions
118    }
119
120    /// Get a reference to the list of supported proposals.
121    pub fn proposals(&self) -> &[ProposalType] {
122        &self.proposals
123    }
124
125    /// Get a reference to the list of supported credential types.
126    pub fn credentials(&self) -> &[CredentialType] {
127        &self.credentials
128    }
129
130    // ---------------------------------------------------------------------------------------------
131
132    /// Check if these [`Capabilities`] support all the capabilities required by
133    /// the given [`RequiredCapabilitiesExtension`].
134    ///
135    /// # Errors
136    ///
137    /// Returns a [`LeafNodeValidationError`] error if any of the required
138    /// capabilities is not supported.
139    pub(crate) fn supports_required_capabilities(
140        &self,
141        required_capabilities: &RequiredCapabilitiesExtension,
142    ) -> Result<(), LeafNodeValidationError> {
143        // Check if all required extensions are supported.
144        let unsupported_extension_types = required_capabilities
145            .extension_types()
146            .iter()
147            .filter(|&e| !self.contains_extension(*e))
148            .collect::<Vec<_>>();
149        if !unsupported_extension_types.is_empty() {
150            log::error!(
151                "Leaf node does not support all required extension types\n
152                Supported extensions: {:?}\n
153                Required extensions: {:?}",
154                self.extensions(),
155                required_capabilities.extension_types()
156            );
157            return Err(LeafNodeValidationError::UnsupportedExtensions);
158        }
159        // Check if all required proposals are supported.
160        if required_capabilities
161            .proposal_types()
162            .iter()
163            .any(|p| !self.contains_proposal(*p))
164        {
165            return Err(LeafNodeValidationError::UnsupportedProposals);
166        }
167        // Check if all required credential types are supported.
168        if required_capabilities
169            .credential_types()
170            .iter()
171            .any(|c| !self.contains_credential(*c))
172        {
173            return Err(LeafNodeValidationError::UnsupportedCredentials);
174        }
175        Ok(())
176    }
177
178    /// Check if these [`Capabilities`] contain all the extensions.
179    pub(crate) fn contains_extensions(
180        &self,
181        extensions: &Extensions<impl ExtensionValidator>,
182    ) -> bool {
183        extensions
184            .iter()
185            .map(Extension::extension_type)
186            .all(|e| e.is_default() || self.extensions().contains(&e))
187    }
188
189    /// Check if these [`Capabilities`] contains the credential.
190    pub(crate) fn contains_credential(&self, credential_type: CredentialType) -> bool {
191        self.credentials().contains(&credential_type)
192    }
193
194    /// Check if these [`Capabilities`] contain the extension.
195    pub(crate) fn contains_extension(&self, extension_type: ExtensionType) -> bool {
196        extension_type.is_default() || self.extensions().contains(&extension_type)
197    }
198
199    /// Check if these [`Capabilities`] contain the proposal.
200    pub(crate) fn contains_proposal(&self, proposal_type: ProposalType) -> bool {
201        proposal_type.is_default() || self.proposals().contains(&proposal_type)
202    }
203
204    /// Check if these [`Capabilities`] contain the version.
205    pub(crate) fn contains_version(&self, version: ProtocolVersion) -> bool {
206        self.versions().contains(&version)
207    }
208
209    /// Check if these [`Capabilities`] contain the ciphersuite.
210    pub(crate) fn contains_ciphersuite(&self, ciphersuite: VerifiableCiphersuite) -> bool {
211        self.ciphersuites().contains(&ciphersuite)
212    }
213
214    /// Add random GREASE values to the capabilities to ensure extensibility.
215    ///
216    /// This adds one random GREASE value to each capability list if no GREASE
217    /// value is already present:
218    /// - Ciphersuites
219    /// - Extensions
220    /// - Proposals
221    /// - Credentials
222    ///
223    /// GREASE values are used per [RFC 9420 Section 13.5](https://www.rfc-editor.org/rfc/rfc9420.html#section-13.5)
224    /// to help prevent extensibility failures by ensuring implementations properly
225    /// handle unknown values.
226    ///
227    /// # Example
228    ///
229    /// ```
230    /// use openmls::prelude::*;
231    /// use openmls_rust_crypto::OpenMlsRustCrypto;
232    ///
233    /// let provider = OpenMlsRustCrypto::default();
234    ///
235    /// // Create capabilities with GREASE values injected
236    /// let capabilities = Capabilities::builder()
237    ///     .build()
238    ///     .with_grease(provider.rand());
239    ///
240    /// // Verify GREASE values were added
241    /// assert!(capabilities.ciphersuites().iter().any(|cs| cs.is_grease()));
242    /// assert!(capabilities.extensions().iter().any(|ext| ext.is_grease()));
243    /// assert!(capabilities.proposals().iter().any(|prop| prop.is_grease()));
244    /// assert!(capabilities.credentials().iter().any(|cred| cred.is_grease()));
245    /// ```
246    pub fn with_grease(mut self, rand: &impl openmls_traits::random::OpenMlsRand) -> Self {
247        use crate::credentials::CredentialType;
248        use crate::extensions::ExtensionType;
249        use crate::messages::proposals::ProposalType;
250        use openmls_traits::types::VerifiableCiphersuite;
251
252        // Add GREASE ciphersuite if none present
253        if !self.ciphersuites.iter().any(|cs| cs.is_grease()) {
254            let grease_cs = VerifiableCiphersuite::new(crate::grease::random_grease_value(rand));
255            self.ciphersuites.push(grease_cs);
256        }
257
258        // Add GREASE extension if none present
259        if !self.extensions.iter().any(|ext| ext.is_grease()) {
260            let grease_ext = ExtensionType::Grease(crate::grease::random_grease_value(rand));
261            self.extensions.push(grease_ext);
262        }
263
264        // Add GREASE proposal if none present
265        if !self.proposals.iter().any(|prop| prop.is_grease()) {
266            let grease_prop = ProposalType::Grease(crate::grease::random_grease_value(rand));
267            self.proposals.push(grease_prop);
268        }
269
270        // Add GREASE credential if none present
271        if !self.credentials.iter().any(|cred| cred.is_grease()) {
272            let grease_cred = CredentialType::Grease(crate::grease::random_grease_value(rand));
273            self.credentials.push(grease_cred);
274        }
275
276        self
277    }
278}
279
280/// A helper for building [`Capabilities`]
281#[derive(Debug, Clone)]
282pub struct CapabilitiesBuilder(Capabilities);
283
284impl CapabilitiesBuilder {
285    /// Sets the `versions` field on the [`Capabilities`].
286    pub fn versions(self, versions: Vec<ProtocolVersion>) -> Self {
287        Self(Capabilities { versions, ..self.0 })
288    }
289
290    /// Sets the `ciphersuites` field on the [`Capabilities`].
291    pub fn ciphersuites(self, ciphersuites: Vec<Ciphersuite>) -> Self {
292        let ciphersuites = ciphersuites.into_iter().map(|cs| cs.into()).collect();
293
294        Self(Capabilities {
295            ciphersuites,
296            ..self.0
297        })
298    }
299
300    /// Sets the `extensions` field on the [`Capabilities`].
301    pub fn extensions(self, extensions: Vec<ExtensionType>) -> Self {
302        Self(Capabilities {
303            extensions,
304            ..self.0
305        })
306    }
307
308    /// Sets the `proposals` field on the [`Capabilities`].
309    pub fn proposals(self, proposals: Vec<ProposalType>) -> Self {
310        Self(Capabilities {
311            proposals,
312            ..self.0
313        })
314    }
315
316    /// Sets the `credentials` field on the [`Capabilities`].
317    pub fn credentials(self, credentials: Vec<CredentialType>) -> Self {
318        Self(Capabilities {
319            credentials,
320            ..self.0
321        })
322    }
323
324    /// Adds random GREASE values to the capabilities being built.
325    ///
326    /// This is a convenience method that calls [`Capabilities::with_grease`] on the
327    /// built capabilities. See that method for more details.
328    ///
329    /// # Example
330    ///
331    /// ```
332    /// use openmls::prelude::*;
333    /// use openmls_rust_crypto::OpenMlsRustCrypto;
334    ///
335    /// let provider = OpenMlsRustCrypto::default();
336    ///
337    /// let capabilities = Capabilities::builder()
338    ///     .with_grease(provider.rand())
339    ///     .build();
340    ///
341    /// // GREASE values were added
342    /// assert!(capabilities.ciphersuites().iter().any(|cs| cs.is_grease()));
343    /// ```
344    pub fn with_grease(self, rand: &impl openmls_traits::random::OpenMlsRand) -> Self {
345        Self(self.0.with_grease(rand))
346    }
347
348    /// Builds the [`Capabilities`].
349    pub fn build(self) -> Capabilities {
350        self.0
351    }
352}
353
354#[cfg(test)]
355impl Capabilities {
356    /// Set the versions list.
357    pub fn set_versions(&mut self, versions: Vec<ProtocolVersion>) {
358        self.versions = versions;
359    }
360
361    /// Set the ciphersuites list.
362    pub fn set_ciphersuites(&mut self, ciphersuites: Vec<VerifiableCiphersuite>) {
363        self.ciphersuites = ciphersuites;
364    }
365}
366
367impl Default for Capabilities {
368    fn default() -> Self {
369        Capabilities {
370            versions: default_versions(),
371            ciphersuites: default_ciphersuites()
372                .into_iter()
373                .map(VerifiableCiphersuite::from)
374                .collect(),
375            extensions: vec![],
376            proposals: vec![],
377            credentials: default_credentials(),
378        }
379    }
380}
381
382pub(super) fn default_versions() -> Vec<ProtocolVersion> {
383    vec![ProtocolVersion::Mls10]
384}
385
386pub(super) fn default_ciphersuites() -> Vec<Ciphersuite> {
387    vec![
388        Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519,
389        Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256,
390        Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519,
391        Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519,
392    ]
393}
394
395// TODO(#1231)
396pub(super) fn default_credentials() -> Vec<CredentialType> {
397    vec![CredentialType::Basic]
398}
399
400#[cfg(test)]
401mod tests {
402    use openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
403    use tls_codec::{Deserialize, Serialize};
404
405    use super::Capabilities;
406    use crate::{
407        credentials::CredentialType, messages::proposals::ProposalType, prelude::ExtensionType,
408        versions::ProtocolVersion,
409    };
410
411    #[test]
412    fn that_unknown_capabilities_are_de_serialized_correctly() {
413        let versions = vec![ProtocolVersion::Mls10, ProtocolVersion::Other(999)];
414        let ciphersuites = vec![
415            Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519.into(),
416            Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256.into(),
417            Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519.into(),
418            Ciphersuite::MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448.into(),
419            Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521.into(),
420            Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448.into(),
421            Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384.into(),
422            VerifiableCiphersuite::new(0x0000),
423            // Use non-GREASE values (GREASE pattern is 0x_A_A)
424            VerifiableCiphersuite::new(0x0B0B),
425            VerifiableCiphersuite::new(0x7C7C),
426            VerifiableCiphersuite::new(0xF000),
427            VerifiableCiphersuite::new(0xFFFF),
428        ];
429
430        let extensions = vec![
431            ExtensionType::Unknown(0x0000),
432            ExtensionType::Unknown(0xFAFA),
433        ];
434
435        // Use non-GREASE values
436        let proposals = vec![ProposalType::Custom(0x7C7C)];
437
438        let credentials = vec![
439            CredentialType::Basic,
440            CredentialType::X509,
441            CredentialType::Other(0x0000),
442            // Use non-GREASE values
443            CredentialType::Other(0x7C7C),
444            CredentialType::Other(0xFFFF),
445        ];
446
447        let expected = Capabilities {
448            versions,
449            ciphersuites,
450            extensions,
451            proposals,
452            credentials,
453        };
454
455        let test_serialized = expected.tls_serialize_detached().unwrap();
456
457        let got = Capabilities::tls_deserialize_exact(test_serialized).unwrap();
458
459        assert_eq!(expected, got);
460    }
461}