1use std::io::Write;
7
8use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
9use tls_codec::{
10 Serialize as TlsSerializeTrait, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
11};
12
13use super::{
14 mls_auth_content::{AuthenticatedContent, FramedContentAuthData},
15 mls_content::{framed_content_tbs_serialized_detached, AuthenticatedContentTbm, FramedContent},
16 *,
17};
18use crate::{error::LibraryError, versions::ProtocolVersion};
19
20#[derive(
22 Debug,
23 PartialEq,
24 Clone,
25 Serialize,
26 Deserialize,
27 TlsSerialize,
28 TlsDeserialize,
29 TlsDeserializeBytes,
30 TlsSize,
31)]
32pub(crate) struct MembershipTag(pub(crate) Mac);
33
34#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
49pub struct PublicMessage {
50 pub(crate) content: FramedContent,
51 pub(crate) auth: FramedContentAuthData,
52 pub(crate) membership_tag: Option<MembershipTag>,
53}
54
55#[cfg(test)]
56impl PublicMessage {
57 pub(crate) fn content(&self) -> &crate::framing::mls_content::FramedContentBody {
58 &self.content.body
59 }
60
61 pub fn set_confirmation_tag(&mut self, confirmation_tag: Option<ConfirmationTag>) {
62 self.auth.confirmation_tag = confirmation_tag;
63 }
64
65 pub fn unset_membership_tag(&mut self) {
66 self.membership_tag = None;
67 }
68
69 pub(crate) fn set_content(&mut self, content: FramedContentBody) {
70 self.content.body = content;
71 }
72
73 pub fn set_epoch(&mut self, epoch: u64) {
74 self.content.epoch = epoch.into();
75 }
76
77 pub fn confirmation_tag(&self) -> Option<&ConfirmationTag> {
78 self.auth.confirmation_tag.as_ref()
79 }
80
81 pub(crate) fn invalidate_signature(&mut self) {
82 let mut modified_signature = self.auth.signature.as_slice().to_vec();
83 modified_signature[0] ^= 0xFF;
84 self.auth.signature.modify(&modified_signature);
85 }
86
87 pub(crate) fn set_sender(&mut self, sender: Sender) {
89 self.content.sender = sender;
90 }
91
92 pub(crate) fn set_group_id(&mut self, group_id: GroupId) {
94 self.content.group_id = group_id;
95 }
96
97 pub(crate) fn is_handshake_message(&self) -> bool {
99 self.content_type().is_handshake_message()
100 }
101}
102
103impl From<AuthenticatedContent> for PublicMessage {
104 fn from(v: AuthenticatedContent) -> Self {
105 Self {
106 content: v.content,
107 auth: v.auth,
108 membership_tag: None,
109 }
110 }
111}
112
113impl PublicMessage {
114 pub fn content_type(&self) -> ContentType {
116 self.content.body.content_type()
117 }
118
119 pub fn group_id(&self) -> &GroupId {
121 &self.content.group_id
122 }
123
124 pub fn epoch(&self) -> GroupEpoch {
126 self.content.epoch
127 }
128
129 pub(crate) fn sender(&self) -> &Sender {
131 &self.content.sender
132 }
133
134 pub(crate) fn set_membership_tag(
139 &mut self,
140 crypto: &impl OpenMlsCrypto,
141 ciphersuite: Ciphersuite,
142 membership_key: &MembershipKey,
143 serialized_context: &[u8],
144 ) -> Result<(), LibraryError> {
145 let tbs_payload = framed_content_tbs_serialized_detached(
146 ProtocolVersion::default(),
147 WireFormat::PublicMessage,
148 &self.content,
149 &self.content.sender,
150 serialized_context,
151 )
152 .map_err(LibraryError::missing_bound_check)?;
153 let tbm_payload = AuthenticatedContentTbm::new(&tbs_payload, &self.auth)?;
154 let membership_tag = membership_key.tag_message(crypto, ciphersuite, tbm_payload)?;
155
156 self.membership_tag = Some(membership_tag);
157 Ok(())
158 }
159
160 #[cfg(test)]
161 pub(crate) fn set_membership_tag_test(&mut self, membership_tag: MembershipTag) {
162 self.membership_tag = Some(membership_tag);
163 }
164}
165
166#[cfg(test)]
167impl From<PublicMessage> for FramedContentTbs {
168 fn from(v: PublicMessage) -> Self {
169 FramedContentTbs {
170 version: ProtocolVersion::default(),
171 wire_format: WireFormat::PublicMessage,
172 content: v.content,
173 serialized_context: None,
174 }
175 }
176}
177
178#[derive(TlsSerialize, TlsSize)]
189pub(crate) struct ConfirmedTranscriptHashInput<'a> {
190 pub(super) wire_format: WireFormat,
191 pub(super) mls_content: &'a FramedContent,
192 pub(super) signature: &'a Signature,
193}
194
195impl ConfirmedTranscriptHashInput<'_> {
196 pub(crate) fn calculate_confirmed_transcript_hash(
197 self,
198 crypto: &impl OpenMlsCrypto,
199 ciphersuite: Ciphersuite,
200 interim_transcript_hash: &[u8],
201 ) -> Result<Vec<u8>, LibraryError> {
202 let serialized: Vec<u8> = self
203 .tls_serialize_detached()
204 .map_err(LibraryError::missing_bound_check)?;
205
206 crypto
207 .hash(
208 ciphersuite.hash_algorithm(),
209 &[interim_transcript_hash, &serialized].concat(),
210 )
211 .map_err(LibraryError::unexpected_crypto_error)
212 }
213}
214
215impl<'a> TryFrom<&'a AuthenticatedContent> for ConfirmedTranscriptHashInput<'a> {
216 type Error = &'static str;
217
218 fn try_from(mls_content: &'a AuthenticatedContent) -> Result<Self, Self::Error> {
219 if !matches!(mls_content.content().content_type(), ContentType::Commit) {
220 return Err("PublicMessage needs to contain a Commit.");
221 }
222
223 Ok(ConfirmedTranscriptHashInput {
224 wire_format: mls_content.wire_format(),
225 mls_content: &mls_content.content,
226 signature: mls_content.signature(),
227 })
228 }
229}
230
231#[derive(TlsSerialize, TlsSize)]
240pub(crate) struct InterimTranscriptHashInput<'a> {
241 pub(crate) confirmation_tag: &'a ConfirmationTag,
242}
243
244impl InterimTranscriptHashInput<'_> {
245 pub fn calculate_interim_transcript_hash(
246 self,
247 crypto: &impl OpenMlsCrypto,
248 ciphersuite: Ciphersuite,
249 confirmed_transcript_hash: &[u8],
250 ) -> Result<Vec<u8>, LibraryError> {
251 let serialized = self
252 .tls_serialize_detached()
253 .map_err(LibraryError::missing_bound_check)?;
254
255 crypto
256 .hash(
257 ciphersuite.hash_algorithm(),
258 &[confirmed_transcript_hash, &serialized].concat(),
259 )
260 .map_err(LibraryError::unexpected_crypto_error)
261 }
262}
263
264impl<'a> TryFrom<&'a PublicMessage> for InterimTranscriptHashInput<'a> {
265 type Error = &'static str;
266
267 fn try_from(public_message: &'a PublicMessage) -> Result<Self, Self::Error> {
268 match public_message.auth.confirmation_tag.as_ref() {
269 Some(confirmation_tag) => Ok(InterimTranscriptHashInput { confirmation_tag }),
270 None => Err("PublicMessage needs to contain a confirmation tag."),
271 }
272 }
273}
274
275impl<'a> From<&'a ConfirmationTag> for InterimTranscriptHashInput<'a> {
276 fn from(confirmation_tag: &'a ConfirmationTag) -> Self {
277 InterimTranscriptHashInput { confirmation_tag }
278 }
279}
280
281impl Size for PublicMessage {
284 #[inline]
285 fn tls_serialized_len(&self) -> usize {
286 self.content.tls_serialized_len()
287 + self.auth.tls_serialized_len()
288 + if let Some(membership_tag) = &self.membership_tag {
289 membership_tag.tls_serialized_len()
290 } else {
291 0
292 }
293 }
294}
295
296impl TlsSerializeTrait for PublicMessage {
297 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
298 let mut written = self.content.tls_serialize(writer)?;
300 written += self.auth.tls_serialize(writer)?;
301 written += if let Some(membership_tag) = &self.membership_tag {
302 membership_tag.tls_serialize(writer)?
303 } else {
304 0
305 };
306 Ok(written)
307 }
308}