openmls/treesync/node/leaf_node/
capabilities.rs1use openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
2use serde::{Deserialize, Serialize};
3use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
4
5#[cfg(doc)]
6use super::LeafNode;
7use crate::{
8 credentials::CredentialType,
9 extensions::{Extension, ExtensionType, Extensions, RequiredCapabilitiesExtension},
10 messages::proposals::ProposalType,
11 treesync::errors::LeafNodeValidationError,
12 versions::ProtocolVersion,
13};
14
15#[derive(
27 Debug,
28 Clone,
29 PartialEq,
30 Eq,
31 Serialize,
32 Deserialize,
33 TlsSerialize,
34 TlsDeserialize,
35 TlsDeserializeBytes,
36 TlsSize,
37)]
38pub struct Capabilities {
39 pub(super) versions: Vec<ProtocolVersion>,
40 pub(super) ciphersuites: Vec<VerifiableCiphersuite>,
41 pub(super) extensions: Vec<ExtensionType>,
42 pub(super) proposals: Vec<ProposalType>,
43 pub(super) credentials: Vec<CredentialType>,
44}
45
46impl Capabilities {
47 pub fn new(
52 versions: Option<&[ProtocolVersion]>,
53 ciphersuites: Option<&[Ciphersuite]>,
54 extensions: Option<&[ExtensionType]>,
55 proposals: Option<&[ProposalType]>,
56 credentials: Option<&[CredentialType]>,
57 ) -> Self {
58 Self {
59 versions: match versions {
60 Some(v) => v.into(),
61 None => default_versions(),
62 },
63 ciphersuites: match ciphersuites {
64 Some(c) => c.iter().map(|c| VerifiableCiphersuite::from(*c)).collect(),
65 None => default_ciphersuites()
66 .into_iter()
67 .map(VerifiableCiphersuite::from)
68 .collect(),
69 },
70 extensions: match extensions {
71 Some(e) => e.into(),
72 None => vec![],
73 },
74 proposals: match proposals {
75 Some(p) => p.into(),
76 None => vec![],
77 },
78 credentials: match credentials {
79 Some(c) => c.into(),
80 None => default_credentials(),
81 },
82 }
83 }
84
85 pub fn empty() -> Self {
87 Self {
88 versions: Vec::new(),
89 ciphersuites: Vec::new(),
90 extensions: Vec::new(),
91 proposals: Vec::new(),
92 credentials: Vec::new(),
93 }
94 }
95
96 pub fn builder() -> CapabilitiesBuilder {
98 CapabilitiesBuilder(Self::default())
99 }
100
101 pub fn versions(&self) -> &[ProtocolVersion] {
105 &self.versions
106 }
107
108 pub fn ciphersuites(&self) -> &[VerifiableCiphersuite] {
110 &self.ciphersuites
111 }
112
113 pub fn extensions(&self) -> &[ExtensionType] {
115 &self.extensions
116 }
117
118 pub fn proposals(&self) -> &[ProposalType] {
120 &self.proposals
121 }
122
123 pub fn credentials(&self) -> &[CredentialType] {
125 &self.credentials
126 }
127
128 pub(crate) fn supports_required_capabilities(
138 &self,
139 required_capabilities: &RequiredCapabilitiesExtension,
140 ) -> Result<(), LeafNodeValidationError> {
141 let unsupported_extension_types = required_capabilities
143 .extension_types()
144 .iter()
145 .filter(|&e| !self.contains_extension(*e))
146 .collect::<Vec<_>>();
147 if !unsupported_extension_types.is_empty() {
148 log::error!(
149 "Leaf node does not support all required extension types\n
150 Supported extensions: {:?}\n
151 Required extensions: {:?}",
152 self.extensions(),
153 required_capabilities.extension_types()
154 );
155 return Err(LeafNodeValidationError::UnsupportedExtensions);
156 }
157 if required_capabilities
159 .proposal_types()
160 .iter()
161 .any(|p| !self.contains_proposal(*p))
162 {
163 return Err(LeafNodeValidationError::UnsupportedProposals);
164 }
165 if required_capabilities
167 .credential_types()
168 .iter()
169 .any(|c| !self.contains_credential(*c))
170 {
171 return Err(LeafNodeValidationError::UnsupportedCredentials);
172 }
173 Ok(())
174 }
175
176 pub(crate) fn contains_extensions(&self, extension: &Extensions) -> bool {
178 extension
179 .iter()
180 .map(Extension::extension_type)
181 .all(|e| e.is_default() || self.extensions().contains(&e))
182 }
183
184 pub(crate) fn contains_credential(&self, credential_type: CredentialType) -> bool {
186 self.credentials().contains(&credential_type)
187 }
188
189 pub(crate) fn contains_extension(&self, extension_type: ExtensionType) -> bool {
191 extension_type.is_default() || self.extensions().contains(&extension_type)
192 }
193
194 pub(crate) fn contains_proposal(&self, proposal_type: ProposalType) -> bool {
196 proposal_type.is_default() || self.proposals().contains(&proposal_type)
197 }
198
199 pub(crate) fn contains_version(&self, version: ProtocolVersion) -> bool {
201 self.versions().contains(&version)
202 }
203
204 pub(crate) fn contains_ciphersuite(&self, ciphersuite: VerifiableCiphersuite) -> bool {
206 self.ciphersuites().contains(&ciphersuite)
207 }
208}
209
210#[derive(Debug, Clone)]
212pub struct CapabilitiesBuilder(Capabilities);
213
214impl CapabilitiesBuilder {
215 pub fn versions(self, versions: Vec<ProtocolVersion>) -> Self {
217 Self(Capabilities { versions, ..self.0 })
218 }
219
220 pub fn ciphersuites(self, ciphersuites: Vec<Ciphersuite>) -> Self {
222 let ciphersuites = ciphersuites.into_iter().map(|cs| cs.into()).collect();
223
224 Self(Capabilities {
225 ciphersuites,
226 ..self.0
227 })
228 }
229
230 pub fn extensions(self, extensions: Vec<ExtensionType>) -> Self {
232 Self(Capabilities {
233 extensions,
234 ..self.0
235 })
236 }
237
238 pub fn proposals(self, proposals: Vec<ProposalType>) -> Self {
240 Self(Capabilities {
241 proposals,
242 ..self.0
243 })
244 }
245
246 pub fn credentials(self, credentials: Vec<CredentialType>) -> Self {
248 Self(Capabilities {
249 credentials,
250 ..self.0
251 })
252 }
253
254 pub fn build(self) -> Capabilities {
256 self.0
257 }
258}
259
260#[cfg(test)]
261impl Capabilities {
262 pub fn set_versions(&mut self, versions: Vec<ProtocolVersion>) {
264 self.versions = versions;
265 }
266
267 pub fn set_ciphersuites(&mut self, ciphersuites: Vec<VerifiableCiphersuite>) {
269 self.ciphersuites = ciphersuites;
270 }
271}
272
273impl Default for Capabilities {
274 fn default() -> Self {
275 Capabilities {
276 versions: default_versions(),
277 ciphersuites: default_ciphersuites()
278 .into_iter()
279 .map(VerifiableCiphersuite::from)
280 .collect(),
281 extensions: vec![],
282 proposals: vec![],
283 credentials: default_credentials(),
284 }
285 }
286}
287
288pub(super) fn default_versions() -> Vec<ProtocolVersion> {
289 vec![ProtocolVersion::Mls10]
290}
291
292pub(super) fn default_ciphersuites() -> Vec<Ciphersuite> {
293 vec![
294 Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519,
295 Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256,
296 Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519,
297 Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519,
298 ]
299}
300
301pub(super) fn default_credentials() -> Vec<CredentialType> {
303 vec![CredentialType::Basic]
304}
305
306#[cfg(test)]
307mod tests {
308 use openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
309 use tls_codec::{Deserialize, Serialize};
310
311 use super::Capabilities;
312 use crate::{
313 credentials::CredentialType, messages::proposals::ProposalType, prelude::ExtensionType,
314 versions::ProtocolVersion,
315 };
316
317 #[test]
318 fn that_unknown_capabilities_are_de_serialized_correctly() {
319 let versions = vec![ProtocolVersion::Mls10, ProtocolVersion::Other(999)];
320 let ciphersuites = vec![
321 Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519.into(),
322 Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256.into(),
323 Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519.into(),
324 Ciphersuite::MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448.into(),
325 Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521.into(),
326 Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448.into(),
327 Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384.into(),
328 VerifiableCiphersuite::new(0x0000),
329 VerifiableCiphersuite::new(0x0A0A),
330 VerifiableCiphersuite::new(0x7A7A),
331 VerifiableCiphersuite::new(0xF000),
332 VerifiableCiphersuite::new(0xFFFF),
333 ];
334
335 let extensions = vec![
336 ExtensionType::Unknown(0x0000),
337 ExtensionType::Unknown(0xFAFA),
338 ];
339
340 let proposals = vec![ProposalType::Custom(0x7A7A)];
341
342 let credentials = vec![
343 CredentialType::Basic,
344 CredentialType::X509,
345 CredentialType::Other(0x0000),
346 CredentialType::Other(0x7A7A),
347 CredentialType::Other(0xFFFF),
348 ];
349
350 let expected = Capabilities {
351 versions,
352 ciphersuites,
353 extensions,
354 proposals,
355 credentials,
356 };
357
358 let test_serialized = expected.tls_serialize_detached().unwrap();
359
360 let got = Capabilities::tls_deserialize_exact(test_serialized).unwrap();
361
362 assert_eq!(expected, got);
363 }
364}