Skip to main content

openmls/test_utils/frankenstein/
codec.rs

1use std::io::{Read, Write};
2
3use tls_codec::*;
4
5use super::{
6    extensions::{
7        FrankenApplicationIdExtension, FrankenExtension, FrankenExtensionType,
8        FrankenExternalPubExtension, FrankenExternalSendersExtension, FrankenRatchetTreeExtension,
9        FrankenRequiredCapabilitiesExtension,
10    },
11    FrankenAddProposal, FrankenCustomProposal, FrankenExternalInitProposal,
12    FrankenPreSharedKeyProposal, FrankenProposal, FrankenProposalType, FrankenReInitProposal,
13    FrankenRemoveProposal, FrankenUpdateProposal,
14};
15
16#[cfg(feature = "extensions-draft-08")]
17use super::{FrankenAppDataUpdateProposal, FrankenAppEphemeralProposal};
18
19fn vlbytes_len_len(length: usize) -> usize {
20    if length < 0x40 {
21        1
22    } else if length < 0x3fff {
23        2
24    } else if length < 0x3fff_ffff {
25        4
26    } else {
27        8
28    }
29}
30
31impl Size for FrankenProposalType {
32    fn tls_serialized_len(&self) -> usize {
33        2
34    }
35}
36
37impl Deserialize for FrankenProposalType {
38    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
39    where
40        Self: Sized,
41    {
42        let mut proposal_type = [0u8; 2];
43        bytes.read_exact(&mut proposal_type)?;
44
45        Ok(FrankenProposalType::from(u16::from_be_bytes(proposal_type)))
46    }
47}
48
49impl Serialize for FrankenProposalType {
50    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
51        writer.write_all(&u16::from(*self).to_be_bytes())?;
52
53        Ok(2)
54    }
55}
56
57impl DeserializeBytes for FrankenProposalType {
58    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
59    where
60        Self: Sized,
61    {
62        let mut bytes_ref = bytes;
63        let proposal_type = FrankenProposalType::tls_deserialize(&mut bytes_ref)?;
64        let remainder = &bytes[proposal_type.tls_serialized_len()..];
65        Ok((proposal_type, remainder))
66    }
67}
68
69impl Size for FrankenProposal {
70    fn tls_serialized_len(&self) -> usize {
71        self.proposal_type().tls_serialized_len()
72            + match self {
73                FrankenProposal::Add(p) => p.tls_serialized_len(),
74                FrankenProposal::Update(p) => p.tls_serialized_len(),
75                FrankenProposal::Remove(p) => p.tls_serialized_len(),
76                FrankenProposal::PreSharedKey(p) => p.tls_serialized_len(),
77                FrankenProposal::ReInit(p) => p.tls_serialized_len(),
78                FrankenProposal::ExternalInit(p) => p.tls_serialized_len(),
79                FrankenProposal::GroupContextExtensions(p) => p.tls_serialized_len(),
80                #[cfg(feature = "extensions-draft-08")]
81                FrankenProposal::AppEphemeral(p) => p.tls_serialized_len(),
82                #[cfg(feature = "extensions-draft-08")]
83                FrankenProposal::AppDataUpdate(p) => p.tls_serialized_len(),
84                FrankenProposal::Custom(p) => p.tls_serialized_len(),
85            }
86    }
87}
88
89impl Serialize for FrankenProposal {
90    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
91        let written = self.proposal_type().tls_serialize(writer)?;
92        match self {
93            FrankenProposal::Add(p) => p.tls_serialize(writer),
94            FrankenProposal::Update(p) => p.tls_serialize(writer),
95            FrankenProposal::Remove(p) => p.tls_serialize(writer),
96            FrankenProposal::PreSharedKey(p) => p.tls_serialize(writer),
97            FrankenProposal::ReInit(p) => p.tls_serialize(writer),
98            FrankenProposal::ExternalInit(p) => p.tls_serialize(writer),
99            FrankenProposal::GroupContextExtensions(p) => p.tls_serialize(writer),
100            #[cfg(feature = "extensions-draft-08")]
101            FrankenProposal::AppEphemeral(p) => p.tls_serialize(writer),
102            #[cfg(feature = "extensions-draft-08")]
103            FrankenProposal::AppDataUpdate(p) => p.tls_serialize(writer),
104            FrankenProposal::Custom(p) => p.payload.tls_serialize(writer),
105        }
106        .map(|l| written + l)
107    }
108}
109
110impl Deserialize for FrankenProposal {
111    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
112    where
113        Self: Sized,
114    {
115        let proposal_type = FrankenProposalType::tls_deserialize(bytes)?;
116        let proposal = match proposal_type {
117            FrankenProposalType::Add => {
118                FrankenProposal::Add(FrankenAddProposal::tls_deserialize(bytes)?)
119            }
120            FrankenProposalType::Update => {
121                FrankenProposal::Update(FrankenUpdateProposal::tls_deserialize(bytes)?)
122            }
123            FrankenProposalType::Remove => {
124                FrankenProposal::Remove(FrankenRemoveProposal::tls_deserialize(bytes)?)
125            }
126            FrankenProposalType::PreSharedKey => {
127                FrankenProposal::PreSharedKey(FrankenPreSharedKeyProposal::tls_deserialize(bytes)?)
128            }
129            FrankenProposalType::Reinit => {
130                FrankenProposal::ReInit(FrankenReInitProposal::tls_deserialize(bytes)?)
131            }
132            FrankenProposalType::ExternalInit => {
133                FrankenProposal::ExternalInit(FrankenExternalInitProposal::tls_deserialize(bytes)?)
134            }
135            FrankenProposalType::GroupContextExtensions => FrankenProposal::GroupContextExtensions(
136                Vec::<FrankenExtension>::tls_deserialize(bytes)?,
137            ),
138            #[cfg(feature = "extensions-draft-08")]
139            FrankenProposalType::AppEphemeral => {
140                FrankenProposal::AppEphemeral(FrankenAppEphemeralProposal::tls_deserialize(bytes)?)
141            }
142            #[cfg(feature = "extensions-draft-08")]
143            FrankenProposalType::AppDataUpdate => FrankenProposal::AppDataUpdate(
144                FrankenAppDataUpdateProposal::tls_deserialize(bytes)?,
145            ),
146            FrankenProposalType::Custom(_) => {
147                let payload = VLBytes::tls_deserialize(bytes)?;
148                let custom_proposal = FrankenCustomProposal {
149                    proposal_type: proposal_type.into(),
150                    payload,
151                };
152                FrankenProposal::Custom(custom_proposal)
153            }
154        };
155        Ok(proposal)
156    }
157}
158
159impl DeserializeBytes for FrankenProposal {
160    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error>
161    where
162        Self: Sized,
163    {
164        let mut bytes_ref = bytes;
165        let proposal = FrankenProposal::tls_deserialize(&mut bytes_ref)?;
166        let remainder = &bytes[proposal.tls_serialized_len()..];
167        Ok((proposal, remainder))
168    }
169}
170
171impl Size for FrankenExtensionType {
172    fn tls_serialized_len(&self) -> usize {
173        2
174    }
175}
176
177impl Deserialize for FrankenExtensionType {
178    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
179    where
180        Self: Sized,
181    {
182        let mut extension_type = [0u8; 2];
183        bytes.read_exact(&mut extension_type)?;
184
185        Ok(FrankenExtensionType::from(u16::from_be_bytes(
186            extension_type,
187        )))
188    }
189}
190
191impl DeserializeBytes for FrankenExtensionType {
192    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
193    where
194        Self: Sized,
195    {
196        let mut bytes_ref = bytes;
197        let extension_type = FrankenExtensionType::tls_deserialize(&mut bytes_ref)?;
198        let remainder = &bytes[extension_type.tls_serialized_len()..];
199        Ok((extension_type, remainder))
200    }
201}
202
203impl Serialize for FrankenExtensionType {
204    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
205        writer.write_all(&u16::from(*self).to_be_bytes())?;
206
207        Ok(2)
208    }
209}
210
211impl Size for FrankenExtension {
212    fn tls_serialized_len(&self) -> usize {
213        let extension_type_length = 2;
214        let extension_data_len = match self {
215            FrankenExtension::ApplicationId(e) => e.tls_serialized_len(),
216            FrankenExtension::RatchetTree(e) => e.tls_serialized_len(),
217            FrankenExtension::RequiredCapabilities(e) => e.tls_serialized_len(),
218            FrankenExtension::ExternalPub(e) => e.tls_serialized_len(),
219            FrankenExtension::ExternalSenders(e) => e.tls_serialized_len(),
220            FrankenExtension::LastResort => 0,
221            FrankenExtension::Unknown(_, e) => e.as_slice().len(),
222        };
223        let vlbytes_len_len = vlbytes_len_len(extension_data_len);
224        extension_type_length + vlbytes_len_len + extension_data_len
225    }
226}
227
228impl Serialize for FrankenExtension {
229    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
230        let written = self.extension_type().tls_serialize(writer)?;
231
232        // subtract the two bytes for the type header
233        let extension_data_len = self.tls_serialized_len() - 2;
234        let mut extension_data = Vec::with_capacity(extension_data_len);
235
236        let _ = match self {
237            FrankenExtension::ApplicationId(e) => e.tls_serialize(&mut extension_data),
238            FrankenExtension::RatchetTree(e) => e.tls_serialize(&mut extension_data),
239            FrankenExtension::RequiredCapabilities(e) => e.tls_serialize(&mut extension_data),
240            FrankenExtension::ExternalPub(e) => e.tls_serialize(&mut extension_data),
241            FrankenExtension::ExternalSenders(e) => e.tls_serialize(&mut extension_data),
242            FrankenExtension::LastResort => Ok(0),
243            FrankenExtension::Unknown(_, e) => extension_data
244                .write_all(e.as_slice())
245                .map(|_| e.as_slice().len())
246                .map_err(|_| tls_codec::Error::EndOfStream),
247        }?;
248
249        Serialize::tls_serialize(&extension_data, writer).map(|l| l + written)
250    }
251}
252
253impl Deserialize for FrankenExtension {
254    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, tls_codec::Error> {
255        // Read the extension type and extension data.
256        let extension_type = FrankenExtensionType::tls_deserialize(bytes)?;
257        let extension_data = VLBytes::tls_deserialize(bytes)?;
258
259        // Now deserialize the extension itself from the extension data.
260        let mut extension_data = extension_data.as_slice();
261        Ok(match extension_type {
262            FrankenExtensionType::ApplicationId => FrankenExtension::ApplicationId(
263                FrankenApplicationIdExtension::tls_deserialize(&mut extension_data)?,
264            ),
265            FrankenExtensionType::RatchetTree => FrankenExtension::RatchetTree(
266                FrankenRatchetTreeExtension::tls_deserialize(&mut extension_data)?,
267            ),
268            FrankenExtensionType::RequiredCapabilities => FrankenExtension::RequiredCapabilities(
269                FrankenRequiredCapabilitiesExtension::tls_deserialize(&mut extension_data)?,
270            ),
271            FrankenExtensionType::ExternalPub => FrankenExtension::ExternalPub(
272                FrankenExternalPubExtension::tls_deserialize(&mut extension_data)?,
273            ),
274            FrankenExtensionType::ExternalSenders => FrankenExtension::ExternalSenders(
275                FrankenExternalSendersExtension::tls_deserialize(&mut extension_data)?,
276            ),
277            FrankenExtensionType::LastResort => FrankenExtension::LastResort,
278            FrankenExtensionType::Unknown(unknown) => {
279                FrankenExtension::Unknown(unknown, extension_data.to_vec().into())
280            }
281        })
282    }
283}
284
285impl DeserializeBytes for FrankenExtension {
286    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error>
287    where
288        Self: Sized,
289    {
290        let mut bytes_ref = bytes;
291        let extension = FrankenExtension::tls_deserialize(&mut bytes_ref)?;
292        let remainder = &bytes[extension.tls_serialized_len()..];
293        Ok((extension, remainder))
294    }
295}