openmls/test_utils/frankenstein/
framing.rs

1use openmls_traits::{crypto::OpenMlsCrypto, signatures::Signer, types::Ciphersuite};
2use tls_codec::*;
3
4use crate::{
5    binary_tree::LeafNodeIndex,
6    extensions::SenderExtensionIndex,
7    framing::{
8        mls_content::{AuthenticatedContentTbm, FramedContentBody, FramedContentTbs},
9        mls_content_in::FramedContentBodyIn,
10        MlsMessageIn, MlsMessageOut, PrivateMessage, PrivateMessageIn, PublicMessage,
11        PublicMessageIn, Sender, WireFormat,
12    },
13    group::GroupContext,
14    messages::{ConfirmationTag, Welcome},
15    prelude_test::signable::Signable,
16    schedule::{ConfirmationKey, MembershipKey},
17};
18
19use super::{
20    commit::FrankenCommit,
21    compute_membership_tag,
22    group_info::{FrankenGroupContext, FrankenGroupInfo},
23    sign_with_label, FrankenKeyPackage, FrankenProposal,
24};
25
26#[derive(
27    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
28)]
29pub struct FrankenMlsMessage {
30    pub version: u16,
31    pub body: FrankenMlsMessageBody,
32}
33
34#[allow(clippy::large_enum_variant)]
35#[derive(
36    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
37)]
38#[repr(u16)]
39pub enum FrankenMlsMessageBody {
40    #[tls_codec(discriminant = 1)]
41    PublicMessage(FrankenPublicMessage),
42    #[tls_codec(discriminant = 2)]
43    PrivateMessage(FrankenPrivateMessage),
44    #[tls_codec(discriminant = 3)]
45    Welcome(FrankenWelcome),
46    #[tls_codec(discriminant = 4)]
47    GroupInfo(FrankenGroupInfo),
48    #[tls_codec(discriminant = 5)]
49    KeyPackage(FrankenKeyPackage),
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct FrankenPublicMessage {
54    pub content: FrankenFramedContent,
55    pub auth: FrankenFramedContentAuthData,
56    pub membership_tag: Option<VLBytes>,
57}
58
59impl tls_codec::Size for FrankenPublicMessage {
60    fn tls_serialized_len(&self) -> usize {
61        let tag_len = self
62            .membership_tag
63            .as_ref()
64            .map_or(0, |tag| tag.tls_serialized_len());
65
66        self.content.tls_serialized_len() + self.auth.tls_serialized_len() + tag_len
67    }
68}
69
70impl Deserialize for FrankenPublicMessage {
71    fn tls_deserialize<R: std::io::prelude::Read>(bytes: &mut R) -> Result<Self, Error>
72    where
73        Self: Sized,
74    {
75        let content = FrankenFramedContent::tls_deserialize(bytes)?;
76        let auth = if matches!(content.body, FrankenFramedContentBody::Commit(_)) {
77            FrankenFramedContentAuthData::tls_deserialize_with_tag(bytes)?
78        } else {
79            FrankenFramedContentAuthData::tls_deserialize_without_tag(bytes)?
80        };
81
82        let membership_tag = if matches!(content.sender, FrankenSender::Member(_)) {
83            Some(VLBytes::tls_deserialize(bytes)?)
84        } else {
85            None
86        };
87
88        Ok(Self {
89            content,
90            auth,
91            membership_tag,
92        })
93    }
94}
95
96impl DeserializeBytes for FrankenPublicMessage {
97    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
98    where
99        Self: Sized,
100    {
101        let (content, bytes) = FrankenFramedContent::tls_deserialize_bytes(bytes)?;
102        let (auth, bytes) = match content.body {
103            FrankenFramedContentBody::Commit(_) => {
104                FrankenFramedContentAuthData::tls_deserialize_bytes_with_tag(bytes)
105            }
106            _ => FrankenFramedContentAuthData::tls_deserialize_bytes_without_tag(bytes),
107        }?;
108        let (membership_tag, bytes) = match content.sender {
109            FrankenSender::Member(_) => {
110                let (tag, bytes) = VLBytes::tls_deserialize_bytes(bytes)?;
111                (Some(tag), bytes)
112            }
113            _ => (None, bytes),
114        };
115
116        Ok((
117            Self {
118                content,
119                auth,
120                membership_tag,
121            },
122            bytes,
123        ))
124    }
125}
126
127impl Serialize for FrankenPublicMessage {
128    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
129        let mut written = 0;
130        written += self.content.tls_serialize(writer)?;
131        written += self.auth.tls_serialize(writer)?;
132        if let Some(tag) = &self.membership_tag {
133            written += tag.tls_serialize(writer)?;
134        }
135
136        Ok(written)
137    }
138}
139
140impl FrankenPublicMessage {
141    /// auth builds a mostly(!) valid fake public message. However, it does not compute a correct
142    /// confirmation_tag. If the caller wants to process a message that requires a
143    /// confirmation_tag, they have two options:
144    ///
145    /// 1. build a valid tag themselves and provide it through the option
146    /// 2. provide a dummy tag and disable the verification of confirmation tags using
147    ///    [`crate::disable_confirmation_tag_verification`].
148    ///    NB: Usually, confirmation tag verification should be turned back on after the call that
149    ///    needs to be tricked!
150    pub(crate) fn auth(
151        provider: &impl crate::storage::OpenMlsProvider,
152        ciphersuite: openmls_traits::types::Ciphersuite,
153        signer: &impl Signer,
154        content: FrankenFramedContent,
155        group_context: Option<&FrankenGroupContext>,
156        membership_key: Option<&[u8]>,
157        confirmation_tag: Option<VLBytes>,
158    ) -> Self {
159        let version = 1; // MLS 1.0
160        let wire_format = 1; // PublicMessage
161
162        let franken_tbs = FrankenFramedContentTbs {
163            version,
164            wire_format,
165            content: &content,
166            group_context,
167        };
168
169        let auth = FrankenFramedContentAuthData::build(
170            signer,
171            version,
172            wire_format,
173            &content,
174            group_context,
175            confirmation_tag,
176        );
177
178        let tbm = FrankenAuthenticatedContentTbm {
179            content_tbs: franken_tbs,
180            auth: auth.clone(),
181        };
182
183        let membership_tag = membership_key.map(|membership_key| {
184            compute_membership_tag(provider.crypto(), ciphersuite, membership_key, &tbm)
185        });
186
187        FrankenPublicMessage {
188            content,
189            auth,
190            membership_tag,
191        }
192    }
193}
194
195#[derive(
196    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
197)]
198pub struct FrankenFramedContent {
199    pub group_id: VLBytes,
200    pub epoch: u64,
201    pub sender: FrankenSender,
202    pub authenticated_data: VLBytes,
203    pub body: FrankenFramedContentBody,
204}
205
206#[derive(
207    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
208)]
209#[repr(u8)]
210pub enum FrankenSender {
211    #[tls_codec(discriminant = 1)]
212    Member(u32),
213    External(u32),
214    NewMemberProposal,
215    NewMemberCommit,
216}
217
218#[derive(
219    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
220)]
221#[repr(u8)]
222pub enum FrankenFramedContentBody {
223    #[tls_codec(discriminant = 1)]
224    Application(VLBytes),
225    #[tls_codec(discriminant = 2)]
226    Proposal(FrankenProposal),
227    #[tls_codec(discriminant = 3)]
228    Commit(FrankenCommit),
229}
230
231#[derive(
232    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
233)]
234pub struct FrankenPrivateMessage {
235    pub group_id: VLBytes,
236    pub epoch: VLBytes,
237    pub content_type: FrankenContentType,
238    pub authenticated_data: VLBytes,
239    pub encrypted_sender_data: VLBytes,
240    pub ciphertext: VLBytes,
241}
242
243#[derive(
244    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
245)]
246pub struct FrankenWelcome {
247    pub cipher_suite: u16,
248    pub secrets: Vec<FrankenEncryptedGroupSecrets>,
249    pub encrypted_group_info: VLBytes,
250}
251
252#[derive(Debug, Clone, PartialEq, Eq)]
253pub struct FrankenFramedContentAuthData {
254    pub signature: VLBytes,
255    pub confirmation_tag: Option<VLBytes>,
256}
257
258impl FrankenFramedContentAuthData {
259    pub fn tls_deserialize_with_tag<R: std::io::Read>(
260        bytes: &mut R,
261    ) -> Result<Self, tls_codec::Error> {
262        let signature = VLBytes::tls_deserialize(bytes)?;
263        let confirmation_tag = VLBytes::tls_deserialize(bytes)?;
264
265        Ok(Self {
266            signature,
267            confirmation_tag: Some(confirmation_tag),
268        })
269    }
270
271    pub fn tls_deserialize_bytes_with_tag(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error> {
272        let (signature, bytes) = VLBytes::tls_deserialize_bytes(bytes)?;
273        let (confirmation_tag, bytes) = VLBytes::tls_deserialize_bytes(bytes)?;
274
275        Ok((
276            Self {
277                signature,
278                confirmation_tag: Some(confirmation_tag),
279            },
280            bytes,
281        ))
282    }
283
284    pub fn tls_deserialize_without_tag<R: std::io::Read>(
285        bytes: &mut R,
286    ) -> Result<Self, tls_codec::Error> {
287        let signature = VLBytes::tls_deserialize(bytes)?;
288
289        Ok(Self {
290            signature,
291            confirmation_tag: None,
292        })
293    }
294
295    pub fn tls_deserialize_bytes_without_tag(
296        bytes: &[u8],
297    ) -> Result<(Self, &[u8]), tls_codec::Error> {
298        let (signature, bytes) = VLBytes::tls_deserialize_bytes(bytes)?;
299
300        Ok((
301            Self {
302                signature,
303                confirmation_tag: None,
304            },
305            bytes,
306        ))
307    }
308}
309
310impl tls_codec::Size for FrankenFramedContentAuthData {
311    fn tls_serialized_len(&self) -> usize {
312        if let Some(tag) = &self.confirmation_tag {
313            self.signature.tls_serialized_len() + tag.tls_serialized_len()
314        } else {
315            self.signature.tls_serialized_len()
316        }
317    }
318}
319
320impl Serialize for FrankenFramedContentAuthData {
321    fn tls_serialize<W: std::io::prelude::Write>(&self, writer: &mut W) -> Result<usize, Error> {
322        let mut written = 0;
323        written += self.signature.tls_serialize(writer)?;
324        if let Some(confirmation_tag) = &self.confirmation_tag {
325            written += confirmation_tag.tls_serialize(writer)?;
326        }
327        Ok(written)
328    }
329}
330
331impl FrankenFramedContentAuthData {
332    pub fn build(
333        signer: &impl Signer,
334        version: u16,
335        wire_format: u16,
336        content: &FrankenFramedContent,
337        group_context: Option<&FrankenGroupContext>,
338        confirmation_tag: Option<VLBytes>,
339    ) -> Self {
340        let content_tbs = FrankenFramedContentTbs {
341            version,
342            wire_format,
343            content,
344            group_context,
345        };
346
347        let content_tbs_serialized = content_tbs.tls_serialize_detached().unwrap();
348
349        let signature =
350            sign_with_label(signer, b"FramedContentTBS", &content_tbs_serialized).into();
351
352        Self {
353            signature,
354            confirmation_tag,
355        }
356    }
357}
358
359#[derive(Debug, Clone, PartialEq, Eq)]
360pub struct FrankenFramedContentTbs<'a> {
361    version: u16,
362    wire_format: u16,
363    content: &'a FrankenFramedContent,
364    group_context: Option<&'a FrankenGroupContext>,
365}
366
367impl tls_codec::Size for FrankenFramedContentTbs<'_> {
368    fn tls_serialized_len(&self) -> usize {
369        if let Some(ctx) = self.group_context {
370            4 + self.content.tls_serialized_len() + ctx.tls_serialized_len()
371        } else {
372            4 + self.content.tls_serialized_len()
373        }
374    }
375}
376
377impl Serialize for FrankenFramedContentTbs<'_> {
378    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
379        writer.write_all(&self.version.to_be_bytes())?;
380        writer.write_all(&self.wire_format.to_be_bytes())?;
381
382        let mut written = 4; // contains the two u16 version and wire_format
383        written += self.content.tls_serialize(writer)?;
384        if let Some(group_context) = &self.group_context {
385            written += group_context.tls_serialize(writer)?;
386        }
387
388        Ok(written)
389    }
390}
391
392#[derive(Debug, Clone, PartialEq, Eq, TlsSerialize, TlsSize)]
393pub struct FrankenAuthenticatedContentTbm<'a> {
394    content_tbs: FrankenFramedContentTbs<'a>,
395    auth: FrankenFramedContentAuthData,
396}
397
398#[derive(
399    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
400)]
401#[repr(u8)]
402pub enum FrankenContentType {
403    Application = 1,
404    Proposal = 2,
405    Commit = 3,
406}
407
408#[derive(
409    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
410)]
411pub struct FrankenEncryptedGroupSecrets {
412    pub new_member: VLBytes,
413    pub encrypted_group_secrets: VLBytes,
414}
415
416impl From<MlsMessageOut> for FrankenMlsMessage {
417    fn from(ln: MlsMessageOut) -> Self {
418        FrankenMlsMessage::tls_deserialize(&mut ln.tls_serialize_detached().unwrap().as_slice())
419            .unwrap()
420    }
421}
422
423impl From<FrankenMlsMessage> for MlsMessageOut {
424    fn from(fln: FrankenMlsMessage) -> Self {
425        MlsMessageIn::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice())
426            .unwrap()
427            .into()
428    }
429}
430
431impl From<PublicMessage> for FrankenPublicMessage {
432    fn from(ln: PublicMessage) -> Self {
433        FrankenPublicMessage::tls_deserialize(&mut ln.tls_serialize_detached().unwrap().as_slice())
434            .unwrap()
435    }
436}
437
438impl From<FrankenPublicMessage> for PublicMessage {
439    fn from(fln: FrankenPublicMessage) -> Self {
440        PublicMessageIn::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice())
441            .unwrap()
442            .into()
443    }
444}
445
446impl From<PrivateMessage> for FrankenPrivateMessage {
447    fn from(ln: PrivateMessage) -> Self {
448        FrankenPrivateMessage::tls_deserialize(&mut ln.tls_serialize_detached().unwrap().as_slice())
449            .unwrap()
450    }
451}
452
453impl From<FrankenPrivateMessage> for PrivateMessage {
454    fn from(fln: FrankenPrivateMessage) -> Self {
455        PrivateMessageIn::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice())
456            .unwrap()
457            .into()
458    }
459}
460
461impl From<Welcome> for FrankenWelcome {
462    fn from(ln: Welcome) -> Self {
463        FrankenWelcome::tls_deserialize(&mut ln.tls_serialize_detached().unwrap().as_slice())
464            .unwrap()
465    }
466}
467
468impl From<FrankenWelcome> for Welcome {
469    fn from(fln: FrankenWelcome) -> Self {
470        Welcome::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice()).unwrap()
471    }
472}
473
474impl From<FrankenFramedContentBody> for FramedContentBodyIn {
475    fn from(value: FrankenFramedContentBody) -> Self {
476        FramedContentBodyIn::tls_deserialize(
477            &mut value.tls_serialize_detached().unwrap().as_slice(),
478        )
479        .unwrap()
480    }
481}
482
483impl From<FrankenFramedContentBody> for FramedContentBody {
484    fn from(value: FrankenFramedContentBody) -> Self {
485        FramedContentBodyIn::from(value).into()
486    }
487}
488
489impl From<Sender> for FrankenSender {
490    fn from(value: Sender) -> Self {
491        match value {
492            Sender::Member(i) => FrankenSender::Member(i.u32()),
493            // this cast is safe, because the index method casts it from u32 to usize for some
494            // reason, so it's known to fit u32
495            Sender::External(i) => FrankenSender::External(i.index() as u32),
496            Sender::NewMemberProposal => FrankenSender::NewMemberProposal,
497            Sender::NewMemberCommit => FrankenSender::NewMemberCommit,
498        }
499    }
500}
501
502impl From<FrankenSender> for Sender {
503    fn from(value: FrankenSender) -> Self {
504        match value {
505            FrankenSender::Member(i) => Sender::Member(LeafNodeIndex::new(i)),
506            FrankenSender::External(i) => Sender::External(SenderExtensionIndex::new(i)),
507            FrankenSender::NewMemberProposal => Sender::NewMemberProposal,
508            FrankenSender::NewMemberCommit => Sender::NewMemberCommit,
509        }
510    }
511}