Skip to main content

openmls/extensions/
codec.rs

1use std::io::{Read, Write};
2
3use tls_codec::{Deserialize, DeserializeBytes, Serialize, Size, VLBytes};
4
5use crate::extensions::{
6    ApplicationIdExtension, Extension, ExtensionType, ExternalPubExtension,
7    ExternalSendersExtension, RatchetTreeExtension, RequiredCapabilitiesExtension,
8    UnknownExtension,
9};
10
11#[cfg(feature = "extensions-draft-08")]
12use crate::extensions::AppDataDictionaryExtension;
13
14use super::last_resort::LastResortExtension;
15
16fn vlbytes_len_len(length: usize) -> usize {
17    if length < 0x40 {
18        1
19    } else if length < 0x3fff {
20        2
21    } else if length < 0x3fff_ffff {
22        4
23    } else {
24        8
25    }
26}
27
28impl Size for Extension {
29    #[inline]
30    fn tls_serialized_len(&self) -> usize {
31        let extension_type_length = 2;
32
33        // We truncate here and don't catch errors for anything that's
34        // too long.
35        // This will be caught when (de)serializing.
36        let extension_data_len = match self {
37            Extension::ApplicationId(e) => e.tls_serialized_len(),
38            Extension::RatchetTree(e) => e.tls_serialized_len(),
39            Extension::RequiredCapabilities(e) => e.tls_serialized_len(),
40            Extension::ExternalPub(e) => e.tls_serialized_len(),
41            Extension::ExternalSenders(e) => e.tls_serialized_len(),
42            Extension::LastResort(e) => e.tls_serialized_len(),
43            #[cfg(feature = "extensions-draft-08")]
44            Extension::AppDataDictionary(e) => e.tls_serialized_len(),
45            Extension::Unknown(_, e) => e.0.len(),
46        };
47
48        let vlbytes_len_len = vlbytes_len_len(extension_data_len);
49
50        extension_type_length + vlbytes_len_len + extension_data_len
51    }
52}
53
54impl Size for &Extension {
55    #[inline]
56    fn tls_serialized_len(&self) -> usize {
57        Extension::tls_serialized_len(*self)
58    }
59}
60
61impl Serialize for Extension {
62    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
63        // First write the extension type.
64        let written = self.extension_type().tls_serialize(writer)?;
65
66        // Now serialize the extension into a separate byte vector.
67        let extension_data_len = self.tls_serialized_len();
68        let mut extension_data = Vec::with_capacity(extension_data_len);
69
70        let extension_data_written = match self {
71            Extension::ApplicationId(e) => e.tls_serialize(&mut extension_data),
72            Extension::RatchetTree(e) => e.tls_serialize(&mut extension_data),
73            Extension::RequiredCapabilities(e) => e.tls_serialize(&mut extension_data),
74            Extension::ExternalPub(e) => e.tls_serialize(&mut extension_data),
75            Extension::ExternalSenders(e) => e.tls_serialize(&mut extension_data),
76            #[cfg(feature = "extensions-draft-08")]
77            Extension::AppDataDictionary(e) => e.tls_serialize(&mut extension_data),
78            Extension::LastResort(e) => e.tls_serialize(&mut extension_data),
79            Extension::Unknown(_, e) => extension_data
80                .write_all(e.0.as_slice())
81                .map(|_| e.0.len())
82                .map_err(|_| tls_codec::Error::EndOfStream),
83        }?;
84        debug_assert_eq!(
85            extension_data_written,
86            extension_data_len - 2 - vlbytes_len_len(extension_data_written)
87        );
88        debug_assert_eq!(extension_data_written, extension_data.len());
89
90        // Write the serialized extension out.
91        extension_data.tls_serialize(writer).map(|l| l + written)
92    }
93}
94
95impl Serialize for &Extension {
96    fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
97        Extension::tls_serialize(*self, writer)
98    }
99}
100
101impl Deserialize for Extension {
102    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, tls_codec::Error> {
103        // Read the extension type and extension data.
104        let extension_type = ExtensionType::tls_deserialize(bytes)?;
105        let extension_data = VLBytes::tls_deserialize(bytes)?;
106
107        // Now deserialize the extension itself from the extension data.
108        let mut extension_data = extension_data.as_slice();
109        Ok(match extension_type {
110            ExtensionType::ApplicationId => Extension::ApplicationId(
111                ApplicationIdExtension::tls_deserialize(&mut extension_data)?,
112            ),
113            ExtensionType::RatchetTree => {
114                Extension::RatchetTree(RatchetTreeExtension::tls_deserialize(&mut extension_data)?)
115            }
116            ExtensionType::RequiredCapabilities => Extension::RequiredCapabilities(
117                RequiredCapabilitiesExtension::tls_deserialize(&mut extension_data)?,
118            ),
119            ExtensionType::ExternalPub => {
120                Extension::ExternalPub(ExternalPubExtension::tls_deserialize(&mut extension_data)?)
121            }
122            ExtensionType::ExternalSenders => Extension::ExternalSenders(
123                ExternalSendersExtension::tls_deserialize(&mut extension_data)?,
124            ),
125            #[cfg(feature = "extensions-draft-08")]
126            ExtensionType::AppDataDictionary => Extension::AppDataDictionary(
127                AppDataDictionaryExtension::tls_deserialize(&mut extension_data)?,
128            ),
129            ExtensionType::LastResort => {
130                Extension::LastResort(LastResortExtension::tls_deserialize(&mut extension_data)?)
131            }
132            ExtensionType::Grease(grease) | ExtensionType::Unknown(grease) => {
133                Extension::Unknown(grease, UnknownExtension(extension_data.to_vec()))
134            }
135        })
136    }
137}
138
139impl DeserializeBytes for Extension {
140    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error>
141    where
142        Self: Sized,
143    {
144        let mut bytes_ref = bytes;
145        let extension = Extension::tls_deserialize(&mut bytes_ref)?;
146        let remainder = &bytes[extension.tls_serialized_len()..];
147        Ok((extension, remainder))
148    }
149}