openmls/treesync/node/leaf_node/
capabilities.rs1use openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
2use serde::{Deserialize, Serialize};
3use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
4
5#[cfg(doc)]
6use super::LeafNode;
7use crate::{
8 credentials::CredentialType,
9 extensions::{Extension, ExtensionType, Extensions, RequiredCapabilitiesExtension},
10 messages::proposals::ProposalType,
11 treesync::errors::LeafNodeValidationError,
12 versions::ProtocolVersion,
13};
14
15#[derive(
27 Debug,
28 Clone,
29 PartialEq,
30 Eq,
31 Serialize,
32 Deserialize,
33 TlsSerialize,
34 TlsDeserialize,
35 TlsDeserializeBytes,
36 TlsSize,
37)]
38pub struct Capabilities {
39 pub(super) versions: Vec<ProtocolVersion>,
40 pub(super) ciphersuites: Vec<VerifiableCiphersuite>,
41 pub(super) extensions: Vec<ExtensionType>,
42 pub(super) proposals: Vec<ProposalType>,
43 pub(super) credentials: Vec<CredentialType>,
44}
45
46impl Capabilities {
47 pub fn new(
52 versions: Option<&[ProtocolVersion]>,
53 ciphersuites: Option<&[Ciphersuite]>,
54 extensions: Option<&[ExtensionType]>,
55 proposals: Option<&[ProposalType]>,
56 credentials: Option<&[CredentialType]>,
57 ) -> Self {
58 Self {
59 versions: match versions {
60 Some(v) => v.into(),
61 None => default_versions(),
62 },
63 ciphersuites: match ciphersuites {
64 Some(c) => c.iter().map(|c| VerifiableCiphersuite::from(*c)).collect(),
65 None => default_ciphersuites()
66 .into_iter()
67 .map(VerifiableCiphersuite::from)
68 .collect(),
69 },
70 extensions: match extensions {
71 Some(e) => e.into(),
72 None => vec![],
73 },
74 proposals: match proposals {
75 Some(p) => p.into(),
76 None => vec![],
77 },
78 credentials: match credentials {
79 Some(c) => c.into(),
80 None => default_credentials(),
81 },
82 }
83 }
84
85 pub fn empty() -> Self {
87 Self {
88 versions: Vec::new(),
89 ciphersuites: Vec::new(),
90 extensions: Vec::new(),
91 proposals: Vec::new(),
92 credentials: Vec::new(),
93 }
94 }
95
96 pub fn builder() -> CapabilitiesBuilder {
98 CapabilitiesBuilder(Self::default())
99 }
100
101 pub fn versions(&self) -> &[ProtocolVersion] {
105 &self.versions
106 }
107
108 pub fn ciphersuites(&self) -> &[VerifiableCiphersuite] {
110 &self.ciphersuites
111 }
112
113 pub fn extensions(&self) -> &[ExtensionType] {
115 &self.extensions
116 }
117
118 pub fn proposals(&self) -> &[ProposalType] {
120 &self.proposals
121 }
122
123 pub fn credentials(&self) -> &[CredentialType] {
125 &self.credentials
126 }
127
128 pub(crate) fn supports_required_capabilities(
138 &self,
139 required_capabilities: &RequiredCapabilitiesExtension,
140 ) -> Result<(), LeafNodeValidationError> {
141 let unsupported_extension_types = required_capabilities
143 .extension_types()
144 .iter()
145 .filter(|&e| !self.contains_extension(*e))
146 .collect::<Vec<_>>();
147 if !unsupported_extension_types.is_empty() {
148 log::error!(
149 "Leaf node does not support all required extension types\n
150 Supported extensions: {:?}\n
151 Required extensions: {:?}",
152 self.extensions(),
153 required_capabilities.extension_types()
154 );
155 return Err(LeafNodeValidationError::UnsupportedExtensions);
156 }
157 if required_capabilities
159 .proposal_types()
160 .iter()
161 .any(|p| !self.contains_proposal(*p))
162 {
163 return Err(LeafNodeValidationError::UnsupportedProposals);
164 }
165 if required_capabilities
167 .credential_types()
168 .iter()
169 .any(|c| !self.contains_credential(*c))
170 {
171 return Err(LeafNodeValidationError::UnsupportedCredentials);
172 }
173 Ok(())
174 }
175
176 pub(crate) fn contains_extensions(&self, extension: &Extensions) -> bool {
178 extension
179 .iter()
180 .map(Extension::extension_type)
181 .all(|e| e.is_default() || self.extensions().contains(&e))
182 }
183
184 pub(crate) fn contains_credential(&self, credential_type: CredentialType) -> bool {
186 self.credentials().contains(&credential_type)
187 }
188
189 pub(crate) fn contains_extension(&self, extension_type: ExtensionType) -> bool {
191 extension_type.is_default() || self.extensions().contains(&extension_type)
192 }
193
194 pub(crate) fn contains_proposal(&self, proposal_type: ProposalType) -> bool {
196 proposal_type.is_default() || self.proposals().contains(&proposal_type)
197 }
198
199 pub(crate) fn contains_version(&self, version: ProtocolVersion) -> bool {
201 self.versions().contains(&version)
202 }
203
204 pub(crate) fn contains_ciphersuite(&self, ciphersuite: VerifiableCiphersuite) -> bool {
206 self.ciphersuites().contains(&ciphersuite)
207 }
208
209 pub fn with_grease(mut self, rand: &impl openmls_traits::random::OpenMlsRand) -> Self {
242 use crate::credentials::CredentialType;
243 use crate::extensions::ExtensionType;
244 use crate::messages::proposals::ProposalType;
245 use openmls_traits::types::VerifiableCiphersuite;
246
247 if !self.ciphersuites.iter().any(|cs| cs.is_grease()) {
249 let grease_cs = VerifiableCiphersuite::new(crate::grease::random_grease_value(rand));
250 self.ciphersuites.push(grease_cs);
251 }
252
253 if !self.extensions.iter().any(|ext| ext.is_grease()) {
255 let grease_ext = ExtensionType::Grease(crate::grease::random_grease_value(rand));
256 self.extensions.push(grease_ext);
257 }
258
259 if !self.proposals.iter().any(|prop| prop.is_grease()) {
261 let grease_prop = ProposalType::Grease(crate::grease::random_grease_value(rand));
262 self.proposals.push(grease_prop);
263 }
264
265 if !self.credentials.iter().any(|cred| cred.is_grease()) {
267 let grease_cred = CredentialType::Grease(crate::grease::random_grease_value(rand));
268 self.credentials.push(grease_cred);
269 }
270
271 self
272 }
273}
274
275#[derive(Debug, Clone)]
277pub struct CapabilitiesBuilder(Capabilities);
278
279impl CapabilitiesBuilder {
280 pub fn versions(self, versions: Vec<ProtocolVersion>) -> Self {
282 Self(Capabilities { versions, ..self.0 })
283 }
284
285 pub fn ciphersuites(self, ciphersuites: Vec<Ciphersuite>) -> Self {
287 let ciphersuites = ciphersuites.into_iter().map(|cs| cs.into()).collect();
288
289 Self(Capabilities {
290 ciphersuites,
291 ..self.0
292 })
293 }
294
295 pub fn extensions(self, extensions: Vec<ExtensionType>) -> Self {
297 Self(Capabilities {
298 extensions,
299 ..self.0
300 })
301 }
302
303 pub fn proposals(self, proposals: Vec<ProposalType>) -> Self {
305 Self(Capabilities {
306 proposals,
307 ..self.0
308 })
309 }
310
311 pub fn credentials(self, credentials: Vec<CredentialType>) -> Self {
313 Self(Capabilities {
314 credentials,
315 ..self.0
316 })
317 }
318
319 pub fn with_grease(self, rand: &impl openmls_traits::random::OpenMlsRand) -> Self {
340 Self(self.0.with_grease(rand))
341 }
342
343 pub fn build(self) -> Capabilities {
345 self.0
346 }
347}
348
349#[cfg(test)]
350impl Capabilities {
351 pub fn set_versions(&mut self, versions: Vec<ProtocolVersion>) {
353 self.versions = versions;
354 }
355
356 pub fn set_ciphersuites(&mut self, ciphersuites: Vec<VerifiableCiphersuite>) {
358 self.ciphersuites = ciphersuites;
359 }
360}
361
362impl Default for Capabilities {
363 fn default() -> Self {
364 Capabilities {
365 versions: default_versions(),
366 ciphersuites: default_ciphersuites()
367 .into_iter()
368 .map(VerifiableCiphersuite::from)
369 .collect(),
370 extensions: vec![],
371 proposals: vec![],
372 credentials: default_credentials(),
373 }
374 }
375}
376
377pub(super) fn default_versions() -> Vec<ProtocolVersion> {
378 vec![ProtocolVersion::Mls10]
379}
380
381pub(super) fn default_ciphersuites() -> Vec<Ciphersuite> {
382 vec![
383 Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519,
384 Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256,
385 Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519,
386 Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519,
387 ]
388}
389
390pub(super) fn default_credentials() -> Vec<CredentialType> {
392 vec![CredentialType::Basic]
393}
394
395#[cfg(test)]
396mod tests {
397 use openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
398 use tls_codec::{Deserialize, Serialize};
399
400 use super::Capabilities;
401 use crate::{
402 credentials::CredentialType, messages::proposals::ProposalType, prelude::ExtensionType,
403 versions::ProtocolVersion,
404 };
405
406 #[test]
407 fn that_unknown_capabilities_are_de_serialized_correctly() {
408 let versions = vec![ProtocolVersion::Mls10, ProtocolVersion::Other(999)];
409 let ciphersuites = vec![
410 Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519.into(),
411 Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256.into(),
412 Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519.into(),
413 Ciphersuite::MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448.into(),
414 Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521.into(),
415 Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448.into(),
416 Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384.into(),
417 VerifiableCiphersuite::new(0x0000),
418 VerifiableCiphersuite::new(0x0B0B),
420 VerifiableCiphersuite::new(0x7C7C),
421 VerifiableCiphersuite::new(0xF000),
422 VerifiableCiphersuite::new(0xFFFF),
423 ];
424
425 let extensions = vec![
426 ExtensionType::Unknown(0x0000),
427 ExtensionType::Unknown(0xFAFA),
428 ];
429
430 let proposals = vec![ProposalType::Custom(0x7C7C)];
432
433 let credentials = vec![
434 CredentialType::Basic,
435 CredentialType::X509,
436 CredentialType::Other(0x0000),
437 CredentialType::Other(0x7C7C),
439 CredentialType::Other(0xFFFF),
440 ];
441
442 let expected = Capabilities {
443 versions,
444 ciphersuites,
445 extensions,
446 proposals,
447 credentials,
448 };
449
450 let test_serialized = expected.tls_serialize_detached().unwrap();
451
452 let got = Capabilities::tls_deserialize_exact(test_serialized).unwrap();
453
454 assert_eq!(expected, got);
455 }
456}