openmls/framing/
public_message_in.rs

1//! # PublicMessageIn
2//!
3//! A PublicMessageIn is a framing structure for MLS messages. It can contain
4//! Proposals, Commits and application messages.
5
6use crate::{error::LibraryError, group::errors::ValidationError, versions::ProtocolVersion};
7
8use super::{
9    mls_auth_content::FramedContentAuthData,
10    mls_auth_content_in::{AuthenticatedContentIn, VerifiableAuthenticatedContentIn},
11    mls_content::{framed_content_tbs_serialized_detached, AuthenticatedContentTbm},
12    mls_content_in::FramedContentIn,
13    *,
14};
15
16use openmls_traits::types::Ciphersuite;
17use std::io::{Read, Write};
18use tls_codec::{Deserialize as TlsDeserializeTrait, Serialize as TlsSerializeTrait};
19
20/// [`PublicMessageIn`] is a framing structure for MLS messages. It can contain
21/// Proposals, Commits and application messages.
22///
23/// 9. Message framing
24///
25/// ```c
26/// // draft-ietf-mls-protocol-17
27///
28/// struct {
29///     FramedContent content;
30///     FramedContentAuthData auth;
31///     optional<MAC> membership_tag;
32/// } PublicMessage;
33/// ```
34#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
35pub struct PublicMessageIn {
36    pub(crate) content: FramedContentIn,
37    pub(crate) auth: FramedContentAuthData,
38    pub(crate) membership_tag: Option<MembershipTag>,
39}
40
41#[cfg(any(test, feature = "test-utils"))]
42impl PublicMessageIn {
43    pub(crate) fn content(&self) -> &crate::framing::mls_content_in::FramedContentBodyIn {
44        &self.content.body
45    }
46}
47
48#[cfg(test)]
49impl PublicMessageIn {
50    pub fn set_confirmation_tag(&mut self, confirmation_tag: Option<ConfirmationTag>) {
51        self.auth.confirmation_tag = confirmation_tag;
52    }
53
54    pub fn unset_membership_tag(&mut self) {
55        self.membership_tag = None;
56    }
57
58    pub(crate) fn set_content(&mut self, content: FramedContentBodyIn) {
59        self.content.body = content;
60    }
61
62    pub fn set_epoch(&mut self, epoch: u64) {
63        self.content.epoch = epoch.into();
64    }
65
66    /// Set the sender.
67    pub(crate) fn set_sender(&mut self, sender: Sender) {
68        self.content.sender = sender;
69    }
70}
71
72impl From<AuthenticatedContentIn> for PublicMessageIn {
73    fn from(v: AuthenticatedContentIn) -> Self {
74        Self {
75            content: v.content,
76            auth: v.auth,
77            membership_tag: None,
78        }
79    }
80}
81
82impl PublicMessageIn {
83    /// Build an [`PublicMessageIn`].
84    pub(crate) fn new(
85        content: FramedContentIn,
86        auth: FramedContentAuthData,
87        membership_tag: Option<MembershipTag>,
88    ) -> Self {
89        Self {
90            content,
91            auth,
92            membership_tag,
93        }
94    }
95
96    /// Returns the [`ContentType`] of the message.
97    pub fn content_type(&self) -> ContentType {
98        self.content.body.content_type()
99    }
100
101    /// Get the sender of this message.
102    pub fn sender(&self) -> &Sender {
103        &self.content.sender
104    }
105
106    #[cfg(test)]
107    pub(crate) fn set_membership_tag(
108        &mut self,
109        provider: &impl openmls_traits::OpenMlsProvider,
110        ciphersuite: Ciphersuite,
111        membership_key: &MembershipKey,
112        serialized_context: &[u8],
113    ) -> Result<(), LibraryError> {
114        let tbs_payload = framed_content_tbs_serialized_detached(
115            ProtocolVersion::default(),
116            WireFormat::PublicMessage,
117            &self.content,
118            &self.content.sender,
119            serialized_context,
120        )
121        .map_err(LibraryError::missing_bound_check)?;
122        let tbm_payload = AuthenticatedContentTbm::new(&tbs_payload, &self.auth)?;
123        let membership_tag =
124            membership_key.tag_message(provider.crypto(), ciphersuite, tbm_payload)?;
125
126        self.membership_tag = Some(membership_tag);
127        Ok(())
128    }
129
130    /// Verify the membership tag of a [`PublicMessage`] sent from a group
131    /// member. Returns `Ok(())` if successful or [`ValidationError`] otherwise.
132    /// Note, that the context must have been set before calling this function.
133    // TODO #133: Include this in the validation
134    pub(crate) fn verify_membership(
135        &self,
136        crypto: &impl openmls_traits::crypto::OpenMlsCrypto,
137        ciphersuite: Ciphersuite,
138        membership_key: &MembershipKey,
139        serialized_context: &[u8],
140    ) -> Result<(), ValidationError> {
141        log::debug!("Verifying membership tag.");
142        log_crypto!(trace, "  Membership key: {:x?}", membership_key);
143        log_crypto!(trace, "  Serialized context: {:x?}", serialized_context);
144        let tbs_payload = framed_content_tbs_serialized_detached(
145            ProtocolVersion::default(),
146            WireFormat::PublicMessage,
147            &self.content,
148            &self.content.sender,
149            serialized_context,
150        )
151        .map_err(LibraryError::missing_bound_check)?;
152        let tbm_payload = AuthenticatedContentTbm::new(&tbs_payload, &self.auth)?;
153        let expected_membership_tag =
154            &membership_key.tag_message(crypto, ciphersuite, tbm_payload)?;
155
156        // Verify the membership tag
157        // https://validation.openmls.tech/#valn1302
158        if let Some(membership_tag) = &self.membership_tag {
159            // TODO #133: make this a constant-time comparison
160            if membership_tag != expected_membership_tag {
161                return Err(ValidationError::InvalidMembershipTag);
162            }
163        } else {
164            return Err(ValidationError::MissingMembershipTag);
165        }
166        Ok(())
167    }
168
169    /// Get the group epoch.
170    pub fn epoch(&self) -> GroupEpoch {
171        self.content.epoch
172    }
173
174    /// Get the [`GroupId`].
175    pub fn group_id(&self) -> &GroupId {
176        &self.content.group_id
177    }
178
179    /// Turn this [`PublicMessageIn`] into a [`VerifiableAuthenticatedContent`].
180    pub(crate) fn into_verifiable_content(
181        self,
182        serialized_context: impl Into<Option<Vec<u8>>>,
183    ) -> VerifiableAuthenticatedContentIn {
184        VerifiableAuthenticatedContentIn::new(
185            WireFormat::PublicMessage,
186            self.content,
187            serialized_context,
188            self.auth,
189        )
190    }
191
192    /// Get the [`MembershipTag`].
193    pub(crate) fn membership_tag(&self) -> Option<&MembershipTag> {
194        self.membership_tag.as_ref()
195    }
196
197    /// Get the [`ConfirmationTag`].
198    pub fn confirmation_tag(&self) -> Option<&ConfirmationTag> {
199        self.auth.confirmation_tag.as_ref()
200    }
201}
202
203#[cfg(test)]
204impl From<PublicMessageIn> for FramedContentTbsIn {
205    fn from(v: PublicMessageIn) -> Self {
206        FramedContentTbsIn {
207            version: ProtocolVersion::default(),
208            wire_format: WireFormat::PublicMessage,
209            content: v.content,
210            serialized_context: None,
211        }
212    }
213}
214
215impl<'a> TryFrom<&'a PublicMessageIn> for InterimTranscriptHashInput<'a> {
216    type Error = &'static str;
217
218    fn try_from(public_message: &'a PublicMessageIn) -> Result<Self, Self::Error> {
219        match public_message.auth.confirmation_tag.as_ref() {
220            Some(confirmation_tag) => Ok(InterimTranscriptHashInput { confirmation_tag }),
221            None => Err("PublicMessage needs to contain a confirmation tag."),
222        }
223    }
224}
225
226impl TlsDeserializeTrait for PublicMessageIn {
227    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error> {
228        let content = FramedContentIn::tls_deserialize(bytes)?;
229        let auth = FramedContentAuthData::deserialize(bytes, content.body.content_type())?;
230        let membership_tag = if content.sender.is_member() {
231            Some(MembershipTag::tls_deserialize(bytes)?)
232        } else {
233            None
234        };
235
236        Ok(PublicMessageIn::new(content, auth, membership_tag))
237    }
238}
239
240impl DeserializeBytes for PublicMessageIn {
241    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
242    where
243        Self: Sized,
244    {
245        let mut bytes_ref = bytes;
246        let message = PublicMessageIn::tls_deserialize(&mut bytes_ref)?;
247        let remainder = &bytes[message.tls_serialized_len()..];
248        Ok((message, remainder))
249    }
250}
251
252impl Size for PublicMessageIn {
253    #[inline]
254    fn tls_serialized_len(&self) -> usize {
255        self.content.tls_serialized_len()
256            + self.auth.tls_serialized_len()
257            + if let Some(membership_tag) = &self.membership_tag {
258                membership_tag.tls_serialized_len()
259            } else {
260                0
261            }
262    }
263}
264
265impl TlsSerializeTrait for PublicMessageIn {
266    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
267        // Serialize the content, not the TBS.
268        let mut written = self.content.tls_serialize(writer)?;
269        written += self.auth.tls_serialize(writer)?;
270        written += if let Some(membership_tag) = &self.membership_tag {
271            membership_tag.tls_serialize(writer)?
272        } else {
273            0
274        };
275        Ok(written)
276    }
277}
278
279// The following `From` implementation( breaks abstraction layers and MUST
280// NOT be made available outside of tests or "test-utils".
281#[cfg(any(feature = "test-utils", test))]
282impl From<PublicMessageIn> for PublicMessage {
283    fn from(v: PublicMessageIn) -> Self {
284        PublicMessage {
285            content: v.content.into(),
286            auth: v.auth,
287            membership_tag: v.membership_tag,
288        }
289    }
290}
291
292#[cfg(any(feature = "test-utils", test))]
293impl From<PublicMessage> for PublicMessageIn {
294    fn from(v: PublicMessage) -> Self {
295        PublicMessageIn {
296            content: v.content.into(),
297            auth: v.auth,
298            membership_tag: v.membership_tag,
299        }
300    }
301}