openmls/treesync/node/leaf_node/
capabilities.rs1use openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
2use serde::{Deserialize, Serialize};
3use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
4
5#[cfg(doc)]
6use super::LeafNode;
7use crate::{
8 credentials::CredentialType,
9 extensions::{
10 Extension, ExtensionType, ExtensionValidator, Extensions, RequiredCapabilitiesExtension,
11 },
12 messages::proposals::ProposalType,
13 treesync::errors::LeafNodeValidationError,
14 versions::ProtocolVersion,
15};
16
17#[derive(
29 Debug,
30 Clone,
31 PartialEq,
32 Eq,
33 Serialize,
34 Deserialize,
35 TlsSerialize,
36 TlsDeserialize,
37 TlsDeserializeBytes,
38 TlsSize,
39)]
40pub struct Capabilities {
41 pub(super) versions: Vec<ProtocolVersion>,
42 pub(super) ciphersuites: Vec<VerifiableCiphersuite>,
43 pub(super) extensions: Vec<ExtensionType>,
44 pub(super) proposals: Vec<ProposalType>,
45 pub(super) credentials: Vec<CredentialType>,
46}
47
48impl Capabilities {
49 pub fn new(
54 versions: Option<&[ProtocolVersion]>,
55 ciphersuites: Option<&[Ciphersuite]>,
56 extensions: Option<&[ExtensionType]>,
57 proposals: Option<&[ProposalType]>,
58 credentials: Option<&[CredentialType]>,
59 ) -> Self {
60 Self {
61 versions: match versions {
62 Some(v) => v.into(),
63 None => default_versions(),
64 },
65 ciphersuites: match ciphersuites {
66 Some(c) => c.iter().map(|c| VerifiableCiphersuite::from(*c)).collect(),
67 None => default_ciphersuites()
68 .into_iter()
69 .map(VerifiableCiphersuite::from)
70 .collect(),
71 },
72 extensions: match extensions {
73 Some(e) => e.into(),
74 None => vec![],
75 },
76 proposals: match proposals {
77 Some(p) => p.into(),
78 None => vec![],
79 },
80 credentials: match credentials {
81 Some(c) => c.into(),
82 None => default_credentials(),
83 },
84 }
85 }
86
87 pub fn empty() -> Self {
89 Self {
90 versions: Vec::new(),
91 ciphersuites: Vec::new(),
92 extensions: Vec::new(),
93 proposals: Vec::new(),
94 credentials: Vec::new(),
95 }
96 }
97
98 pub fn builder() -> CapabilitiesBuilder {
100 CapabilitiesBuilder(Self::default())
101 }
102
103 pub fn versions(&self) -> &[ProtocolVersion] {
107 &self.versions
108 }
109
110 pub fn ciphersuites(&self) -> &[VerifiableCiphersuite] {
112 &self.ciphersuites
113 }
114
115 pub fn extensions(&self) -> &[ExtensionType] {
117 &self.extensions
118 }
119
120 pub fn proposals(&self) -> &[ProposalType] {
122 &self.proposals
123 }
124
125 pub fn credentials(&self) -> &[CredentialType] {
127 &self.credentials
128 }
129
130 pub(crate) fn supports_required_capabilities(
140 &self,
141 required_capabilities: &RequiredCapabilitiesExtension,
142 ) -> Result<(), LeafNodeValidationError> {
143 let unsupported_extension_types = required_capabilities
145 .extension_types()
146 .iter()
147 .filter(|&e| !self.contains_extension(*e))
148 .collect::<Vec<_>>();
149 if !unsupported_extension_types.is_empty() {
150 log::error!(
151 "Leaf node does not support all required extension types\n
152 Supported extensions: {:?}\n
153 Required extensions: {:?}",
154 self.extensions(),
155 required_capabilities.extension_types()
156 );
157 return Err(LeafNodeValidationError::UnsupportedExtensions);
158 }
159 if required_capabilities
161 .proposal_types()
162 .iter()
163 .any(|p| !self.contains_proposal(*p))
164 {
165 return Err(LeafNodeValidationError::UnsupportedProposals);
166 }
167 if required_capabilities
169 .credential_types()
170 .iter()
171 .any(|c| !self.contains_credential(*c))
172 {
173 return Err(LeafNodeValidationError::UnsupportedCredentials);
174 }
175 Ok(())
176 }
177
178 pub(crate) fn contains_extensions(
180 &self,
181 extensions: &Extensions<impl ExtensionValidator>,
182 ) -> bool {
183 extensions
184 .iter()
185 .map(Extension::extension_type)
186 .all(|e| e.is_default() || self.extensions().contains(&e))
187 }
188
189 pub(crate) fn contains_credential(&self, credential_type: CredentialType) -> bool {
191 self.credentials().contains(&credential_type)
192 }
193
194 pub(crate) fn contains_extension(&self, extension_type: ExtensionType) -> bool {
196 extension_type.is_default() || self.extensions().contains(&extension_type)
197 }
198
199 pub(crate) fn contains_proposal(&self, proposal_type: ProposalType) -> bool {
201 proposal_type.is_default() || self.proposals().contains(&proposal_type)
202 }
203
204 pub(crate) fn contains_version(&self, version: ProtocolVersion) -> bool {
206 self.versions().contains(&version)
207 }
208
209 pub(crate) fn contains_ciphersuite(&self, ciphersuite: VerifiableCiphersuite) -> bool {
211 self.ciphersuites().contains(&ciphersuite)
212 }
213
214 pub fn with_grease(mut self, rand: &impl openmls_traits::random::OpenMlsRand) -> Self {
247 use crate::credentials::CredentialType;
248 use crate::extensions::ExtensionType;
249 use crate::messages::proposals::ProposalType;
250 use openmls_traits::types::VerifiableCiphersuite;
251
252 if !self.ciphersuites.iter().any(|cs| cs.is_grease()) {
254 let grease_cs = VerifiableCiphersuite::new(crate::grease::random_grease_value(rand));
255 self.ciphersuites.push(grease_cs);
256 }
257
258 if !self.extensions.iter().any(|ext| ext.is_grease()) {
260 let grease_ext = ExtensionType::Grease(crate::grease::random_grease_value(rand));
261 self.extensions.push(grease_ext);
262 }
263
264 if !self.proposals.iter().any(|prop| prop.is_grease()) {
266 let grease_prop = ProposalType::Grease(crate::grease::random_grease_value(rand));
267 self.proposals.push(grease_prop);
268 }
269
270 if !self.credentials.iter().any(|cred| cred.is_grease()) {
272 let grease_cred = CredentialType::Grease(crate::grease::random_grease_value(rand));
273 self.credentials.push(grease_cred);
274 }
275
276 self
277 }
278}
279
280#[derive(Debug, Clone)]
282pub struct CapabilitiesBuilder(Capabilities);
283
284impl CapabilitiesBuilder {
285 pub fn versions(self, versions: Vec<ProtocolVersion>) -> Self {
287 Self(Capabilities { versions, ..self.0 })
288 }
289
290 pub fn ciphersuites(self, ciphersuites: Vec<Ciphersuite>) -> Self {
292 let ciphersuites = ciphersuites.into_iter().map(|cs| cs.into()).collect();
293
294 Self(Capabilities {
295 ciphersuites,
296 ..self.0
297 })
298 }
299
300 pub fn extensions(self, extensions: Vec<ExtensionType>) -> Self {
302 Self(Capabilities {
303 extensions,
304 ..self.0
305 })
306 }
307
308 pub fn proposals(self, proposals: Vec<ProposalType>) -> Self {
310 Self(Capabilities {
311 proposals,
312 ..self.0
313 })
314 }
315
316 pub fn credentials(self, credentials: Vec<CredentialType>) -> Self {
318 Self(Capabilities {
319 credentials,
320 ..self.0
321 })
322 }
323
324 pub fn with_grease(self, rand: &impl openmls_traits::random::OpenMlsRand) -> Self {
345 Self(self.0.with_grease(rand))
346 }
347
348 pub fn build(self) -> Capabilities {
350 self.0
351 }
352}
353
354#[cfg(test)]
355impl Capabilities {
356 pub fn set_versions(&mut self, versions: Vec<ProtocolVersion>) {
358 self.versions = versions;
359 }
360
361 pub fn set_ciphersuites(&mut self, ciphersuites: Vec<VerifiableCiphersuite>) {
363 self.ciphersuites = ciphersuites;
364 }
365}
366
367impl Default for Capabilities {
368 fn default() -> Self {
369 Capabilities {
370 versions: default_versions(),
371 ciphersuites: default_ciphersuites()
372 .into_iter()
373 .map(VerifiableCiphersuite::from)
374 .collect(),
375 extensions: vec![],
376 proposals: vec![],
377 credentials: default_credentials(),
378 }
379 }
380}
381
382pub(super) fn default_versions() -> Vec<ProtocolVersion> {
383 vec![ProtocolVersion::Mls10]
384}
385
386pub(super) fn default_ciphersuites() -> Vec<Ciphersuite> {
387 vec![
388 Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519,
389 Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256,
390 Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519,
391 Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519,
392 ]
393}
394
395pub(super) fn default_credentials() -> Vec<CredentialType> {
397 vec![CredentialType::Basic]
398}
399
400#[cfg(test)]
401mod tests {
402 use openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
403 use tls_codec::{Deserialize, Serialize};
404
405 use super::Capabilities;
406 use crate::{
407 credentials::CredentialType, messages::proposals::ProposalType, prelude::ExtensionType,
408 versions::ProtocolVersion,
409 };
410
411 #[test]
412 fn that_unknown_capabilities_are_de_serialized_correctly() {
413 let versions = vec![ProtocolVersion::Mls10, ProtocolVersion::Other(999)];
414 let ciphersuites = vec![
415 Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519.into(),
416 Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256.into(),
417 Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519.into(),
418 Ciphersuite::MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448.into(),
419 Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521.into(),
420 Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448.into(),
421 Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384.into(),
422 VerifiableCiphersuite::new(0x0000),
423 VerifiableCiphersuite::new(0x0B0B),
425 VerifiableCiphersuite::new(0x7C7C),
426 VerifiableCiphersuite::new(0xF000),
427 VerifiableCiphersuite::new(0xFFFF),
428 ];
429
430 let extensions = vec![
431 ExtensionType::Unknown(0x0000),
432 ExtensionType::Unknown(0xFAFA),
433 ];
434
435 let proposals = vec![ProposalType::Custom(0x7C7C)];
437
438 let credentials = vec![
439 CredentialType::Basic,
440 CredentialType::X509,
441 CredentialType::Other(0x0000),
442 CredentialType::Other(0x7C7C),
444 CredentialType::Other(0xFFFF),
445 ];
446
447 let expected = Capabilities {
448 versions,
449 ciphersuites,
450 extensions,
451 proposals,
452 credentials,
453 };
454
455 let test_serialized = expected.tls_serialize_detached().unwrap();
456
457 let got = Capabilities::tls_deserialize_exact(test_serialized).unwrap();
458
459 assert_eq!(expected, got);
460 }
461}