Skip to main content

openmls/framing/
safe_aad.rs

1//! Safe Additional Authenticated Data (Safe AAD) framing.
2//!
3//! Implements the wire format and validation rules from
4//! <https://datatracker.ietf.org/doc/html/draft-ietf-mls-extensions> Section 4.9.
5//!
6//! ```tls
7//! struct {
8//!   ComponentID component_id;
9//!   opaque aad_item_data<V>;
10//! } SafeAADItem;
11//!
12//! struct {
13//!   SafeAADItem aad_items<V>;
14//! } SafeAAD;
15//! ```
16//!
17//! Items in a [`SafeAad`] are sorted in strictly-increasing order of
18//! `component_id`. Duplicates and misordering are rejected on both construction
19//! and deserialization.
20
21use serde::{Deserialize, Serialize};
22use std::io::Read;
23use tls_codec::{
24    Deserialize as TlsDeserializeTrait, DeserializeBytes as TlsDeserializeBytesTrait,
25    Serialize as TlsSerializeTrait, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
26    VLBytes,
27};
28
29use crate::component::ComponentId;
30
31/// Errors that can occur when building or parsing a [`SafeAad`].
32#[derive(thiserror::Error, Debug, PartialEq, Eq, Clone)]
33pub enum SafeAadError {
34    /// Two items share the same [`ComponentId`].
35    #[error("duplicate component id in SafeAAD: {0}")]
36    DuplicateComponentId(ComponentId),
37    /// Items are not sorted in strictly-increasing order by [`ComponentId`].
38    #[error("SafeAAD items are not sorted by component id in increasing order")]
39    ItemsNotSortedAscending,
40    /// Encoding or decoding failure.
41    #[error("codec error: {0}")]
42    Codec(String),
43}
44
45/// A single Safe AAD entry tagged by [`ComponentId`].
46///
47/// ```tls
48/// struct {
49///   ComponentID component_id;
50///   opaque aad_item_data<V>;
51/// } SafeAADItem;
52/// ```
53#[derive(
54    Clone,
55    Debug,
56    PartialEq,
57    Eq,
58    Serialize,
59    Deserialize,
60    TlsSerialize,
61    TlsDeserialize,
62    TlsDeserializeBytes,
63    TlsSize,
64)]
65pub struct SafeAadItem {
66    component_id: ComponentId,
67    aad_item_data: VLBytes,
68}
69
70impl SafeAadItem {
71    /// Create a new [`SafeAadItem`].
72    pub fn new(component_id: ComponentId, data: Vec<u8>) -> Self {
73        Self {
74            component_id,
75            aad_item_data: data.into(),
76        }
77    }
78
79    /// The [`ComponentId`] this item is tagged with.
80    pub fn component_id(&self) -> ComponentId {
81        self.component_id
82    }
83
84    /// The bytes carried by this item.
85    pub fn data(&self) -> &[u8] {
86        self.aad_item_data.as_slice()
87    }
88}
89
90/// A Safe AAD struct as it appears at the beginning of an MLS message's
91/// `authenticated_data` field when negotiated for the group.
92#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, TlsSerialize, TlsSize)]
93pub struct SafeAad {
94    aad_items: Vec<SafeAadItem>,
95}
96
97impl SafeAad {
98    /// Build a [`SafeAad`] from a list of items.
99    ///
100    /// Returns an error if items are not sorted in strictly-increasing
101    /// [`ComponentId`] order or if any [`ComponentId`] appears more than once.
102    pub fn from_items(items: Vec<SafeAadItem>) -> Result<Self, SafeAadError> {
103        Self::validate(&items)?;
104        Ok(Self { aad_items: items })
105    }
106
107    /// Build an empty [`SafeAad`].
108    pub fn empty() -> Self {
109        Self {
110            aad_items: Vec::new(),
111        }
112    }
113
114    /// Returns all items.
115    pub fn items(&self) -> &[SafeAadItem] {
116        &self.aad_items
117    }
118
119    /// Look up the data carried for a given [`ComponentId`].
120    ///
121    /// Returns `None` if there is no item tagged with that id.
122    pub fn get(&self, component_id: ComponentId) -> Option<&[u8]> {
123        // The list is sorted by construction, so a binary search is correct
124        // and cheap.
125        self.aad_items
126            .binary_search_by_key(&component_id, SafeAadItem::component_id)
127            .ok()
128            .map(|index| self.aad_items[index].data())
129    }
130
131    /// Returns true if there are no items.
132    pub fn is_empty(&self) -> bool {
133        self.aad_items.is_empty()
134    }
135
136    /// Number of items.
137    pub fn len(&self) -> usize {
138        self.aad_items.len()
139    }
140
141    fn validate(items: &[SafeAadItem]) -> Result<(), SafeAadError> {
142        let mut previous: Option<ComponentId> = None;
143        for item in items {
144            if let Some(prev) = previous {
145                if item.component_id == prev {
146                    return Err(SafeAadError::DuplicateComponentId(item.component_id));
147                }
148                if item.component_id < prev {
149                    return Err(SafeAadError::ItemsNotSortedAscending);
150                }
151            }
152            previous = Some(item.component_id);
153        }
154        Ok(())
155    }
156}
157
158impl TlsDeserializeTrait for SafeAad {
159    fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, tls_codec::Error> {
160        let aad_items = Vec::<SafeAadItem>::tls_deserialize(bytes)?;
161        SafeAad::from_items(aad_items)
162            .map_err(|err| tls_codec::Error::DecodingError(err.to_string()))
163    }
164}
165
166impl TlsDeserializeBytesTrait for SafeAad {
167    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error> {
168        let (aad_items, rest) = Vec::<SafeAadItem>::tls_deserialize_bytes(bytes)?;
169        let aad = SafeAad::from_items(aad_items)
170            .map_err(|err| tls_codec::Error::DecodingError(err.to_string()))?;
171        Ok((aad, rest))
172    }
173}
174
175/// Build the bytes that go into `authenticated_data` for an outgoing message
176/// when Safe AAD is required: the TLS-serialized [`SafeAad`] followed by the
177/// caller-supplied tail bytes.
178pub(crate) fn assemble_authenticated_data(
179    safe_aad: &SafeAad,
180    tail: &[u8],
181) -> Result<Vec<u8>, SafeAadError> {
182    let mut out = safe_aad
183        .tls_serialize_detached()
184        .map_err(|err| SafeAadError::Codec(err.to_string()))?;
185    out.extend_from_slice(tail);
186    Ok(out)
187}
188
189/// Parse the [`SafeAad`] prefix from `authenticated_data` bytes when Safe AAD
190/// is required for the group. Returns the parsed struct and the length of the
191/// consumed prefix.
192pub(crate) fn parse_authenticated_data_prefix(
193    bytes: &[u8],
194) -> Result<(SafeAad, usize), SafeAadError> {
195    let (parsed, remainder) = SafeAad::tls_deserialize_bytes(bytes)
196        .map_err(|err| SafeAadError::Codec(err.to_string()))?;
197    let prefix_len = bytes.len() - remainder.len();
198    Ok((parsed, prefix_len))
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use tls_codec::{Deserialize, Serialize};
205
206    fn item(id: ComponentId, data: &[u8]) -> SafeAadItem {
207        SafeAadItem::new(id, data.to_vec())
208    }
209
210    #[test]
211    fn roundtrip_non_empty() {
212        let safe_aad = SafeAad::from_items(vec![
213            item(1, b"first"),
214            item(7, b""),
215            item(42, b"last item bytes"),
216        ])
217        .unwrap();
218
219        let bytes = safe_aad.tls_serialize_detached().unwrap();
220        let parsed = SafeAad::tls_deserialize_exact(&bytes).unwrap();
221
222        assert_eq!(parsed, safe_aad);
223        let reserialized = parsed.tls_serialize_detached().unwrap();
224        assert_eq!(reserialized, bytes);
225    }
226
227    #[test]
228    fn empty_is_length_prefix_only() {
229        let safe_aad = SafeAad::empty();
230        let bytes = safe_aad.tls_serialize_detached().unwrap();
231
232        // The TLS encoding of a zero-length `<V>` vector is a single zero byte.
233        assert_eq!(bytes, vec![0x00]);
234
235        let parsed = SafeAad::tls_deserialize_exact(&bytes).unwrap();
236        assert!(parsed.is_empty());
237    }
238
239    #[test]
240    fn from_items_rejects_duplicates() {
241        let err = SafeAad::from_items(vec![item(3, b"a"), item(3, b"b")]).unwrap_err();
242        assert_eq!(err, SafeAadError::DuplicateComponentId(3));
243    }
244
245    #[test]
246    fn from_items_rejects_misordered() {
247        let err = SafeAad::from_items(vec![item(9, b""), item(2, b"")]).unwrap_err();
248        assert_eq!(err, SafeAadError::ItemsNotSortedAscending);
249    }
250
251    #[test]
252    fn deserialize_rejects_misordered() {
253        // Hand-craft TLS bytes for two items that are out of order. The derived
254        // serializer would normally refuse to emit these, so we build the bytes
255        // directly from items wrapped in a plain `Vec`.
256        let raw_items: Vec<SafeAadItem> = vec![item(5, b"x"), item(1, b"y")];
257        let raw_bytes = raw_items.tls_serialize_detached().unwrap();
258
259        let err = SafeAad::tls_deserialize_exact(&raw_bytes).unwrap_err();
260        match err {
261            tls_codec::Error::DecodingError(message) => {
262                assert!(
263                    message.contains("not sorted"),
264                    "unexpected error message: {message}"
265                );
266            }
267            other => panic!("unexpected error variant: {other:?}"),
268        }
269    }
270
271    #[test]
272    fn deserialize_rejects_duplicates() {
273        let raw_items: Vec<SafeAadItem> = vec![item(4, b""), item(4, b"")];
274        let raw_bytes = raw_items.tls_serialize_detached().unwrap();
275
276        let err = SafeAad::tls_deserialize_exact(&raw_bytes).unwrap_err();
277        match err {
278            tls_codec::Error::DecodingError(message) => {
279                assert!(
280                    message.contains("duplicate"),
281                    "unexpected error message: {message}"
282                );
283            }
284            other => panic!("unexpected error variant: {other:?}"),
285        }
286    }
287
288    #[test]
289    fn boundary_component_ids() {
290        let safe_aad = SafeAad::from_items(vec![item(0, b"min"), item(u16::MAX, b"max")]).unwrap();
291
292        let bytes = safe_aad.tls_serialize_detached().unwrap();
293        let parsed = SafeAad::tls_deserialize_exact(&bytes).unwrap();
294
295        assert_eq!(parsed.get(0), Some(b"min".as_slice()));
296        assert_eq!(parsed.get(u16::MAX), Some(b"max".as_slice()));
297    }
298
299    #[test]
300    fn get_returns_none_for_missing() {
301        let safe_aad = SafeAad::from_items(vec![item(1, b"a"), item(10, b"b")]).unwrap();
302        assert_eq!(safe_aad.get(5), None);
303        assert_eq!(safe_aad.get(1), Some(b"a".as_slice()));
304        assert_eq!(safe_aad.get(10), Some(b"b".as_slice()));
305    }
306
307    #[test]
308    fn assemble_and_parse_authenticated_data_roundtrip() {
309        let safe_aad =
310            SafeAad::from_items(vec![item(2, b"safe-aad-data"), item(8, b"more")]).unwrap();
311        let tail = b"caller tail bytes";
312
313        let combined = assemble_authenticated_data(&safe_aad, tail).unwrap();
314
315        let (parsed, prefix_len) = parse_authenticated_data_prefix(&combined).unwrap();
316        assert_eq!(parsed, safe_aad);
317        assert_eq!(&combined[prefix_len..], tail);
318    }
319}