openmls/treesync/node/leaf_node/
capabilities.rsuse openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
use serde::{Deserialize, Serialize};
use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize};
#[cfg(doc)]
use super::LeafNode;
use crate::{
credentials::CredentialType,
extensions::{Extension, ExtensionType, Extensions, RequiredCapabilitiesExtension},
messages::proposals::ProposalType,
treesync::errors::LeafNodeValidationError,
versions::ProtocolVersion,
};
#[derive(
Debug,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
TlsSerialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSize,
)]
pub struct Capabilities {
pub(super) versions: Vec<ProtocolVersion>,
pub(super) ciphersuites: Vec<VerifiableCiphersuite>,
pub(super) extensions: Vec<ExtensionType>,
pub(super) proposals: Vec<ProposalType>,
pub(super) credentials: Vec<CredentialType>,
}
impl Capabilities {
pub fn new(
versions: Option<&[ProtocolVersion]>,
ciphersuites: Option<&[Ciphersuite]>,
extensions: Option<&[ExtensionType]>,
proposals: Option<&[ProposalType]>,
credentials: Option<&[CredentialType]>,
) -> Self {
Self {
versions: match versions {
Some(v) => v.into(),
None => default_versions(),
},
ciphersuites: match ciphersuites {
Some(c) => c.iter().map(|c| VerifiableCiphersuite::from(*c)).collect(),
None => default_ciphersuites()
.into_iter()
.map(VerifiableCiphersuite::from)
.collect(),
},
extensions: match extensions {
Some(e) => e.into(),
None => vec![],
},
proposals: match proposals {
Some(p) => p.into(),
None => vec![],
},
credentials: match credentials {
Some(c) => c.into(),
None => default_credentials(),
},
}
}
pub fn empty() -> Self {
Self {
versions: Vec::new(),
ciphersuites: Vec::new(),
extensions: Vec::new(),
proposals: Vec::new(),
credentials: Vec::new(),
}
}
pub fn builder() -> CapabilitiesBuilder {
CapabilitiesBuilder(Self::default())
}
pub fn versions(&self) -> &[ProtocolVersion] {
&self.versions
}
pub fn ciphersuites(&self) -> &[VerifiableCiphersuite] {
&self.ciphersuites
}
pub fn extensions(&self) -> &[ExtensionType] {
&self.extensions
}
pub fn proposals(&self) -> &[ProposalType] {
&self.proposals
}
pub fn credentials(&self) -> &[CredentialType] {
&self.credentials
}
pub(crate) fn supports_required_capabilities(
&self,
required_capabilities: &RequiredCapabilitiesExtension,
) -> Result<(), LeafNodeValidationError> {
let unsupported_extension_types = required_capabilities
.extension_types()
.iter()
.filter(|&e| !self.contains_extension(*e))
.collect::<Vec<_>>();
if !unsupported_extension_types.is_empty() {
log::error!(
"Leaf node does not support all required extension types\n
Supported extensions: {:?}\n
Required extensions: {:?}",
self.extensions(),
required_capabilities.extension_types()
);
return Err(LeafNodeValidationError::UnsupportedExtensions);
}
if required_capabilities
.proposal_types()
.iter()
.any(|p| !self.contains_proposal(*p))
{
return Err(LeafNodeValidationError::UnsupportedProposals);
}
if required_capabilities
.credential_types()
.iter()
.any(|c| !self.contains_credential(*c))
{
return Err(LeafNodeValidationError::UnsupportedCredentials);
}
Ok(())
}
pub(crate) fn contains_extensions(&self, extension: &Extensions) -> bool {
extension
.iter()
.map(Extension::extension_type)
.all(|e| e.is_default() || self.extensions().contains(&e))
}
pub(crate) fn contains_credential(&self, credential_type: CredentialType) -> bool {
self.credentials().contains(&credential_type)
}
pub(crate) fn contains_extension(&self, extension_type: ExtensionType) -> bool {
extension_type.is_default() || self.extensions().contains(&extension_type)
}
pub(crate) fn contains_proposal(&self, proposal_type: ProposalType) -> bool {
proposal_type.is_default() || self.proposals().contains(&proposal_type)
}
pub(crate) fn contains_version(&self, version: ProtocolVersion) -> bool {
self.versions().contains(&version)
}
pub(crate) fn contains_ciphersuite(&self, ciphersuite: VerifiableCiphersuite) -> bool {
self.ciphersuites().contains(&ciphersuite)
}
}
#[derive(Debug, Clone)]
pub struct CapabilitiesBuilder(Capabilities);
impl CapabilitiesBuilder {
pub fn versions(self, versions: Vec<ProtocolVersion>) -> Self {
Self(Capabilities { versions, ..self.0 })
}
pub fn ciphersuites(self, ciphersuites: Vec<Ciphersuite>) -> Self {
let ciphersuites = ciphersuites.into_iter().map(|cs| cs.into()).collect();
Self(Capabilities {
ciphersuites,
..self.0
})
}
pub fn extensions(self, extensions: Vec<ExtensionType>) -> Self {
Self(Capabilities {
extensions,
..self.0
})
}
pub fn proposals(self, proposals: Vec<ProposalType>) -> Self {
Self(Capabilities {
proposals,
..self.0
})
}
pub fn credentials(self, credentials: Vec<CredentialType>) -> Self {
Self(Capabilities {
credentials,
..self.0
})
}
pub fn build(self) -> Capabilities {
self.0
}
}
#[cfg(test)]
impl Capabilities {
pub fn set_versions(&mut self, versions: Vec<ProtocolVersion>) {
self.versions = versions;
}
pub fn set_ciphersuites(&mut self, ciphersuites: Vec<VerifiableCiphersuite>) {
self.ciphersuites = ciphersuites;
}
}
impl Default for Capabilities {
fn default() -> Self {
Capabilities {
versions: default_versions(),
ciphersuites: default_ciphersuites()
.into_iter()
.map(VerifiableCiphersuite::from)
.collect(),
extensions: vec![],
proposals: vec![],
credentials: default_credentials(),
}
}
}
pub(super) fn default_versions() -> Vec<ProtocolVersion> {
vec![ProtocolVersion::Mls10]
}
pub(super) fn default_ciphersuites() -> Vec<Ciphersuite> {
vec![
Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519,
Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256,
Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519,
Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519,
]
}
pub(super) fn default_credentials() -> Vec<CredentialType> {
vec![CredentialType::Basic]
}
#[cfg(test)]
mod tests {
use openmls_traits::types::{Ciphersuite, VerifiableCiphersuite};
use tls_codec::{Deserialize, Serialize};
use super::Capabilities;
use crate::{
credentials::CredentialType, messages::proposals::ProposalType, prelude::ExtensionType,
versions::ProtocolVersion,
};
#[test]
fn that_unknown_capabilities_are_de_serialized_correctly() {
let versions = vec![ProtocolVersion::Mls10, ProtocolVersion::Other(999)];
let ciphersuites = vec![
Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519.into(),
Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256.into(),
Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519.into(),
Ciphersuite::MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448.into(),
Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521.into(),
Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448.into(),
Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384.into(),
VerifiableCiphersuite::new(0x0000),
VerifiableCiphersuite::new(0x0A0A),
VerifiableCiphersuite::new(0x7A7A),
VerifiableCiphersuite::new(0xF000),
VerifiableCiphersuite::new(0xFFFF),
];
let extensions = vec![
ExtensionType::Unknown(0x0000),
ExtensionType::Unknown(0xFAFA),
];
let proposals = vec![ProposalType::Custom(0x7A7A)];
let credentials = vec![
CredentialType::Basic,
CredentialType::X509,
CredentialType::Other(0x0000),
CredentialType::Other(0x7A7A),
CredentialType::Other(0xFFFF),
];
let expected = Capabilities {
versions,
ciphersuites,
extensions,
proposals,
credentials,
};
let test_serialized = expected.tls_serialize_detached().unwrap();
let got = Capabilities::tls_deserialize_exact(test_serialized).unwrap();
assert_eq!(expected, got);
}
}