openmls/extensions/
app_data_dict_extension.rs

1use crate::component::{ComponentData, ComponentId};
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes};
5
6#[derive(thiserror::Error, Debug)]
7enum BuildAppDataDictionaryError {
8    #[error("entries not in order")]
9    EntriesNotInOrder,
10    #[error("duplicate entries")]
11    DuplicateEntries,
12}
13
14/// Serializable app data dictionary in the [`AppDataDictionaryExtension`].
15///
16/// This struct contains a list of [`ComponentData`] entries.
17/// Entries are in order, and there is at most one entry per [`ComponentId`].
18/// These properties are checked upon creation, as well as upon deserialization.
19#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)]
20pub struct AppDataDictionary {
21    // NOTE: A BTreeMap is used here to ensure that the data is ordered by keys,
22    // and unique. Since this also goes into the actual MLS Extension message, you could argue that
23    // this should be a Vec<ComponentData>. However, inserting in the middle is much easier (and
24    // cheaper) with the BTreeMap. The one thing that is a bit slower now is the `len` method,
25    // which iterates over all keys.
26    component_data: BTreeMap<ComponentId, ComponentData>,
27}
28
29impl AppDataDictionary {
30    /// Initialize a new, empty [`AppDataDictionary`].
31    pub fn new() -> Self {
32        Self {
33            component_data: BTreeMap::new(),
34        }
35    }
36    /// Returns an iterator over the [`ComponentData`] entries,
37    /// ordered by [`ComponentId`].
38    pub fn entries(&self) -> impl Iterator<Item = &ComponentData> {
39        self.component_data.values()
40    }
41
42    /// Returns the data that is currently stored in the [`AppDataDictionary`].
43    pub fn to_entries(self) -> Vec<ComponentData> {
44        self.entries().cloned().collect()
45    }
46
47    /// Returns the number of entries in the dictionary.
48    pub fn len(&self) -> usize {
49        self.component_data.len()
50    }
51
52    /// Returns whether the dictionary has entries.
53    pub fn is_empty(&self) -> bool {
54        self.component_data.is_empty()
55    }
56
57    /// Get a reference to an entry in the dictionary.
58    pub fn get(&self, component_id: &ComponentId) -> Option<&[u8]> {
59        self.component_data
60            .get(component_id)
61            .map(|component_data| component_data.data())
62    }
63
64    /// Insert an entry into the dictionary. If an entry for this [`ComponentId`] already exists,
65    /// replace the old entry and return it.
66    pub fn insert(&mut self, component_id: ComponentId, data: Vec<u8>) -> Option<VLBytes> {
67        self.component_data
68            .insert(
69                component_id,
70                ComponentData::from_parts(component_id, data.into()),
71            )
72            .map(|component_data| component_data.into_data())
73    }
74
75    /// Returns `true` if the dictionary contains an entry for the specified [`ComponentId`].
76    pub fn contains(&self, component_id: &ComponentId) -> bool {
77        self.component_data.contains_key(component_id)
78    }
79
80    /// Remove an entry from the dictionary by [`ComponentId`]. If this entry exists,
81    /// return it.
82    pub fn remove(&mut self, component_id: &ComponentId) -> Option<VLBytes> {
83        self.component_data
84            .remove(component_id)
85            .map(|component_data| component_data.into_data())
86    }
87
88    /// Creates an [`AppDataDictionary`] from an `impl IntoIterator<Item = ComponentData>`.
89    ///
90    /// Ensures that the list is ordered by [`ComponentId`], and that there is at most one entry per [`ComponentId`].
91    /// <https://datatracker.ietf.org/doc/html/draft-ietf-mls-extensions#section-4.6-5>
92    fn try_from_data(
93        data: impl IntoIterator<Item = ComponentData>,
94    ) -> Result<Self, BuildAppDataDictionaryError> {
95        let mut map = BTreeMap::<ComponentId, ComponentData>::new();
96
97        for component_data in data {
98            let (component_id, data) = component_data.into_parts();
99            // Check for duplicates
100            if map.contains_key(&component_id) {
101                return Err(BuildAppDataDictionaryError::DuplicateEntries);
102            }
103
104            // Check the ordering
105            // The component id must be greater than all previous component ids
106            if let Some((max, _)) = map.last_key_value() {
107                if *max > component_id {
108                    return Err(BuildAppDataDictionaryError::EntriesNotInOrder);
109                }
110            }
111            // Update the last component id
112            let _ = map.insert(component_id, ComponentData::from_parts(component_id, data));
113        }
114
115        Ok(Self {
116            component_data: map,
117        })
118    }
119}
120
121impl tls_codec::Size for AppDataDictionary {
122    fn tls_serialized_len(&self) -> usize {
123        // get length without copying
124        let data: Vec<&ComponentData> = self.entries().collect();
125        data.tls_serialized_len()
126    }
127}
128
129impl tls_codec::Serialize for AppDataDictionary {
130    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
131        // serialize without copying
132        let data: Vec<&ComponentData> = self.entries().collect();
133        data.tls_serialize(writer)
134    }
135}
136
137impl tls_codec::Deserialize for AppDataDictionary {
138    /// Deserialize from bytes.
139    ///
140    /// This function also ensures that the [`ComponentData`] entries are in order by
141    /// [`ComponentId`], and there is at most one entry per [`ComponentId`].
142    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error> {
143        // First deserialize as vector of ComponentData
144        let data = Vec::<ComponentData>::tls_deserialize(bytes)?;
145
146        // Convert to an AppDataDictionary
147        AppDataDictionary::try_from_data(data)
148            .map_err(|e| tls_codec::Error::DecodingError(e.to_string()))
149    }
150}
151
152impl tls_codec::DeserializeBytes for AppDataDictionary {
153    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error> {
154        use tls_codec::{Deserialize, Size};
155        let mut bytes_ref = bytes;
156        let dictionary = Self::tls_deserialize(&mut bytes_ref)?;
157
158        let remainder = &bytes[dictionary.tls_serialized_len()..];
159
160        Ok((dictionary, remainder))
161    }
162}
163
164/// App Data Dictionary Extension.
165///
166/// <https://datatracker.ietf.org/doc/html/draft-ietf-mls-extensions#section-4.6-3>
167#[derive(
168    PartialEq,
169    Eq,
170    Clone,
171    Debug,
172    Default,
173    Serialize,
174    Deserialize,
175    TlsSerialize,
176    TlsDeserialize,
177    TlsDeserializeBytes,
178    TlsSize,
179)]
180pub struct AppDataDictionaryExtension {
181    dictionary: AppDataDictionary,
182}
183
184impl AppDataDictionaryExtension {
185    /// Return the [`AppDataDictionary`] from this extension.
186    pub fn dictionary(&self) -> &AppDataDictionary {
187        &self.dictionary
188    }
189    /// Build a new extension from an [`AppDataDictionary`].
190    pub fn new(dictionary: AppDataDictionary) -> Self {
191        Self { dictionary }
192    }
193}
194
195#[cfg(test)]
196mod test {
197    use super::*;
198    use tls_codec::{Deserialize, Serialize};
199
200    #[openmls_test::openmls_test]
201    fn test_serialize_deserialize() {
202        // build a dictionary with one entry
203        let mut dictionary = AppDataDictionary::new();
204        let _ = dictionary.insert(0, vec![]);
205        let _ = dictionary.insert(0, vec![1, 2, 3]);
206
207        assert_eq!(dictionary.len(), 1);
208
209        // build a dictionary with multiple entries
210        let mut dictionary_orig = AppDataDictionary::new();
211        let _ = dictionary_orig.insert(5, vec![]);
212        let _ = dictionary_orig.insert(0, vec![1, 2, 3]);
213
214        assert_eq!(dictionary_orig.len(), 2);
215
216        // create an extension from the dictionary
217        let extension_orig = AppDataDictionaryExtension::new(dictionary_orig.clone());
218
219        // test serialization and deserialization of constructed dictionary
220        let bytes = extension_orig.tls_serialize_detached().unwrap();
221        let extension_deserialized =
222            AppDataDictionaryExtension::tls_deserialize(&mut bytes.as_slice()).unwrap();
223        assert_eq!(extension_orig, extension_deserialized);
224    }
225    #[openmls_test::openmls_test]
226    fn test_serialization_empty() {
227        // build a dictionary with no entries
228        let dictionary_orig = AppDataDictionary::new();
229
230        assert_eq!(dictionary_orig.len(), 0);
231
232        // create an extension from the dictionary
233        let extension_orig = AppDataDictionaryExtension::new(dictionary_orig.clone());
234
235        // test serialization and deserialization of constructed dictionary
236        let bytes = extension_orig.tls_serialize_detached().unwrap();
237        let extension_deserialized =
238            AppDataDictionaryExtension::tls_deserialize(&mut bytes.as_slice()).unwrap();
239        assert_eq!(extension_orig, extension_deserialized);
240    }
241    // TODO: replace with FrankenApppDataDictionary
242    #[openmls_test::openmls_test]
243    fn test_serialization_invalid() {
244        // incorrect dictionary with repeat entries
245        // serialize the raw content
246        let component_data = vec![
247            ComponentData::from_parts(5, vec![].into()),
248            ComponentData::from_parts(5, vec![1, 2, 3].into()),
249            ComponentData::from_parts(9, vec![].into()),
250        ];
251
252        let serialized = component_data.tls_serialize_detached().unwrap();
253        let err = AppDataDictionary::tls_deserialize_exact(serialized).unwrap_err();
254        assert_eq!(
255            err,
256            tls_codec::Error::DecodingError(
257                BuildAppDataDictionaryError::DuplicateEntries.to_string()
258            )
259        );
260
261        // incorrect dictionary with out-of-order entries
262        // serialize the raw content
263        let component_data = vec![
264            ComponentData::from_parts(5, vec![].into()),
265            ComponentData::from_parts(9, vec![].into()),
266            ComponentData::from_parts(4, vec![1, 2, 3].into()),
267        ];
268
269        let serialized = component_data.tls_serialize_detached().unwrap();
270        let err = AppDataDictionary::tls_deserialize_exact(serialized).unwrap_err();
271        assert_eq!(
272            err,
273            tls_codec::Error::DecodingError(
274                BuildAppDataDictionaryError::EntriesNotInOrder.to_string()
275            )
276        );
277    }
278}