openmls/extensions/
app_data_dict_extension.rs1use crate::component::{ComponentData, ComponentId};
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes};
5
6#[derive(thiserror::Error, Debug)]
7enum BuildAppDataDictionaryError {
8 #[error("entries not in order")]
9 EntriesNotInOrder,
10 #[error("duplicate entries")]
11 DuplicateEntries,
12}
13
14#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)]
20pub struct AppDataDictionary {
21 component_data: BTreeMap<ComponentId, ComponentData>,
27}
28
29impl AppDataDictionary {
30 pub fn new() -> Self {
32 Self {
33 component_data: BTreeMap::new(),
34 }
35 }
36 pub fn entries(&self) -> impl Iterator<Item = &ComponentData> {
39 self.component_data.values()
40 }
41
42 pub fn to_entries(self) -> Vec<ComponentData> {
44 self.entries().cloned().collect()
45 }
46
47 pub fn len(&self) -> usize {
49 self.component_data.len()
50 }
51
52 pub fn is_empty(&self) -> bool {
54 self.component_data.is_empty()
55 }
56
57 pub fn get(&self, component_id: &ComponentId) -> Option<&[u8]> {
59 self.component_data
60 .get(component_id)
61 .map(|component_data| component_data.data())
62 }
63
64 pub fn insert(&mut self, component_id: ComponentId, data: Vec<u8>) -> Option<VLBytes> {
67 self.component_data
68 .insert(
69 component_id,
70 ComponentData::from_parts(component_id, data.into()),
71 )
72 .map(|component_data| component_data.into_data())
73 }
74
75 pub fn contains(&self, component_id: &ComponentId) -> bool {
77 self.component_data.contains_key(component_id)
78 }
79
80 pub fn remove(&mut self, component_id: &ComponentId) -> Option<VLBytes> {
83 self.component_data
84 .remove(component_id)
85 .map(|component_data| component_data.into_data())
86 }
87
88 fn try_from_data(
93 data: impl IntoIterator<Item = ComponentData>,
94 ) -> Result<Self, BuildAppDataDictionaryError> {
95 let mut map = BTreeMap::<ComponentId, ComponentData>::new();
96
97 for component_data in data {
98 let (component_id, data) = component_data.into_parts();
99 if map.contains_key(&component_id) {
101 return Err(BuildAppDataDictionaryError::DuplicateEntries);
102 }
103
104 if let Some((max, _)) = map.last_key_value() {
107 if *max > component_id {
108 return Err(BuildAppDataDictionaryError::EntriesNotInOrder);
109 }
110 }
111 let _ = map.insert(component_id, ComponentData::from_parts(component_id, data));
113 }
114
115 Ok(Self {
116 component_data: map,
117 })
118 }
119}
120
121impl tls_codec::Size for AppDataDictionary {
122 fn tls_serialized_len(&self) -> usize {
123 let data: Vec<&ComponentData> = self.entries().collect();
125 data.tls_serialized_len()
126 }
127}
128
129impl tls_codec::Serialize for AppDataDictionary {
130 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
131 let data: Vec<&ComponentData> = self.entries().collect();
133 data.tls_serialize(writer)
134 }
135}
136
137impl tls_codec::Deserialize for AppDataDictionary {
138 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error> {
143 let data = Vec::<ComponentData>::tls_deserialize(bytes)?;
145
146 AppDataDictionary::try_from_data(data)
148 .map_err(|e| tls_codec::Error::DecodingError(e.to_string()))
149 }
150}
151
152impl tls_codec::DeserializeBytes for AppDataDictionary {
153 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error> {
154 use tls_codec::{Deserialize, Size};
155 let mut bytes_ref = bytes;
156 let dictionary = Self::tls_deserialize(&mut bytes_ref)?;
157
158 let remainder = &bytes[dictionary.tls_serialized_len()..];
159
160 Ok((dictionary, remainder))
161 }
162}
163
164#[derive(
168 PartialEq,
169 Eq,
170 Clone,
171 Debug,
172 Default,
173 Serialize,
174 Deserialize,
175 TlsSerialize,
176 TlsDeserialize,
177 TlsDeserializeBytes,
178 TlsSize,
179)]
180pub struct AppDataDictionaryExtension {
181 dictionary: AppDataDictionary,
182}
183
184impl AppDataDictionaryExtension {
185 pub fn dictionary(&self) -> &AppDataDictionary {
187 &self.dictionary
188 }
189 pub fn new(dictionary: AppDataDictionary) -> Self {
191 Self { dictionary }
192 }
193}
194
195#[cfg(test)]
196mod test {
197 use super::*;
198 use tls_codec::{Deserialize, Serialize};
199
200 #[openmls_test::openmls_test]
201 fn test_serialize_deserialize() {
202 let mut dictionary = AppDataDictionary::new();
204 let _ = dictionary.insert(0, vec![]);
205 let _ = dictionary.insert(0, vec![1, 2, 3]);
206
207 assert_eq!(dictionary.len(), 1);
208
209 let mut dictionary_orig = AppDataDictionary::new();
211 let _ = dictionary_orig.insert(5, vec![]);
212 let _ = dictionary_orig.insert(0, vec![1, 2, 3]);
213
214 assert_eq!(dictionary_orig.len(), 2);
215
216 let extension_orig = AppDataDictionaryExtension::new(dictionary_orig.clone());
218
219 let bytes = extension_orig.tls_serialize_detached().unwrap();
221 let extension_deserialized =
222 AppDataDictionaryExtension::tls_deserialize(&mut bytes.as_slice()).unwrap();
223 assert_eq!(extension_orig, extension_deserialized);
224 }
225 #[openmls_test::openmls_test]
226 fn test_serialization_empty() {
227 let dictionary_orig = AppDataDictionary::new();
229
230 assert_eq!(dictionary_orig.len(), 0);
231
232 let extension_orig = AppDataDictionaryExtension::new(dictionary_orig.clone());
234
235 let bytes = extension_orig.tls_serialize_detached().unwrap();
237 let extension_deserialized =
238 AppDataDictionaryExtension::tls_deserialize(&mut bytes.as_slice()).unwrap();
239 assert_eq!(extension_orig, extension_deserialized);
240 }
241 #[openmls_test::openmls_test]
243 fn test_serialization_invalid() {
244 let component_data = vec![
247 ComponentData::from_parts(5, vec![].into()),
248 ComponentData::from_parts(5, vec![1, 2, 3].into()),
249 ComponentData::from_parts(9, vec![].into()),
250 ];
251
252 let serialized = component_data.tls_serialize_detached().unwrap();
253 let err = AppDataDictionary::tls_deserialize_exact(serialized).unwrap_err();
254 assert_eq!(
255 err,
256 tls_codec::Error::DecodingError(
257 BuildAppDataDictionaryError::DuplicateEntries.to_string()
258 )
259 );
260
261 let component_data = vec![
264 ComponentData::from_parts(5, vec![].into()),
265 ComponentData::from_parts(9, vec![].into()),
266 ComponentData::from_parts(4, vec![1, 2, 3].into()),
267 ];
268
269 let serialized = component_data.tls_serialize_detached().unwrap();
270 let err = AppDataDictionary::tls_deserialize_exact(serialized).unwrap_err();
271 assert_eq!(
272 err,
273 tls_codec::Error::DecodingError(
274 BuildAppDataDictionaryError::EntriesNotInOrder.to_string()
275 )
276 );
277 }
278}