use crate::{error::LibraryError, group::errors::ValidationError, versions::ProtocolVersion};
use super::{
mls_auth_content::FramedContentAuthData,
mls_auth_content_in::{AuthenticatedContentIn, VerifiableAuthenticatedContentIn},
mls_content::{framed_content_tbs_serialized_detached, AuthenticatedContentTbm},
mls_content_in::FramedContentIn,
*,
};
use openmls_traits::types::Ciphersuite;
use std::io::{Read, Write};
use tls_codec::{Deserialize as TlsDeserializeTrait, Serialize as TlsSerializeTrait};
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct PublicMessageIn {
pub(crate) content: FramedContentIn,
pub(crate) auth: FramedContentAuthData,
pub(crate) membership_tag: Option<MembershipTag>,
}
#[cfg(any(test, feature = "test-utils"))]
impl PublicMessageIn {
pub(crate) fn content(&self) -> &crate::framing::mls_content_in::FramedContentBodyIn {
&self.content.body
}
}
#[cfg(test)]
impl PublicMessageIn {
pub fn set_confirmation_tag(&mut self, confirmation_tag: Option<ConfirmationTag>) {
self.auth.confirmation_tag = confirmation_tag;
}
pub fn unset_membership_tag(&mut self) {
self.membership_tag = None;
}
pub(crate) fn set_content(&mut self, content: FramedContentBodyIn) {
self.content.body = content;
}
pub fn set_epoch(&mut self, epoch: u64) {
self.content.epoch = epoch.into();
}
pub(crate) fn set_sender(&mut self, sender: Sender) {
self.content.sender = sender;
}
}
impl From<AuthenticatedContentIn> for PublicMessageIn {
fn from(v: AuthenticatedContentIn) -> Self {
Self {
content: v.content,
auth: v.auth,
membership_tag: None,
}
}
}
impl PublicMessageIn {
pub(crate) fn new(
content: FramedContentIn,
auth: FramedContentAuthData,
membership_tag: Option<MembershipTag>,
) -> Self {
Self {
content,
auth,
membership_tag,
}
}
pub fn content_type(&self) -> ContentType {
self.content.body.content_type()
}
pub fn sender(&self) -> &Sender {
&self.content.sender
}
#[cfg(test)]
pub(crate) fn set_membership_tag(
&mut self,
provider: &impl openmls_traits::OpenMlsProvider,
ciphersuite: Ciphersuite,
membership_key: &MembershipKey,
serialized_context: &[u8],
) -> Result<(), LibraryError> {
let tbs_payload = framed_content_tbs_serialized_detached(
ProtocolVersion::default(),
WireFormat::PublicMessage,
&self.content,
&self.content.sender,
serialized_context,
)
.map_err(LibraryError::missing_bound_check)?;
let tbm_payload = AuthenticatedContentTbm::new(&tbs_payload, &self.auth)?;
let membership_tag =
membership_key.tag_message(provider.crypto(), ciphersuite, tbm_payload)?;
self.membership_tag = Some(membership_tag);
Ok(())
}
pub(crate) fn verify_membership(
&self,
crypto: &impl openmls_traits::crypto::OpenMlsCrypto,
ciphersuite: Ciphersuite,
membership_key: &MembershipKey,
serialized_context: &[u8],
) -> Result<(), ValidationError> {
log::debug!("Verifying membership tag.");
log_crypto!(trace, " Membership key: {:x?}", membership_key);
log_crypto!(trace, " Serialized context: {:x?}", serialized_context);
let tbs_payload = framed_content_tbs_serialized_detached(
ProtocolVersion::default(),
WireFormat::PublicMessage,
&self.content,
&self.content.sender,
serialized_context,
)
.map_err(LibraryError::missing_bound_check)?;
let tbm_payload = AuthenticatedContentTbm::new(&tbs_payload, &self.auth)?;
let expected_membership_tag =
&membership_key.tag_message(crypto, ciphersuite, tbm_payload)?;
if let Some(membership_tag) = &self.membership_tag {
if membership_tag != expected_membership_tag {
return Err(ValidationError::InvalidMembershipTag);
}
} else {
return Err(ValidationError::MissingMembershipTag);
}
Ok(())
}
pub fn epoch(&self) -> GroupEpoch {
self.content.epoch
}
pub fn group_id(&self) -> &GroupId {
&self.content.group_id
}
pub(crate) fn into_verifiable_content(
self,
serialized_context: impl Into<Option<Vec<u8>>>,
) -> VerifiableAuthenticatedContentIn {
VerifiableAuthenticatedContentIn::new(
WireFormat::PublicMessage,
self.content,
serialized_context,
self.auth,
)
}
pub(crate) fn membership_tag(&self) -> Option<&MembershipTag> {
self.membership_tag.as_ref()
}
pub fn confirmation_tag(&self) -> Option<&ConfirmationTag> {
self.auth.confirmation_tag.as_ref()
}
}
#[cfg(test)]
impl From<PublicMessageIn> for FramedContentTbsIn {
fn from(v: PublicMessageIn) -> Self {
FramedContentTbsIn {
version: ProtocolVersion::default(),
wire_format: WireFormat::PublicMessage,
content: v.content,
serialized_context: None,
}
}
}
impl<'a> TryFrom<&'a PublicMessageIn> for InterimTranscriptHashInput<'a> {
type Error = &'static str;
fn try_from(public_message: &'a PublicMessageIn) -> Result<Self, Self::Error> {
match public_message.auth.confirmation_tag.as_ref() {
Some(confirmation_tag) => Ok(InterimTranscriptHashInput { confirmation_tag }),
None => Err("PublicMessage needs to contain a confirmation tag."),
}
}
}
impl TlsDeserializeTrait for PublicMessageIn {
fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error> {
let content = FramedContentIn::tls_deserialize(bytes)?;
let auth = FramedContentAuthData::deserialize(bytes, content.body.content_type())?;
let membership_tag = if content.sender.is_member() {
Some(MembershipTag::tls_deserialize(bytes)?)
} else {
None
};
Ok(PublicMessageIn::new(content, auth, membership_tag))
}
}
impl DeserializeBytes for PublicMessageIn {
fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
where
Self: Sized,
{
let mut bytes_ref = bytes;
let message = PublicMessageIn::tls_deserialize(&mut bytes_ref)?;
let remainder = &bytes[message.tls_serialized_len()..];
Ok((message, remainder))
}
}
impl Size for PublicMessageIn {
#[inline]
fn tls_serialized_len(&self) -> usize {
self.content.tls_serialized_len()
+ self.auth.tls_serialized_len()
+ if let Some(membership_tag) = &self.membership_tag {
membership_tag.tls_serialized_len()
} else {
0
}
}
}
impl TlsSerializeTrait for PublicMessageIn {
fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
let mut written = self.content.tls_serialize(writer)?;
written += self.auth.tls_serialize(writer)?;
written += if let Some(membership_tag) = &self.membership_tag {
membership_tag.tls_serialize(writer)?
} else {
0
};
Ok(written)
}
}
#[cfg(any(feature = "test-utils", test))]
impl From<PublicMessageIn> for PublicMessage {
fn from(v: PublicMessageIn) -> Self {
PublicMessage {
content: v.content.into(),
auth: v.auth,
membership_tag: v.membership_tag,
}
}
}
#[cfg(any(feature = "test-utils", test))]
impl From<PublicMessage> for PublicMessageIn {
fn from(v: PublicMessage) -> Self {
PublicMessageIn {
content: v.content.into(),
auth: v.auth,
membership_tag: v.membership_tag,
}
}
}