openmls/extensions/
codec.rs

1use std::io::{Read, Write};
2
3use tls_codec::{Deserialize, DeserializeBytes, Serialize, Size, VLBytes};
4
5use crate::extensions::{
6    ApplicationIdExtension, Extension, ExtensionType, ExternalPubExtension,
7    ExternalSendersExtension, RatchetTreeExtension, RequiredCapabilitiesExtension,
8    UnknownExtension,
9};
10
11use super::last_resort::LastResortExtension;
12
13fn vlbytes_len_len(length: usize) -> usize {
14    if length < 0x40 {
15        1
16    } else if length < 0x3fff {
17        2
18    } else if length < 0x3fff_ffff {
19        4
20    } else {
21        8
22    }
23}
24
25impl Size for Extension {
26    #[inline]
27    fn tls_serialized_len(&self) -> usize {
28        let extension_type_length = 2;
29
30        // We truncate here and don't catch errors for anything that's
31        // too long.
32        // This will be caught when (de)serializing.
33        let extension_data_len = match self {
34            Extension::ApplicationId(e) => e.tls_serialized_len(),
35            Extension::RatchetTree(e) => e.tls_serialized_len(),
36            Extension::RequiredCapabilities(e) => e.tls_serialized_len(),
37            Extension::ExternalPub(e) => e.tls_serialized_len(),
38            Extension::ExternalSenders(e) => e.tls_serialized_len(),
39            Extension::LastResort(e) => e.tls_serialized_len(),
40            Extension::Unknown(_, e) => e.0.len(),
41        };
42
43        let vlbytes_len_len = vlbytes_len_len(extension_data_len);
44
45        extension_type_length + vlbytes_len_len + extension_data_len
46    }
47}
48
49impl Size for &Extension {
50    #[inline]
51    fn tls_serialized_len(&self) -> usize {
52        Extension::tls_serialized_len(*self)
53    }
54}
55
56impl Serialize for Extension {
57    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
58        // First write the extension type.
59        let written = self.extension_type().tls_serialize(writer)?;
60
61        // Now serialize the extension into a separate byte vector.
62        let extension_data_len = self.tls_serialized_len();
63        let mut extension_data = Vec::with_capacity(extension_data_len);
64
65        let extension_data_written = match self {
66            Extension::ApplicationId(e) => e.tls_serialize(&mut extension_data),
67            Extension::RatchetTree(e) => e.tls_serialize(&mut extension_data),
68            Extension::RequiredCapabilities(e) => e.tls_serialize(&mut extension_data),
69            Extension::ExternalPub(e) => e.tls_serialize(&mut extension_data),
70            Extension::ExternalSenders(e) => e.tls_serialize(&mut extension_data),
71            Extension::LastResort(e) => e.tls_serialize(&mut extension_data),
72            Extension::Unknown(_, e) => extension_data
73                .write_all(e.0.as_slice())
74                .map(|_| e.0.len())
75                .map_err(|_| tls_codec::Error::EndOfStream),
76        }?;
77        debug_assert_eq!(
78            extension_data_written,
79            extension_data_len - 2 - vlbytes_len_len(extension_data_written)
80        );
81        debug_assert_eq!(extension_data_written, extension_data.len());
82
83        // Write the serialized extension out.
84        extension_data.tls_serialize(writer).map(|l| l + written)
85    }
86}
87
88impl Serialize for &Extension {
89    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
90        Extension::tls_serialize(*self, writer)
91    }
92}
93
94impl Deserialize for Extension {
95    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, tls_codec::Error> {
96        // Read the extension type and extension data.
97        let extension_type = ExtensionType::tls_deserialize(bytes)?;
98        let extension_data = VLBytes::tls_deserialize(bytes)?;
99
100        // Now deserialize the extension itself from the extension data.
101        let mut extension_data = extension_data.as_slice();
102        Ok(match extension_type {
103            ExtensionType::ApplicationId => Extension::ApplicationId(
104                ApplicationIdExtension::tls_deserialize(&mut extension_data)?,
105            ),
106            ExtensionType::RatchetTree => {
107                Extension::RatchetTree(RatchetTreeExtension::tls_deserialize(&mut extension_data)?)
108            }
109            ExtensionType::RequiredCapabilities => Extension::RequiredCapabilities(
110                RequiredCapabilitiesExtension::tls_deserialize(&mut extension_data)?,
111            ),
112            ExtensionType::ExternalPub => {
113                Extension::ExternalPub(ExternalPubExtension::tls_deserialize(&mut extension_data)?)
114            }
115            ExtensionType::ExternalSenders => Extension::ExternalSenders(
116                ExternalSendersExtension::tls_deserialize(&mut extension_data)?,
117            ),
118            ExtensionType::LastResort => {
119                Extension::LastResort(LastResortExtension::tls_deserialize(&mut extension_data)?)
120            }
121            ExtensionType::Unknown(unknown) => {
122                Extension::Unknown(unknown, UnknownExtension(extension_data.to_vec()))
123            }
124        })
125    }
126}
127
128impl DeserializeBytes for Extension {
129    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error>
130    where
131        Self: Sized,
132    {
133        let mut bytes_ref = bytes;
134        let extension = Extension::tls_deserialize(&mut bytes_ref)?;
135        let remainder = &bytes[extension.tls_serialized_len()..];
136        Ok((extension, remainder))
137    }
138}