openmls/extensions/
codec.rs1use std::io::{Read, Write};
2
3use tls_codec::{Deserialize, DeserializeBytes, Serialize, Size, VLBytes};
4
5use crate::extensions::{
6 ApplicationIdExtension, Extension, ExtensionType, ExternalPubExtension,
7 ExternalSendersExtension, RatchetTreeExtension, RequiredCapabilitiesExtension,
8 UnknownExtension,
9};
10
11#[cfg(feature = "extensions-draft-08")]
12use crate::extensions::AppDataDictionaryExtension;
13
14use super::last_resort::LastResortExtension;
15
16fn vlbytes_len_len(length: usize) -> usize {
17 if length < 0x40 {
18 1
19 } else if length < 0x3fff {
20 2
21 } else if length < 0x3fff_ffff {
22 4
23 } else {
24 8
25 }
26}
27
28impl Size for Extension {
29 #[inline]
30 fn tls_serialized_len(&self) -> usize {
31 let extension_type_length = 2;
32
33 let extension_data_len = match self {
37 Extension::ApplicationId(e) => e.tls_serialized_len(),
38 Extension::RatchetTree(e) => e.tls_serialized_len(),
39 Extension::RequiredCapabilities(e) => e.tls_serialized_len(),
40 Extension::ExternalPub(e) => e.tls_serialized_len(),
41 Extension::ExternalSenders(e) => e.tls_serialized_len(),
42 Extension::LastResort(e) => e.tls_serialized_len(),
43 #[cfg(feature = "extensions-draft-08")]
44 Extension::AppDataDictionary(e) => e.tls_serialized_len(),
45 Extension::Unknown(_, e) => e.0.len(),
46 };
47
48 let vlbytes_len_len = vlbytes_len_len(extension_data_len);
49
50 extension_type_length + vlbytes_len_len + extension_data_len
51 }
52}
53
54impl Size for &Extension {
55 #[inline]
56 fn tls_serialized_len(&self) -> usize {
57 Extension::tls_serialized_len(*self)
58 }
59}
60
61impl Serialize for Extension {
62 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
63 let written = self.extension_type().tls_serialize(writer)?;
65
66 let extension_data_len = self.tls_serialized_len();
68 let mut extension_data = Vec::with_capacity(extension_data_len);
69
70 let extension_data_written = match self {
71 Extension::ApplicationId(e) => e.tls_serialize(&mut extension_data),
72 Extension::RatchetTree(e) => e.tls_serialize(&mut extension_data),
73 Extension::RequiredCapabilities(e) => e.tls_serialize(&mut extension_data),
74 Extension::ExternalPub(e) => e.tls_serialize(&mut extension_data),
75 Extension::ExternalSenders(e) => e.tls_serialize(&mut extension_data),
76 #[cfg(feature = "extensions-draft-08")]
77 Extension::AppDataDictionary(e) => e.tls_serialize(&mut extension_data),
78 Extension::LastResort(e) => e.tls_serialize(&mut extension_data),
79 Extension::Unknown(_, e) => extension_data
80 .write_all(e.0.as_slice())
81 .map(|_| e.0.len())
82 .map_err(|_| tls_codec::Error::EndOfStream),
83 }?;
84 debug_assert_eq!(
85 extension_data_written,
86 extension_data_len - 2 - vlbytes_len_len(extension_data_written)
87 );
88 debug_assert_eq!(extension_data_written, extension_data.len());
89
90 extension_data.tls_serialize(writer).map(|l| l + written)
92 }
93}
94
95impl Serialize for &Extension {
96 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
97 Extension::tls_serialize(*self, writer)
98 }
99}
100
101impl Deserialize for Extension {
102 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, tls_codec::Error> {
103 let extension_type = ExtensionType::tls_deserialize(bytes)?;
105 let extension_data = VLBytes::tls_deserialize(bytes)?;
106
107 let mut extension_data = extension_data.as_slice();
109 Ok(match extension_type {
110 ExtensionType::ApplicationId => Extension::ApplicationId(
111 ApplicationIdExtension::tls_deserialize(&mut extension_data)?,
112 ),
113 ExtensionType::RatchetTree => {
114 Extension::RatchetTree(RatchetTreeExtension::tls_deserialize(&mut extension_data)?)
115 }
116 ExtensionType::RequiredCapabilities => Extension::RequiredCapabilities(
117 RequiredCapabilitiesExtension::tls_deserialize(&mut extension_data)?,
118 ),
119 ExtensionType::ExternalPub => {
120 Extension::ExternalPub(ExternalPubExtension::tls_deserialize(&mut extension_data)?)
121 }
122 ExtensionType::ExternalSenders => Extension::ExternalSenders(
123 ExternalSendersExtension::tls_deserialize(&mut extension_data)?,
124 ),
125 #[cfg(feature = "extensions-draft-08")]
126 ExtensionType::AppDataDictionary => Extension::AppDataDictionary(
127 AppDataDictionaryExtension::tls_deserialize(&mut extension_data)?,
128 ),
129 ExtensionType::LastResort => {
130 Extension::LastResort(LastResortExtension::tls_deserialize(&mut extension_data)?)
131 }
132 ExtensionType::Grease(grease) | ExtensionType::Unknown(grease) => {
133 Extension::Unknown(grease, UnknownExtension(extension_data.to_vec()))
134 }
135 })
136 }
137}
138
139impl DeserializeBytes for Extension {
140 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error>
141 where
142 Self: Sized,
143 {
144 let mut bytes_ref = bytes;
145 let extension = Extension::tls_deserialize(&mut bytes_ref)?;
146 let remainder = &bytes[extension.tls_serialized_len()..];
147 Ok((extension, remainder))
148 }
149}