1use std::io::Write;
7
8use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
9use tls_codec::{
10 Serialize as TlsSerializeTrait, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
11};
12
13use super::{
14 mls_auth_content::{AuthenticatedContent, FramedContentAuthData},
15 mls_content::{framed_content_tbs_serialized_detached, AuthenticatedContentTbm, FramedContent},
16 *,
17};
18use crate::{error::LibraryError, versions::ProtocolVersion};
19
20#[derive(
22 Debug,
23 PartialEq,
24 Clone,
25 Serialize,
26 Deserialize,
27 TlsSerialize,
28 TlsDeserialize,
29 TlsDeserializeBytes,
30 TlsSize,
31)]
32pub(crate) struct MembershipTag(pub(crate) Mac);
33
34#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
49pub struct PublicMessage {
50 pub(crate) content: FramedContent,
51 pub(crate) auth: FramedContentAuthData,
52 pub(crate) membership_tag: Option<MembershipTag>,
53}
54
55#[cfg(test)]
56impl PublicMessage {
57 pub(crate) fn content(&self) -> &crate::framing::mls_content::FramedContentBody {
58 &self.content.body
59 }
60
61 pub fn set_confirmation_tag(&mut self, confirmation_tag: Option<ConfirmationTag>) {
62 self.auth.confirmation_tag = confirmation_tag;
63 }
64
65 pub fn unset_membership_tag(&mut self) {
66 self.membership_tag = None;
67 }
68
69 pub(crate) fn set_content(&mut self, content: FramedContentBody) {
70 self.content.body = content;
71 }
72
73 pub fn set_epoch(&mut self, epoch: u64) {
74 self.content.epoch = epoch.into();
75 }
76
77 pub fn confirmation_tag(&self) -> Option<&ConfirmationTag> {
78 self.auth.confirmation_tag.as_ref()
79 }
80
81 pub(crate) fn invalidate_signature(&mut self) {
82 let mut modified_signature = self.auth.signature.as_slice().to_vec();
83 modified_signature[0] ^= 0xFF;
84 self.auth.signature.modify(&modified_signature);
85 }
86
87 pub(crate) fn set_sender(&mut self, sender: Sender) {
89 self.content.sender = sender;
90 }
91
92 pub(crate) fn set_group_id(&mut self, group_id: GroupId) {
94 self.content.group_id = group_id;
95 }
96
97 pub(crate) fn is_handshake_message(&self) -> bool {
99 self.content_type().is_handshake_message()
100 }
101}
102
103impl From<AuthenticatedContent> for PublicMessage {
104 fn from(v: AuthenticatedContent) -> Self {
105 Self {
106 content: v.content,
107 auth: v.auth,
108 membership_tag: None,
109 }
110 }
111}
112
113impl PublicMessage {
114 pub fn content_type(&self) -> ContentType {
116 self.content.body.content_type()
117 }
118
119 pub(crate) fn sender(&self) -> &Sender {
121 &self.content.sender
122 }
123
124 pub(crate) fn set_membership_tag(
129 &mut self,
130 crypto: &impl OpenMlsCrypto,
131 ciphersuite: Ciphersuite,
132 membership_key: &MembershipKey,
133 serialized_context: &[u8],
134 ) -> Result<(), LibraryError> {
135 let tbs_payload = framed_content_tbs_serialized_detached(
136 ProtocolVersion::default(),
137 WireFormat::PublicMessage,
138 &self.content,
139 &self.content.sender,
140 serialized_context,
141 )
142 .map_err(LibraryError::missing_bound_check)?;
143 let tbm_payload = AuthenticatedContentTbm::new(&tbs_payload, &self.auth)?;
144 let membership_tag = membership_key.tag_message(crypto, ciphersuite, tbm_payload)?;
145
146 self.membership_tag = Some(membership_tag);
147 Ok(())
148 }
149
150 #[cfg(test)]
151 pub(crate) fn set_membership_tag_test(&mut self, membership_tag: MembershipTag) {
152 self.membership_tag = Some(membership_tag);
153 }
154}
155
156#[cfg(test)]
157impl From<PublicMessage> for FramedContentTbs {
158 fn from(v: PublicMessage) -> Self {
159 FramedContentTbs {
160 version: ProtocolVersion::default(),
161 wire_format: WireFormat::PublicMessage,
162 content: v.content,
163 serialized_context: None,
164 }
165 }
166}
167
168#[derive(TlsSerialize, TlsSize)]
179pub(crate) struct ConfirmedTranscriptHashInput<'a> {
180 pub(super) wire_format: WireFormat,
181 pub(super) mls_content: &'a FramedContent,
182 pub(super) signature: &'a Signature,
183}
184
185impl ConfirmedTranscriptHashInput<'_> {
186 pub(crate) fn calculate_confirmed_transcript_hash(
187 self,
188 crypto: &impl OpenMlsCrypto,
189 ciphersuite: Ciphersuite,
190 interim_transcript_hash: &[u8],
191 ) -> Result<Vec<u8>, LibraryError> {
192 let serialized: Vec<u8> = self
193 .tls_serialize_detached()
194 .map_err(LibraryError::missing_bound_check)?;
195
196 crypto
197 .hash(
198 ciphersuite.hash_algorithm(),
199 &[interim_transcript_hash, &serialized].concat(),
200 )
201 .map_err(LibraryError::unexpected_crypto_error)
202 }
203}
204
205impl<'a> TryFrom<&'a AuthenticatedContent> for ConfirmedTranscriptHashInput<'a> {
206 type Error = &'static str;
207
208 fn try_from(mls_content: &'a AuthenticatedContent) -> Result<Self, Self::Error> {
209 if !matches!(mls_content.content().content_type(), ContentType::Commit) {
210 return Err("PublicMessage needs to contain a Commit.");
211 }
212
213 Ok(ConfirmedTranscriptHashInput {
214 wire_format: mls_content.wire_format(),
215 mls_content: &mls_content.content,
216 signature: mls_content.signature(),
217 })
218 }
219}
220
221#[derive(TlsSerialize, TlsSize)]
230pub(crate) struct InterimTranscriptHashInput<'a> {
231 pub(crate) confirmation_tag: &'a ConfirmationTag,
232}
233
234impl InterimTranscriptHashInput<'_> {
235 pub fn calculate_interim_transcript_hash(
236 self,
237 crypto: &impl OpenMlsCrypto,
238 ciphersuite: Ciphersuite,
239 confirmed_transcript_hash: &[u8],
240 ) -> Result<Vec<u8>, LibraryError> {
241 let serialized = self
242 .tls_serialize_detached()
243 .map_err(LibraryError::missing_bound_check)?;
244
245 crypto
246 .hash(
247 ciphersuite.hash_algorithm(),
248 &[confirmed_transcript_hash, &serialized].concat(),
249 )
250 .map_err(LibraryError::unexpected_crypto_error)
251 }
252}
253
254impl<'a> TryFrom<&'a PublicMessage> for InterimTranscriptHashInput<'a> {
255 type Error = &'static str;
256
257 fn try_from(public_message: &'a PublicMessage) -> Result<Self, Self::Error> {
258 match public_message.auth.confirmation_tag.as_ref() {
259 Some(confirmation_tag) => Ok(InterimTranscriptHashInput { confirmation_tag }),
260 None => Err("PublicMessage needs to contain a confirmation tag."),
261 }
262 }
263}
264
265impl<'a> From<&'a ConfirmationTag> for InterimTranscriptHashInput<'a> {
266 fn from(confirmation_tag: &'a ConfirmationTag) -> Self {
267 InterimTranscriptHashInput { confirmation_tag }
268 }
269}
270
271impl Size for PublicMessage {
274 #[inline]
275 fn tls_serialized_len(&self) -> usize {
276 self.content.tls_serialized_len()
277 + self.auth.tls_serialized_len()
278 + if let Some(membership_tag) = &self.membership_tag {
279 membership_tag.tls_serialized_len()
280 } else {
281 0
282 }
283 }
284}
285
286impl TlsSerializeTrait for PublicMessage {
287 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
288 let mut written = self.content.tls_serialize(writer)?;
290 written += self.auth.tls_serialize(writer)?;
291 written += if let Some(membership_tag) = &self.membership_tag {
292 membership_tag.tls_serialize(writer)?
293 } else {
294 0
295 };
296 Ok(written)
297 }
298}