use openmls_traits::{crypto::OpenMlsCrypto, random::OpenMlsRand, types::Ciphersuite};
use std::io::Write;
use tls_codec::{Serialize, Size, TlsSerialize, TlsSize};
use super::mls_auth_content::AuthenticatedContent;
use crate::{
binary_tree::array_representation::LeafNodeIndex, error::LibraryError,
tree::secret_tree::SecretType,
};
use super::*;
#[derive(Debug, PartialEq, Eq, Clone, TlsSerialize, TlsSize)]
pub struct PrivateMessage {
pub(crate) group_id: GroupId,
pub(crate) epoch: GroupEpoch,
pub(crate) content_type: ContentType,
pub(crate) authenticated_data: VLBytes,
pub(crate) encrypted_sender_data: VLBytes,
pub(crate) ciphertext: VLBytes,
}
pub(crate) struct MlsMessageHeader {
pub(crate) group_id: GroupId,
pub(crate) epoch: GroupEpoch,
pub(crate) sender: LeafNodeIndex,
}
impl PrivateMessage {
#[cfg(test)]
pub(crate) fn new(
group_id: GroupId,
epoch: GroupEpoch,
content_type: ContentType,
authenticated_data: VLBytes,
encrypted_sender_data: VLBytes,
ciphertext: VLBytes,
) -> Self {
Self {
group_id,
epoch,
content_type,
authenticated_data,
encrypted_sender_data,
ciphertext,
}
}
pub(crate) fn try_from_authenticated_content<T>(
crypto: &impl OpenMlsCrypto,
rand: &impl OpenMlsRand,
public_message: &AuthenticatedContent,
ciphersuite: Ciphersuite,
message_secrets: &mut MessageSecrets,
padding_size: usize,
) -> Result<PrivateMessage, MessageEncryptionError<T>> {
log::debug!("PrivateMessage::try_from_authenticated_content");
log::trace!(" ciphersuite: {}", ciphersuite);
if public_message.wire_format() != WireFormat::PrivateMessage {
return Err(MessageEncryptionError::WrongWireFormat);
}
Self::encrypt_content(
crypto,
rand,
None,
public_message,
ciphersuite,
message_secrets,
padding_size,
)
}
#[cfg(any(feature = "test-utils", test))]
pub(crate) fn encrypt_without_check<T>(
crypto: &impl OpenMlsCrypto,
rand: &impl OpenMlsRand,
public_message: &AuthenticatedContent,
ciphersuite: Ciphersuite,
message_secrets: &mut MessageSecrets,
padding_size: usize,
) -> Result<PrivateMessage, MessageEncryptionError<T>> {
Self::encrypt_content(
crypto,
rand,
None,
public_message,
ciphersuite,
message_secrets,
padding_size,
)
}
#[cfg(test)]
pub(crate) fn encrypt_with_different_header<T>(
crypto: &impl OpenMlsCrypto,
rand: &impl OpenMlsRand,
public_message: &AuthenticatedContent,
ciphersuite: Ciphersuite,
header: MlsMessageHeader,
message_secrets: &mut MessageSecrets,
padding_size: usize,
) -> Result<PrivateMessage, MessageEncryptionError<T>> {
Self::encrypt_content(
crypto,
rand,
Some(header),
public_message,
ciphersuite,
message_secrets,
padding_size,
)
}
fn encrypt_content<T>(
crypto: &impl OpenMlsCrypto,
rand: &impl OpenMlsRand,
test_header: Option<MlsMessageHeader>,
public_message: &AuthenticatedContent,
ciphersuite: Ciphersuite,
message_secrets: &mut MessageSecrets,
padding_size: usize,
) -> Result<PrivateMessage, MessageEncryptionError<T>> {
let sender_index = if let Some(index) = public_message.sender().as_member() {
index
} else {
return Err(LibraryError::custom("Sender is not a member.").into());
};
let header = match test_header {
Some(header) if cfg!(any(feature = "test-utils", test)) => header,
_ => MlsMessageHeader {
group_id: public_message.group_id().clone(),
epoch: public_message.epoch(),
sender: sender_index,
},
};
let private_message_content_aad = PrivateContentAad {
group_id: header.group_id.clone(),
epoch: header.epoch,
content_type: public_message.content().content_type(),
authenticated_data: VLByteSlice(public_message.authenticated_data()),
};
let private_message_content_aad_bytes = private_message_content_aad
.tls_serialize_detached()
.map_err(LibraryError::missing_bound_check)?;
let secret_type = SecretType::from(&public_message.content().content_type());
let (generation, (ratchet_key, ratchet_nonce)) = message_secrets
.secret_tree_mut()
.secret_for_encryption(ciphersuite, crypto, sender_index, secret_type)?;
let reuse_guard: ReuseGuard =
ReuseGuard::try_from_random(rand).map_err(LibraryError::unexpected_crypto_error)?;
let prepared_nonce = ratchet_nonce.xor_with_reuse_guard(&reuse_guard);
log_crypto!(
trace,
"Encryption key for private message: {ratchet_key:x?}"
);
log_crypto!(trace, "Encryption of private message private_message_content_aad_bytes: {private_message_content_aad_bytes:x?} - ratchet_nonce: {prepared_nonce:x?}");
let ciphertext = ratchet_key
.aead_seal(
crypto,
&Self::encode_padded_ciphertext_content_detached(
public_message,
padding_size,
ciphersuite.mac_length(),
)
.map_err(LibraryError::missing_bound_check)?,
&private_message_content_aad_bytes,
&prepared_nonce,
)
.map_err(LibraryError::unexpected_crypto_error)?;
log::trace!("Encrypted ciphertext {:x?}", ciphertext);
let sender_data_key = message_secrets
.sender_data_secret()
.derive_aead_key(crypto, ciphersuite, &ciphertext)
.map_err(LibraryError::unexpected_crypto_error)?;
let sender_data_nonce = message_secrets
.sender_data_secret()
.derive_aead_nonce(ciphersuite, crypto, &ciphertext)
.map_err(LibraryError::unexpected_crypto_error)?;
let mls_sender_data_aad = MlsSenderDataAad::new(
header.group_id.clone(),
header.epoch,
public_message.content().content_type(),
);
let mls_sender_data_aad_bytes = mls_sender_data_aad
.tls_serialize_detached()
.map_err(LibraryError::missing_bound_check)?;
let sender_data = MlsSenderData::from_sender(
header.sender,
generation,
reuse_guard,
);
log_crypto!(
trace,
"Encryption key for sender data: {sender_data_key:x?}"
);
log_crypto!(trace, "Encryption of sender data mls_sender_data_aad_bytes: {mls_sender_data_aad_bytes:x?} - sender_data_nonce: {sender_data_nonce:x?}");
let encrypted_sender_data = sender_data_key
.aead_seal(
crypto,
&sender_data
.tls_serialize_detached()
.map_err(LibraryError::missing_bound_check)?,
&mls_sender_data_aad_bytes,
&sender_data_nonce,
)
.map_err(LibraryError::unexpected_crypto_error)?;
Ok(PrivateMessage {
group_id: header.group_id.clone(),
epoch: header.epoch,
content_type: public_message.content().content_type(),
authenticated_data: public_message.authenticated_data().into(),
encrypted_sender_data: encrypted_sender_data.into(),
ciphertext: ciphertext.into(),
})
}
#[cfg(test)]
pub(crate) fn is_handshake_message(&self) -> bool {
self.content_type.is_handshake_message()
}
fn encode_padded_ciphertext_content_detached(
authenticated_content: &AuthenticatedContent,
padding_size: usize,
mac_len: usize,
) -> Result<Vec<u8>, tls_codec::Error> {
let plaintext_length = authenticated_content
.content()
.serialized_len_without_type()
+ authenticated_content.auth.tls_serialized_len();
let padding_length = if padding_size > 0 {
let padding_offset = plaintext_length + mac_len;
(padding_size - (padding_offset % padding_size)) % padding_size
} else {
0
};
let buffer = &mut Vec::with_capacity(plaintext_length + padding_length);
authenticated_content
.content()
.serialize_without_type(buffer)?;
authenticated_content.auth.tls_serialize(buffer)?;
buffer
.write_all(&vec![0u8; padding_length])
.map_err(|_| Error::EncodingError("Failed to write padding.".into()))?;
Ok(buffer.to_vec())
}
#[cfg(test)]
pub(crate) fn ciphertext(&self) -> &[u8] {
self.ciphertext.as_slice()
}
}
#[derive(TlsSerialize, TlsSize)]
pub(crate) struct PrivateContentAad<'a> {
pub(crate) group_id: GroupId,
pub(crate) epoch: GroupEpoch,
pub(crate) content_type: ContentType,
pub(crate) authenticated_data: VLByteSlice<'a>,
}