openmls/test_utils/frankenstein/
codec.rs

1use std::io::{Read, Write};
2
3use tls_codec::*;
4
5use super::{
6    extensions::{
7        FrankenApplicationIdExtension, FrankenExtension, FrankenExtensionType,
8        FrankenExternalPubExtension, FrankenExternalSendersExtension, FrankenRatchetTreeExtension,
9        FrankenRequiredCapabilitiesExtension,
10    },
11    FrankenAddProposal, FrankenCustomProposal, FrankenExternalInitProposal,
12    FrankenPreSharedKeyProposal, FrankenProposal, FrankenProposalType, FrankenReInitProposal,
13    FrankenRemoveProposal, FrankenUpdateProposal,
14};
15
16#[cfg(feature = "extensions-draft-08")]
17use super::FrankenAppEphemeralProposal;
18
19fn vlbytes_len_len(length: usize) -> usize {
20    if length < 0x40 {
21        1
22    } else if length < 0x3fff {
23        2
24    } else if length < 0x3fff_ffff {
25        4
26    } else {
27        8
28    }
29}
30
31impl Size for FrankenProposalType {
32    fn tls_serialized_len(&self) -> usize {
33        2
34    }
35}
36
37impl Deserialize for FrankenProposalType {
38    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
39    where
40        Self: Sized,
41    {
42        let mut proposal_type = [0u8; 2];
43        bytes.read_exact(&mut proposal_type)?;
44
45        Ok(FrankenProposalType::from(u16::from_be_bytes(proposal_type)))
46    }
47}
48
49impl Serialize for FrankenProposalType {
50    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
51        writer.write_all(&u16::from(*self).to_be_bytes())?;
52
53        Ok(2)
54    }
55}
56
57impl DeserializeBytes for FrankenProposalType {
58    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
59    where
60        Self: Sized,
61    {
62        let mut bytes_ref = bytes;
63        let proposal_type = FrankenProposalType::tls_deserialize(&mut bytes_ref)?;
64        let remainder = &bytes[proposal_type.tls_serialized_len()..];
65        Ok((proposal_type, remainder))
66    }
67}
68
69impl Size for FrankenProposal {
70    fn tls_serialized_len(&self) -> usize {
71        self.proposal_type().tls_serialized_len()
72            + match self {
73                FrankenProposal::Add(p) => p.tls_serialized_len(),
74                FrankenProposal::Update(p) => p.tls_serialized_len(),
75                FrankenProposal::Remove(p) => p.tls_serialized_len(),
76                FrankenProposal::PreSharedKey(p) => p.tls_serialized_len(),
77                FrankenProposal::ReInit(p) => p.tls_serialized_len(),
78                FrankenProposal::ExternalInit(p) => p.tls_serialized_len(),
79                FrankenProposal::GroupContextExtensions(p) => p.tls_serialized_len(),
80                #[cfg(feature = "extensions-draft-08")]
81                FrankenProposal::AppEphemeral(p) => p.tls_serialized_len(),
82                FrankenProposal::Custom(p) => p.tls_serialized_len(),
83            }
84    }
85}
86
87impl Serialize for FrankenProposal {
88    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
89        let written = self.proposal_type().tls_serialize(writer)?;
90        match self {
91            FrankenProposal::Add(p) => p.tls_serialize(writer),
92            FrankenProposal::Update(p) => p.tls_serialize(writer),
93            FrankenProposal::Remove(p) => p.tls_serialize(writer),
94            FrankenProposal::PreSharedKey(p) => p.tls_serialize(writer),
95            FrankenProposal::ReInit(p) => p.tls_serialize(writer),
96            FrankenProposal::ExternalInit(p) => p.tls_serialize(writer),
97            FrankenProposal::GroupContextExtensions(p) => p.tls_serialize(writer),
98            #[cfg(feature = "extensions-draft-08")]
99            FrankenProposal::AppEphemeral(p) => p.tls_serialize(writer),
100            FrankenProposal::Custom(p) => p.payload.tls_serialize(writer),
101        }
102        .map(|l| written + l)
103    }
104}
105
106impl Deserialize for FrankenProposal {
107    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
108    where
109        Self: Sized,
110    {
111        let proposal_type = FrankenProposalType::tls_deserialize(bytes)?;
112        let proposal = match proposal_type {
113            FrankenProposalType::Add => {
114                FrankenProposal::Add(FrankenAddProposal::tls_deserialize(bytes)?)
115            }
116            FrankenProposalType::Update => {
117                FrankenProposal::Update(FrankenUpdateProposal::tls_deserialize(bytes)?)
118            }
119            FrankenProposalType::Remove => {
120                FrankenProposal::Remove(FrankenRemoveProposal::tls_deserialize(bytes)?)
121            }
122            FrankenProposalType::PreSharedKey => {
123                FrankenProposal::PreSharedKey(FrankenPreSharedKeyProposal::tls_deserialize(bytes)?)
124            }
125            FrankenProposalType::Reinit => {
126                FrankenProposal::ReInit(FrankenReInitProposal::tls_deserialize(bytes)?)
127            }
128            FrankenProposalType::ExternalInit => {
129                FrankenProposal::ExternalInit(FrankenExternalInitProposal::tls_deserialize(bytes)?)
130            }
131            FrankenProposalType::GroupContextExtensions => FrankenProposal::GroupContextExtensions(
132                Vec::<FrankenExtension>::tls_deserialize(bytes)?,
133            ),
134            #[cfg(feature = "extensions-draft-08")]
135            FrankenProposalType::AppEphemeral => {
136                FrankenProposal::AppEphemeral(FrankenAppEphemeralProposal::tls_deserialize(bytes)?)
137            }
138            FrankenProposalType::Custom(_) => {
139                let payload = VLBytes::tls_deserialize(bytes)?;
140                let custom_proposal = FrankenCustomProposal {
141                    proposal_type: proposal_type.into(),
142                    payload,
143                };
144                FrankenProposal::Custom(custom_proposal)
145            }
146        };
147        Ok(proposal)
148    }
149}
150
151impl DeserializeBytes for FrankenProposal {
152    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error>
153    where
154        Self: Sized,
155    {
156        let mut bytes_ref = bytes;
157        let proposal = FrankenProposal::tls_deserialize(&mut bytes_ref)?;
158        let remainder = &bytes[proposal.tls_serialized_len()..];
159        Ok((proposal, remainder))
160    }
161}
162
163impl Size for FrankenExtensionType {
164    fn tls_serialized_len(&self) -> usize {
165        2
166    }
167}
168
169impl Deserialize for FrankenExtensionType {
170    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
171    where
172        Self: Sized,
173    {
174        let mut extension_type = [0u8; 2];
175        bytes.read_exact(&mut extension_type)?;
176
177        Ok(FrankenExtensionType::from(u16::from_be_bytes(
178            extension_type,
179        )))
180    }
181}
182
183impl DeserializeBytes for FrankenExtensionType {
184    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
185    where
186        Self: Sized,
187    {
188        let mut bytes_ref = bytes;
189        let extension_type = FrankenExtensionType::tls_deserialize(&mut bytes_ref)?;
190        let remainder = &bytes[extension_type.tls_serialized_len()..];
191        Ok((extension_type, remainder))
192    }
193}
194
195impl Serialize for FrankenExtensionType {
196    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
197        writer.write_all(&u16::from(*self).to_be_bytes())?;
198
199        Ok(2)
200    }
201}
202
203impl Size for FrankenExtension {
204    fn tls_serialized_len(&self) -> usize {
205        let extension_type_length = 2;
206        let extension_data_len = match self {
207            FrankenExtension::ApplicationId(e) => e.tls_serialized_len(),
208            FrankenExtension::RatchetTree(e) => e.tls_serialized_len(),
209            FrankenExtension::RequiredCapabilities(e) => e.tls_serialized_len(),
210            FrankenExtension::ExternalPub(e) => e.tls_serialized_len(),
211            FrankenExtension::ExternalSenders(e) => e.tls_serialized_len(),
212            FrankenExtension::LastResort => 0,
213            FrankenExtension::Unknown(_, e) => e.tls_serialized_len(),
214        };
215        let vlbytes_len_len = vlbytes_len_len(extension_data_len);
216        extension_type_length + vlbytes_len_len + extension_data_len
217    }
218}
219
220impl Serialize for FrankenExtension {
221    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
222        let written = self.extension_type().tls_serialize(writer)?;
223
224        // subtract the two bytes for the type header
225        let extension_data_len = self.tls_serialized_len() - 2;
226        let mut extension_data = Vec::with_capacity(extension_data_len);
227
228        let _ = match self {
229            FrankenExtension::ApplicationId(e) => e.tls_serialize(&mut extension_data),
230            FrankenExtension::RatchetTree(e) => e.tls_serialize(&mut extension_data),
231            FrankenExtension::RequiredCapabilities(e) => e.tls_serialize(&mut extension_data),
232            FrankenExtension::ExternalPub(e) => e.tls_serialize(&mut extension_data),
233            FrankenExtension::ExternalSenders(e) => e.tls_serialize(&mut extension_data),
234            FrankenExtension::LastResort => Ok(0),
235            FrankenExtension::Unknown(_, e) => extension_data
236                .write_all(e.as_slice())
237                .map(|_| e.tls_serialized_len())
238                .map_err(|_| tls_codec::Error::EndOfStream),
239        }?;
240
241        Serialize::tls_serialize(&extension_data, writer).map(|l| l + written)
242    }
243}
244
245impl Deserialize for FrankenExtension {
246    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, tls_codec::Error> {
247        // Read the extension type and extension data.
248        let extension_type = FrankenExtensionType::tls_deserialize(bytes)?;
249        let extension_data = VLBytes::tls_deserialize(bytes)?;
250
251        // Now deserialize the extension itself from the extension data.
252        let mut extension_data = extension_data.as_slice();
253        Ok(match extension_type {
254            FrankenExtensionType::ApplicationId => FrankenExtension::ApplicationId(
255                FrankenApplicationIdExtension::tls_deserialize(&mut extension_data)?,
256            ),
257            FrankenExtensionType::RatchetTree => FrankenExtension::RatchetTree(
258                FrankenRatchetTreeExtension::tls_deserialize(&mut extension_data)?,
259            ),
260            FrankenExtensionType::RequiredCapabilities => FrankenExtension::RequiredCapabilities(
261                FrankenRequiredCapabilitiesExtension::tls_deserialize(&mut extension_data)?,
262            ),
263            FrankenExtensionType::ExternalPub => FrankenExtension::ExternalPub(
264                FrankenExternalPubExtension::tls_deserialize(&mut extension_data)?,
265            ),
266            FrankenExtensionType::ExternalSenders => FrankenExtension::ExternalSenders(
267                FrankenExternalSendersExtension::tls_deserialize(&mut extension_data)?,
268            ),
269            FrankenExtensionType::LastResort => FrankenExtension::LastResort,
270            FrankenExtensionType::Unknown(unknown) => {
271                FrankenExtension::Unknown(unknown, extension_data.to_vec().into())
272            }
273        })
274    }
275}
276
277impl DeserializeBytes for FrankenExtension {
278    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error>
279    where
280        Self: Sized,
281    {
282        let mut bytes_ref = bytes;
283        let extension = FrankenExtension::tls_deserialize(&mut bytes_ref)?;
284        let remainder = &bytes[extension.tls_serialized_len()..];
285        Ok((extension, remainder))
286    }
287}