openmls/treesync/node/leaf_node/
capabilities.rs

1use openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
2use serde::{Deserialize, Serialize};
3use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
4
5#[cfg(doc)]
6use super::LeafNode;
7use crate::{
8    credentials::CredentialType,
9    extensions::{Extension, ExtensionType, Extensions, RequiredCapabilitiesExtension},
10    messages::proposals::ProposalType,
11    treesync::errors::LeafNodeValidationError,
12    versions::ProtocolVersion,
13};
14
15/// Capabilities of [`LeafNode`]s.
16///
17/// ```text
18/// struct {
19///     ProtocolVersion versions<V>;
20///     CipherSuite ciphersuites<V>;
21///     ExtensionType extensions<V>;
22///     ProposalType proposals<V>;
23///     CredentialType credentials<V>;
24/// } Capabilities;
25/// ```
26#[derive(
27    Debug,
28    Clone,
29    PartialEq,
30    Eq,
31    Serialize,
32    Deserialize,
33    TlsSerialize,
34    TlsDeserialize,
35    TlsDeserializeBytes,
36    TlsSize,
37)]
38pub struct Capabilities {
39    pub(super) versions: Vec<ProtocolVersion>,
40    pub(super) ciphersuites: Vec<VerifiableCiphersuite>,
41    pub(super) extensions: Vec<ExtensionType>,
42    pub(super) proposals: Vec<ProposalType>,
43    pub(super) credentials: Vec<CredentialType>,
44}
45
46impl Capabilities {
47    /// Create a new [`Capabilities`] struct with the given configuration.
48    /// Any argument that is `None` is filled with the default values from the
49    /// global configuration.
50    // TODO(#1232)
51    pub fn new(
52        versions: Option<&[ProtocolVersion]>,
53        ciphersuites: Option<&[Ciphersuite]>,
54        extensions: Option<&[ExtensionType]>,
55        proposals: Option<&[ProposalType]>,
56        credentials: Option<&[CredentialType]>,
57    ) -> Self {
58        Self {
59            versions: match versions {
60                Some(v) => v.into(),
61                None => default_versions(),
62            },
63            ciphersuites: match ciphersuites {
64                Some(c) => c.iter().map(|c| VerifiableCiphersuite::from(*c)).collect(),
65                None => default_ciphersuites()
66                    .into_iter()
67                    .map(VerifiableCiphersuite::from)
68                    .collect(),
69            },
70            extensions: match extensions {
71                Some(e) => e.into(),
72                None => vec![],
73            },
74            proposals: match proposals {
75                Some(p) => p.into(),
76                None => vec![],
77            },
78            credentials: match credentials {
79                Some(c) => c.into(),
80                None => default_credentials(),
81            },
82        }
83    }
84
85    /// Create new empty [`Capabilities`].
86    pub fn empty() -> Self {
87        Self {
88            versions: Vec::new(),
89            ciphersuites: Vec::new(),
90            extensions: Vec::new(),
91            proposals: Vec::new(),
92            credentials: Vec::new(),
93        }
94    }
95
96    /// Creates a new [`CapabilitiesBuilder`] for constructing [`Capabilities`]
97    pub fn builder() -> CapabilitiesBuilder {
98        CapabilitiesBuilder(Self::default())
99    }
100
101    // ---------------------------------------------------------------------------------------------
102
103    /// Get a reference to the list of versions in this extension.
104    pub fn versions(&self) -> &[ProtocolVersion] {
105        &self.versions
106    }
107
108    /// Get a reference to the list of ciphersuites in this extension.
109    pub fn ciphersuites(&self) -> &[VerifiableCiphersuite] {
110        &self.ciphersuites
111    }
112
113    /// Get a reference to the list of supported extensions.
114    pub fn extensions(&self) -> &[ExtensionType] {
115        &self.extensions
116    }
117
118    /// Get a reference to the list of supported proposals.
119    pub fn proposals(&self) -> &[ProposalType] {
120        &self.proposals
121    }
122
123    /// Get a reference to the list of supported credential types.
124    pub fn credentials(&self) -> &[CredentialType] {
125        &self.credentials
126    }
127
128    // ---------------------------------------------------------------------------------------------
129
130    /// Check if these [`Capabilities`] support all the capabilities required by
131    /// the given [`RequiredCapabilitiesExtension`].
132    ///
133    /// # Errors
134    ///
135    /// Returns a [`LeafNodeValidationError`] error if any of the required
136    /// capabilities is not supported.
137    pub(crate) fn supports_required_capabilities(
138        &self,
139        required_capabilities: &RequiredCapabilitiesExtension,
140    ) -> Result<(), LeafNodeValidationError> {
141        // Check if all required extensions are supported.
142        let unsupported_extension_types = required_capabilities
143            .extension_types()
144            .iter()
145            .filter(|&e| !self.contains_extension(*e))
146            .collect::<Vec<_>>();
147        if !unsupported_extension_types.is_empty() {
148            log::error!(
149                "Leaf node does not support all required extension types\n
150                Supported extensions: {:?}\n
151                Required extensions: {:?}",
152                self.extensions(),
153                required_capabilities.extension_types()
154            );
155            return Err(LeafNodeValidationError::UnsupportedExtensions);
156        }
157        // Check if all required proposals are supported.
158        if required_capabilities
159            .proposal_types()
160            .iter()
161            .any(|p| !self.contains_proposal(*p))
162        {
163            return Err(LeafNodeValidationError::UnsupportedProposals);
164        }
165        // Check if all required credential types are supported.
166        if required_capabilities
167            .credential_types()
168            .iter()
169            .any(|c| !self.contains_credential(*c))
170        {
171            return Err(LeafNodeValidationError::UnsupportedCredentials);
172        }
173        Ok(())
174    }
175
176    /// Check if these [`Capabilities`] contain all the extensions.
177    pub(crate) fn contains_extensions(&self, extension: &Extensions) -> bool {
178        extension
179            .iter()
180            .map(Extension::extension_type)
181            .all(|e| e.is_default() || self.extensions().contains(&e))
182    }
183
184    /// Check if these [`Capabilities`] contains the credential.
185    pub(crate) fn contains_credential(&self, credential_type: CredentialType) -> bool {
186        self.credentials().contains(&credential_type)
187    }
188
189    /// Check if these [`Capabilities`] contain the extension.
190    pub(crate) fn contains_extension(&self, extension_type: ExtensionType) -> bool {
191        extension_type.is_default() || self.extensions().contains(&extension_type)
192    }
193
194    /// Check if these [`Capabilities`] contain the proposal.
195    pub(crate) fn contains_proposal(&self, proposal_type: ProposalType) -> bool {
196        proposal_type.is_default() || self.proposals().contains(&proposal_type)
197    }
198
199    /// Check if these [`Capabilities`] contain the version.
200    pub(crate) fn contains_version(&self, version: ProtocolVersion) -> bool {
201        self.versions().contains(&version)
202    }
203
204    /// Check if these [`Capabilities`] contain the ciphersuite.
205    pub(crate) fn contains_ciphersuite(&self, ciphersuite: VerifiableCiphersuite) -> bool {
206        self.ciphersuites().contains(&ciphersuite)
207    }
208}
209
210/// A helper for building [`Capabilities`]
211#[derive(Debug, Clone)]
212pub struct CapabilitiesBuilder(Capabilities);
213
214impl CapabilitiesBuilder {
215    /// Sets the `versions` field on the [`Capabilities`].
216    pub fn versions(self, versions: Vec<ProtocolVersion>) -> Self {
217        Self(Capabilities { versions, ..self.0 })
218    }
219
220    /// Sets the `ciphersuites` field on the [`Capabilities`].
221    pub fn ciphersuites(self, ciphersuites: Vec<Ciphersuite>) -> Self {
222        let ciphersuites = ciphersuites.into_iter().map(|cs| cs.into()).collect();
223
224        Self(Capabilities {
225            ciphersuites,
226            ..self.0
227        })
228    }
229
230    /// Sets the `extensions` field on the [`Capabilities`].
231    pub fn extensions(self, extensions: Vec<ExtensionType>) -> Self {
232        Self(Capabilities {
233            extensions,
234            ..self.0
235        })
236    }
237
238    /// Sets the `proposals` field on the [`Capabilities`].
239    pub fn proposals(self, proposals: Vec<ProposalType>) -> Self {
240        Self(Capabilities {
241            proposals,
242            ..self.0
243        })
244    }
245
246    /// Sets the `credentials` field on the [`Capabilities`].
247    pub fn credentials(self, credentials: Vec<CredentialType>) -> Self {
248        Self(Capabilities {
249            credentials,
250            ..self.0
251        })
252    }
253
254    /// Builds the [`Capabilities`].
255    pub fn build(self) -> Capabilities {
256        self.0
257    }
258}
259
260#[cfg(test)]
261impl Capabilities {
262    /// Set the versions list.
263    pub fn set_versions(&mut self, versions: Vec<ProtocolVersion>) {
264        self.versions = versions;
265    }
266
267    /// Set the ciphersuites list.
268    pub fn set_ciphersuites(&mut self, ciphersuites: Vec<VerifiableCiphersuite>) {
269        self.ciphersuites = ciphersuites;
270    }
271}
272
273impl Default for Capabilities {
274    fn default() -> Self {
275        Capabilities {
276            versions: default_versions(),
277            ciphersuites: default_ciphersuites()
278                .into_iter()
279                .map(VerifiableCiphersuite::from)
280                .collect(),
281            extensions: vec![],
282            proposals: vec![],
283            credentials: default_credentials(),
284        }
285    }
286}
287
288pub(super) fn default_versions() -> Vec<ProtocolVersion> {
289    vec![ProtocolVersion::Mls10]
290}
291
292pub(super) fn default_ciphersuites() -> Vec<Ciphersuite> {
293    vec![
294        Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519,
295        Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256,
296        Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519,
297        Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519,
298    ]
299}
300
301// TODO(#1231)
302pub(super) fn default_credentials() -> Vec<CredentialType> {
303    vec![CredentialType::Basic]
304}
305
306#[cfg(test)]
307mod tests {
308    use openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
309    use tls_codec::{Deserialize, Serialize};
310
311    use super::Capabilities;
312    use crate::{
313        credentials::CredentialType, messages::proposals::ProposalType, prelude::ExtensionType,
314        versions::ProtocolVersion,
315    };
316
317    #[test]
318    fn that_unknown_capabilities_are_de_serialized_correctly() {
319        let versions = vec![ProtocolVersion::Mls10, ProtocolVersion::Other(999)];
320        let ciphersuites = vec![
321            Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519.into(),
322            Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256.into(),
323            Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519.into(),
324            Ciphersuite::MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448.into(),
325            Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521.into(),
326            Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448.into(),
327            Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384.into(),
328            VerifiableCiphersuite::new(0x0000),
329            VerifiableCiphersuite::new(0x0A0A),
330            VerifiableCiphersuite::new(0x7A7A),
331            VerifiableCiphersuite::new(0xF000),
332            VerifiableCiphersuite::new(0xFFFF),
333        ];
334
335        let extensions = vec![
336            ExtensionType::Unknown(0x0000),
337            ExtensionType::Unknown(0xFAFA),
338        ];
339
340        let proposals = vec![ProposalType::Custom(0x7A7A)];
341
342        let credentials = vec![
343            CredentialType::Basic,
344            CredentialType::X509,
345            CredentialType::Other(0x0000),
346            CredentialType::Other(0x7A7A),
347            CredentialType::Other(0xFFFF),
348        ];
349
350        let expected = Capabilities {
351            versions,
352            ciphersuites,
353            extensions,
354            proposals,
355            credentials,
356        };
357
358        let test_serialized = expected.tls_serialize_detached().unwrap();
359
360        let got = Capabilities::tls_deserialize_exact(test_serialized).unwrap();
361
362        assert_eq!(expected, got);
363    }
364}