1use crate::{error::LibraryError, group::errors::ValidationError, versions::ProtocolVersion};
7
8use super::{
9 mls_auth_content::FramedContentAuthData,
10 mls_auth_content_in::{AuthenticatedContentIn, VerifiableAuthenticatedContentIn},
11 mls_content::{framed_content_tbs_serialized_detached, AuthenticatedContentTbm},
12 mls_content_in::FramedContentIn,
13 *,
14};
15
16use openmls_traits::types::Ciphersuite;
17use std::io::{Read, Write};
18use tls_codec::{Deserialize as TlsDeserializeTrait, Serialize as TlsSerializeTrait};
19
20#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
35pub struct PublicMessageIn {
36 pub(crate) content: FramedContentIn,
37 pub(crate) auth: FramedContentAuthData,
38 pub(crate) membership_tag: Option<MembershipTag>,
39}
40
41#[cfg(any(test, feature = "test-utils"))]
42impl PublicMessageIn {
43 pub(crate) fn content(&self) -> &crate::framing::mls_content_in::FramedContentBodyIn {
44 &self.content.body
45 }
46}
47
48#[cfg(test)]
49impl PublicMessageIn {
50 pub fn set_confirmation_tag(&mut self, confirmation_tag: Option<ConfirmationTag>) {
51 self.auth.confirmation_tag = confirmation_tag;
52 }
53
54 pub fn unset_membership_tag(&mut self) {
55 self.membership_tag = None;
56 }
57
58 pub(crate) fn set_content(&mut self, content: FramedContentBodyIn) {
59 self.content.body = content;
60 }
61
62 pub fn set_epoch(&mut self, epoch: u64) {
63 self.content.epoch = epoch.into();
64 }
65
66 pub(crate) fn set_sender(&mut self, sender: Sender) {
68 self.content.sender = sender;
69 }
70}
71
72impl From<AuthenticatedContentIn> for PublicMessageIn {
73 fn from(v: AuthenticatedContentIn) -> Self {
74 Self {
75 content: v.content,
76 auth: v.auth,
77 membership_tag: None,
78 }
79 }
80}
81
82impl PublicMessageIn {
83 pub(crate) fn new(
85 content: FramedContentIn,
86 auth: FramedContentAuthData,
87 membership_tag: Option<MembershipTag>,
88 ) -> Self {
89 Self {
90 content,
91 auth,
92 membership_tag,
93 }
94 }
95
96 pub fn content_type(&self) -> ContentType {
98 self.content.body.content_type()
99 }
100
101 pub fn sender(&self) -> &Sender {
103 &self.content.sender
104 }
105
106 #[cfg(test)]
107 pub(crate) fn set_membership_tag(
108 &mut self,
109 provider: &impl openmls_traits::OpenMlsProvider,
110 ciphersuite: Ciphersuite,
111 membership_key: &MembershipKey,
112 serialized_context: &[u8],
113 ) -> Result<(), LibraryError> {
114 let tbs_payload = framed_content_tbs_serialized_detached(
115 ProtocolVersion::default(),
116 WireFormat::PublicMessage,
117 &self.content,
118 &self.content.sender,
119 serialized_context,
120 )
121 .map_err(LibraryError::missing_bound_check)?;
122 let tbm_payload = AuthenticatedContentTbm::new(&tbs_payload, &self.auth)?;
123 let membership_tag =
124 membership_key.tag_message(provider.crypto(), ciphersuite, tbm_payload)?;
125
126 self.membership_tag = Some(membership_tag);
127 Ok(())
128 }
129
130 pub(crate) fn verify_membership(
135 &self,
136 crypto: &impl openmls_traits::crypto::OpenMlsCrypto,
137 ciphersuite: Ciphersuite,
138 membership_key: &MembershipKey,
139 serialized_context: &[u8],
140 ) -> Result<(), ValidationError> {
141 log::debug!("Verifying membership tag.");
142 log_crypto!(trace, " Membership key: {:x?}", membership_key);
143 log_crypto!(trace, " Serialized context: {:x?}", serialized_context);
144 let tbs_payload = framed_content_tbs_serialized_detached(
145 ProtocolVersion::default(),
146 WireFormat::PublicMessage,
147 &self.content,
148 &self.content.sender,
149 serialized_context,
150 )
151 .map_err(LibraryError::missing_bound_check)?;
152 let tbm_payload = AuthenticatedContentTbm::new(&tbs_payload, &self.auth)?;
153 let expected_membership_tag =
154 &membership_key.tag_message(crypto, ciphersuite, tbm_payload)?;
155
156 if let Some(membership_tag) = &self.membership_tag {
159 if membership_tag != expected_membership_tag {
161 return Err(ValidationError::InvalidMembershipTag);
162 }
163 } else {
164 return Err(ValidationError::MissingMembershipTag);
165 }
166 Ok(())
167 }
168
169 pub fn epoch(&self) -> GroupEpoch {
171 self.content.epoch
172 }
173
174 pub fn group_id(&self) -> &GroupId {
176 &self.content.group_id
177 }
178
179 pub(crate) fn into_verifiable_content(
181 self,
182 serialized_context: impl Into<Option<Vec<u8>>>,
183 ) -> VerifiableAuthenticatedContentIn {
184 VerifiableAuthenticatedContentIn::new(
185 WireFormat::PublicMessage,
186 self.content,
187 serialized_context,
188 self.auth,
189 )
190 }
191
192 pub(crate) fn membership_tag(&self) -> Option<&MembershipTag> {
194 self.membership_tag.as_ref()
195 }
196
197 pub fn confirmation_tag(&self) -> Option<&ConfirmationTag> {
199 self.auth.confirmation_tag.as_ref()
200 }
201}
202
203#[cfg(test)]
204impl From<PublicMessageIn> for FramedContentTbsIn {
205 fn from(v: PublicMessageIn) -> Self {
206 FramedContentTbsIn {
207 version: ProtocolVersion::default(),
208 wire_format: WireFormat::PublicMessage,
209 content: v.content,
210 serialized_context: None,
211 }
212 }
213}
214
215impl<'a> TryFrom<&'a PublicMessageIn> for InterimTranscriptHashInput<'a> {
216 type Error = &'static str;
217
218 fn try_from(public_message: &'a PublicMessageIn) -> Result<Self, Self::Error> {
219 match public_message.auth.confirmation_tag.as_ref() {
220 Some(confirmation_tag) => Ok(InterimTranscriptHashInput { confirmation_tag }),
221 None => Err("PublicMessage needs to contain a confirmation tag."),
222 }
223 }
224}
225
226impl TlsDeserializeTrait for PublicMessageIn {
227 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error> {
228 let content = FramedContentIn::tls_deserialize(bytes)?;
229 let auth = FramedContentAuthData::deserialize(bytes, content.body.content_type())?;
230 let membership_tag = if content.sender.is_member() {
231 Some(MembershipTag::tls_deserialize(bytes)?)
232 } else {
233 None
234 };
235
236 Ok(PublicMessageIn::new(content, auth, membership_tag))
237 }
238}
239
240impl DeserializeBytes for PublicMessageIn {
241 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
242 where
243 Self: Sized,
244 {
245 let mut bytes_ref = bytes;
246 let message = PublicMessageIn::tls_deserialize(&mut bytes_ref)?;
247 let remainder = &bytes[message.tls_serialized_len()..];
248 Ok((message, remainder))
249 }
250}
251
252impl Size for PublicMessageIn {
253 #[inline]
254 fn tls_serialized_len(&self) -> usize {
255 self.content.tls_serialized_len()
256 + self.auth.tls_serialized_len()
257 + if let Some(membership_tag) = &self.membership_tag {
258 membership_tag.tls_serialized_len()
259 } else {
260 0
261 }
262 }
263}
264
265impl TlsSerializeTrait for PublicMessageIn {
266 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
267 let mut written = self.content.tls_serialize(writer)?;
269 written += self.auth.tls_serialize(writer)?;
270 written += if let Some(membership_tag) = &self.membership_tag {
271 membership_tag.tls_serialize(writer)?
272 } else {
273 0
274 };
275 Ok(written)
276 }
277}
278
279#[cfg(any(feature = "test-utils", test))]
282impl From<PublicMessageIn> for PublicMessage {
283 fn from(v: PublicMessageIn) -> Self {
284 PublicMessage {
285 content: v.content.into(),
286 auth: v.auth,
287 membership_tag: v.membership_tag,
288 }
289 }
290}
291
292#[cfg(any(feature = "test-utils", test))]
293impl From<PublicMessage> for PublicMessageIn {
294 fn from(v: PublicMessage) -> Self {
295 PublicMessageIn {
296 content: v.content.into(),
297 auth: v.auth,
298 membership_tag: v.membership_tag,
299 }
300 }
301}