openmls/extensions/
codec.rs1use std::io::{Read, Write};
2
3use tls_codec::{Deserialize, DeserializeBytes, Serialize, Size, VLBytes};
4
5use crate::extensions::{
6 ApplicationIdExtension, Extension, ExtensionType, ExternalPubExtension,
7 ExternalSendersExtension, RatchetTreeExtension, RequiredCapabilitiesExtension,
8 UnknownExtension,
9};
10
11use super::last_resort::LastResortExtension;
12
13fn vlbytes_len_len(length: usize) -> usize {
14 if length < 0x40 {
15 1
16 } else if length < 0x3fff {
17 2
18 } else if length < 0x3fff_ffff {
19 4
20 } else {
21 8
22 }
23}
24
25impl Size for Extension {
26 #[inline]
27 fn tls_serialized_len(&self) -> usize {
28 let extension_type_length = 2;
29
30 let extension_data_len = match self {
34 Extension::ApplicationId(e) => e.tls_serialized_len(),
35 Extension::RatchetTree(e) => e.tls_serialized_len(),
36 Extension::RequiredCapabilities(e) => e.tls_serialized_len(),
37 Extension::ExternalPub(e) => e.tls_serialized_len(),
38 Extension::ExternalSenders(e) => e.tls_serialized_len(),
39 Extension::LastResort(e) => e.tls_serialized_len(),
40 Extension::Unknown(_, e) => e.0.len(),
41 };
42
43 let vlbytes_len_len = vlbytes_len_len(extension_data_len);
44
45 extension_type_length + vlbytes_len_len + extension_data_len
46 }
47}
48
49impl Size for &Extension {
50 #[inline]
51 fn tls_serialized_len(&self) -> usize {
52 Extension::tls_serialized_len(*self)
53 }
54}
55
56impl Serialize for Extension {
57 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
58 let written = self.extension_type().tls_serialize(writer)?;
60
61 let extension_data_len = self.tls_serialized_len();
63 let mut extension_data = Vec::with_capacity(extension_data_len);
64
65 let extension_data_written = match self {
66 Extension::ApplicationId(e) => e.tls_serialize(&mut extension_data),
67 Extension::RatchetTree(e) => e.tls_serialize(&mut extension_data),
68 Extension::RequiredCapabilities(e) => e.tls_serialize(&mut extension_data),
69 Extension::ExternalPub(e) => e.tls_serialize(&mut extension_data),
70 Extension::ExternalSenders(e) => e.tls_serialize(&mut extension_data),
71 Extension::LastResort(e) => e.tls_serialize(&mut extension_data),
72 Extension::Unknown(_, e) => extension_data
73 .write_all(e.0.as_slice())
74 .map(|_| e.0.len())
75 .map_err(|_| tls_codec::Error::EndOfStream),
76 }?;
77 debug_assert_eq!(
78 extension_data_written,
79 extension_data_len - 2 - vlbytes_len_len(extension_data_written)
80 );
81 debug_assert_eq!(extension_data_written, extension_data.len());
82
83 extension_data.tls_serialize(writer).map(|l| l + written)
85 }
86}
87
88impl Serialize for &Extension {
89 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
90 Extension::tls_serialize(*self, writer)
91 }
92}
93
94impl Deserialize for Extension {
95 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, tls_codec::Error> {
96 let extension_type = ExtensionType::tls_deserialize(bytes)?;
98 let extension_data = VLBytes::tls_deserialize(bytes)?;
99
100 let mut extension_data = extension_data.as_slice();
102 Ok(match extension_type {
103 ExtensionType::ApplicationId => Extension::ApplicationId(
104 ApplicationIdExtension::tls_deserialize(&mut extension_data)?,
105 ),
106 ExtensionType::RatchetTree => {
107 Extension::RatchetTree(RatchetTreeExtension::tls_deserialize(&mut extension_data)?)
108 }
109 ExtensionType::RequiredCapabilities => Extension::RequiredCapabilities(
110 RequiredCapabilitiesExtension::tls_deserialize(&mut extension_data)?,
111 ),
112 ExtensionType::ExternalPub => {
113 Extension::ExternalPub(ExternalPubExtension::tls_deserialize(&mut extension_data)?)
114 }
115 ExtensionType::ExternalSenders => Extension::ExternalSenders(
116 ExternalSendersExtension::tls_deserialize(&mut extension_data)?,
117 ),
118 ExtensionType::LastResort => {
119 Extension::LastResort(LastResortExtension::tls_deserialize(&mut extension_data)?)
120 }
121 ExtensionType::Unknown(unknown) => {
122 Extension::Unknown(unknown, UnknownExtension(extension_data.to_vec()))
123 }
124 })
125 }
126}
127
128impl DeserializeBytes for Extension {
129 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error>
130 where
131 Self: Sized,
132 {
133 let mut bytes_ref = bytes;
134 let extension = Extension::tls_deserialize(&mut bytes_ref)?;
135 let remainder = &bytes[extension.tls_serialized_len()..];
136 Ok((extension, remainder))
137 }
138}