1use std::io::Write;
7
8use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
9use tls_codec::{
10 Serialize as TlsSerializeTrait, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
11};
12
13use super::{
14 mls_auth_content::{AuthenticatedContent, FramedContentAuthData},
15 mls_content::{framed_content_tbs_serialized_detached, AuthenticatedContentTbm, FramedContent},
16 *,
17};
18use crate::{error::LibraryError, versions::ProtocolVersion};
19
20#[derive(
22 Debug,
23 PartialEq,
24 Clone,
25 Serialize,
26 Deserialize,
27 TlsSerialize,
28 TlsDeserialize,
29 TlsDeserializeBytes,
30 TlsSize,
31)]
32pub(crate) struct MembershipTag(pub(crate) Mac);
33
34#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
49pub struct PublicMessage {
50 pub(crate) content: FramedContent,
51 pub(crate) auth: FramedContentAuthData,
52 pub(crate) membership_tag: Option<MembershipTag>,
53}
54
55#[cfg(test)]
56impl PublicMessage {
57 pub(crate) fn content(&self) -> &crate::framing::mls_content::FramedContentBody {
58 &self.content.body
59 }
60
61 pub fn set_confirmation_tag(&mut self, confirmation_tag: Option<ConfirmationTag>) {
62 self.auth.confirmation_tag = confirmation_tag;
63 }
64
65 pub fn unset_membership_tag(&mut self) {
66 self.membership_tag = None;
67 }
68
69 pub(crate) fn set_content(&mut self, content: FramedContentBody) {
70 self.content.body = content;
71 }
72
73 pub fn set_epoch(&mut self, epoch: u64) {
74 self.content.epoch = epoch.into();
75 }
76
77 pub fn confirmation_tag(&self) -> Option<&ConfirmationTag> {
78 self.auth.confirmation_tag.as_ref()
79 }
80
81 pub(crate) fn invalidate_signature(&mut self) {
82 let mut modified_signature = self.auth.signature.as_slice().to_vec();
83 modified_signature[0] ^= 0xFF;
84 self.auth.signature.modify(&modified_signature);
85 }
86
87 pub(crate) fn set_sender(&mut self, sender: Sender) {
89 self.content.sender = sender;
90 }
91
92 pub(crate) fn set_group_id(&mut self, group_id: GroupId) {
94 self.content.group_id = group_id;
95 }
96
97 pub(crate) fn is_handshake_message(&self) -> bool {
99 self.content_type().is_handshake_message()
100 }
101}
102
103impl From<AuthenticatedContent> for PublicMessage {
104 fn from(v: AuthenticatedContent) -> Self {
105 Self {
106 content: v.content,
107 auth: v.auth,
108 membership_tag: None,
109 }
110 }
111}
112
113impl PublicMessage {
114 pub fn content_type(&self) -> ContentType {
116 self.content.body.content_type()
117 }
118
119 pub fn epoch(&self) -> GroupEpoch {
121 self.content.epoch
122 }
123
124 pub(crate) fn sender(&self) -> &Sender {
126 &self.content.sender
127 }
128
129 pub(crate) fn set_membership_tag(
134 &mut self,
135 crypto: &impl OpenMlsCrypto,
136 ciphersuite: Ciphersuite,
137 membership_key: &MembershipKey,
138 serialized_context: &[u8],
139 ) -> Result<(), LibraryError> {
140 let tbs_payload = framed_content_tbs_serialized_detached(
141 ProtocolVersion::default(),
142 WireFormat::PublicMessage,
143 &self.content,
144 &self.content.sender,
145 serialized_context,
146 )
147 .map_err(LibraryError::missing_bound_check)?;
148 let tbm_payload = AuthenticatedContentTbm::new(&tbs_payload, &self.auth)?;
149 let membership_tag = membership_key.tag_message(crypto, ciphersuite, tbm_payload)?;
150
151 self.membership_tag = Some(membership_tag);
152 Ok(())
153 }
154
155 #[cfg(test)]
156 pub(crate) fn set_membership_tag_test(&mut self, membership_tag: MembershipTag) {
157 self.membership_tag = Some(membership_tag);
158 }
159}
160
161#[cfg(test)]
162impl From<PublicMessage> for FramedContentTbs {
163 fn from(v: PublicMessage) -> Self {
164 FramedContentTbs {
165 version: ProtocolVersion::default(),
166 wire_format: WireFormat::PublicMessage,
167 content: v.content,
168 serialized_context: None,
169 }
170 }
171}
172
173#[derive(TlsSerialize, TlsSize)]
184pub(crate) struct ConfirmedTranscriptHashInput<'a> {
185 pub(super) wire_format: WireFormat,
186 pub(super) mls_content: &'a FramedContent,
187 pub(super) signature: &'a Signature,
188}
189
190impl ConfirmedTranscriptHashInput<'_> {
191 pub(crate) fn calculate_confirmed_transcript_hash(
192 self,
193 crypto: &impl OpenMlsCrypto,
194 ciphersuite: Ciphersuite,
195 interim_transcript_hash: &[u8],
196 ) -> Result<Vec<u8>, LibraryError> {
197 let serialized: Vec<u8> = self
198 .tls_serialize_detached()
199 .map_err(LibraryError::missing_bound_check)?;
200
201 crypto
202 .hash(
203 ciphersuite.hash_algorithm(),
204 &[interim_transcript_hash, &serialized].concat(),
205 )
206 .map_err(LibraryError::unexpected_crypto_error)
207 }
208}
209
210impl<'a> TryFrom<&'a AuthenticatedContent> for ConfirmedTranscriptHashInput<'a> {
211 type Error = &'static str;
212
213 fn try_from(mls_content: &'a AuthenticatedContent) -> Result<Self, Self::Error> {
214 if !matches!(mls_content.content().content_type(), ContentType::Commit) {
215 return Err("PublicMessage needs to contain a Commit.");
216 }
217
218 Ok(ConfirmedTranscriptHashInput {
219 wire_format: mls_content.wire_format(),
220 mls_content: &mls_content.content,
221 signature: mls_content.signature(),
222 })
223 }
224}
225
226#[derive(TlsSerialize, TlsSize)]
235pub(crate) struct InterimTranscriptHashInput<'a> {
236 pub(crate) confirmation_tag: &'a ConfirmationTag,
237}
238
239impl InterimTranscriptHashInput<'_> {
240 pub fn calculate_interim_transcript_hash(
241 self,
242 crypto: &impl OpenMlsCrypto,
243 ciphersuite: Ciphersuite,
244 confirmed_transcript_hash: &[u8],
245 ) -> Result<Vec<u8>, LibraryError> {
246 let serialized = self
247 .tls_serialize_detached()
248 .map_err(LibraryError::missing_bound_check)?;
249
250 crypto
251 .hash(
252 ciphersuite.hash_algorithm(),
253 &[confirmed_transcript_hash, &serialized].concat(),
254 )
255 .map_err(LibraryError::unexpected_crypto_error)
256 }
257}
258
259impl<'a> TryFrom<&'a PublicMessage> for InterimTranscriptHashInput<'a> {
260 type Error = &'static str;
261
262 fn try_from(public_message: &'a PublicMessage) -> Result<Self, Self::Error> {
263 match public_message.auth.confirmation_tag.as_ref() {
264 Some(confirmation_tag) => Ok(InterimTranscriptHashInput { confirmation_tag }),
265 None => Err("PublicMessage needs to contain a confirmation tag."),
266 }
267 }
268}
269
270impl<'a> From<&'a ConfirmationTag> for InterimTranscriptHashInput<'a> {
271 fn from(confirmation_tag: &'a ConfirmationTag) -> Self {
272 InterimTranscriptHashInput { confirmation_tag }
273 }
274}
275
276impl Size for PublicMessage {
279 #[inline]
280 fn tls_serialized_len(&self) -> usize {
281 self.content.tls_serialized_len()
282 + self.auth.tls_serialized_len()
283 + if let Some(membership_tag) = &self.membership_tag {
284 membership_tag.tls_serialized_len()
285 } else {
286 0
287 }
288 }
289}
290
291impl TlsSerializeTrait for PublicMessage {
292 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
293 let mut written = self.content.tls_serialize(writer)?;
295 written += self.auth.tls_serialize(writer)?;
296 written += if let Some(membership_tag) = &self.membership_tag {
297 membership_tag.tls_serialize(writer)?
298 } else {
299 0
300 };
301 Ok(written)
302 }
303}