1use serde::{Deserialize, Serialize};
22use std::io::Read;
23use tls_codec::{
24 Deserialize as TlsDeserializeTrait, DeserializeBytes as TlsDeserializeBytesTrait,
25 Serialize as TlsSerializeTrait, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
26 VLBytes,
27};
28
29use crate::component::ComponentId;
30
31#[derive(thiserror::Error, Debug, PartialEq, Eq, Clone)]
33pub enum SafeAadError {
34 #[error("duplicate component id in SafeAAD: {0}")]
36 DuplicateComponentId(ComponentId),
37 #[error("SafeAAD items are not sorted by component id in increasing order")]
39 ItemsNotSortedAscending,
40 #[error("codec error: {0}")]
42 Codec(String),
43}
44
45#[derive(
54 Clone,
55 Debug,
56 PartialEq,
57 Eq,
58 Serialize,
59 Deserialize,
60 TlsSerialize,
61 TlsDeserialize,
62 TlsDeserializeBytes,
63 TlsSize,
64)]
65pub struct SafeAadItem {
66 component_id: ComponentId,
67 aad_item_data: VLBytes,
68}
69
70impl SafeAadItem {
71 pub fn new(component_id: ComponentId, data: Vec<u8>) -> Self {
73 Self {
74 component_id,
75 aad_item_data: data.into(),
76 }
77 }
78
79 pub fn component_id(&self) -> ComponentId {
81 self.component_id
82 }
83
84 pub fn data(&self) -> &[u8] {
86 self.aad_item_data.as_slice()
87 }
88}
89
90#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, TlsSerialize, TlsSize)]
93pub struct SafeAad {
94 aad_items: Vec<SafeAadItem>,
95}
96
97impl SafeAad {
98 pub fn from_items(items: Vec<SafeAadItem>) -> Result<Self, SafeAadError> {
103 Self::validate(&items)?;
104 Ok(Self { aad_items: items })
105 }
106
107 pub fn empty() -> Self {
109 Self {
110 aad_items: Vec::new(),
111 }
112 }
113
114 pub fn items(&self) -> &[SafeAadItem] {
116 &self.aad_items
117 }
118
119 pub fn get(&self, component_id: ComponentId) -> Option<&[u8]> {
123 self.aad_items
126 .binary_search_by_key(&component_id, SafeAadItem::component_id)
127 .ok()
128 .map(|index| self.aad_items[index].data())
129 }
130
131 pub fn is_empty(&self) -> bool {
133 self.aad_items.is_empty()
134 }
135
136 pub fn len(&self) -> usize {
138 self.aad_items.len()
139 }
140
141 fn validate(items: &[SafeAadItem]) -> Result<(), SafeAadError> {
142 let mut previous: Option<ComponentId> = None;
143 for item in items {
144 if let Some(prev) = previous {
145 if item.component_id == prev {
146 return Err(SafeAadError::DuplicateComponentId(item.component_id));
147 }
148 if item.component_id < prev {
149 return Err(SafeAadError::ItemsNotSortedAscending);
150 }
151 }
152 previous = Some(item.component_id);
153 }
154 Ok(())
155 }
156}
157
158impl TlsDeserializeTrait for SafeAad {
159 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, tls_codec::Error> {
160 let aad_items = Vec::<SafeAadItem>::tls_deserialize(bytes)?;
161 SafeAad::from_items(aad_items)
162 .map_err(|err| tls_codec::Error::DecodingError(err.to_string()))
163 }
164}
165
166impl TlsDeserializeBytesTrait for SafeAad {
167 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error> {
168 let (aad_items, rest) = Vec::<SafeAadItem>::tls_deserialize_bytes(bytes)?;
169 let aad = SafeAad::from_items(aad_items)
170 .map_err(|err| tls_codec::Error::DecodingError(err.to_string()))?;
171 Ok((aad, rest))
172 }
173}
174
175pub(crate) fn assemble_authenticated_data(
179 safe_aad: &SafeAad,
180 tail: &[u8],
181) -> Result<Vec<u8>, SafeAadError> {
182 let mut out = safe_aad
183 .tls_serialize_detached()
184 .map_err(|err| SafeAadError::Codec(err.to_string()))?;
185 out.extend_from_slice(tail);
186 Ok(out)
187}
188
189pub(crate) fn parse_authenticated_data_prefix(
193 bytes: &[u8],
194) -> Result<(SafeAad, usize), SafeAadError> {
195 let (parsed, remainder) = SafeAad::tls_deserialize_bytes(bytes)
196 .map_err(|err| SafeAadError::Codec(err.to_string()))?;
197 let prefix_len = bytes.len() - remainder.len();
198 Ok((parsed, prefix_len))
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use tls_codec::{Deserialize, Serialize};
205
206 fn item(id: ComponentId, data: &[u8]) -> SafeAadItem {
207 SafeAadItem::new(id, data.to_vec())
208 }
209
210 #[test]
211 fn roundtrip_non_empty() {
212 let safe_aad = SafeAad::from_items(vec![
213 item(1, b"first"),
214 item(7, b""),
215 item(42, b"last item bytes"),
216 ])
217 .unwrap();
218
219 let bytes = safe_aad.tls_serialize_detached().unwrap();
220 let parsed = SafeAad::tls_deserialize_exact(&bytes).unwrap();
221
222 assert_eq!(parsed, safe_aad);
223 let reserialized = parsed.tls_serialize_detached().unwrap();
224 assert_eq!(reserialized, bytes);
225 }
226
227 #[test]
228 fn empty_is_length_prefix_only() {
229 let safe_aad = SafeAad::empty();
230 let bytes = safe_aad.tls_serialize_detached().unwrap();
231
232 assert_eq!(bytes, vec![0x00]);
234
235 let parsed = SafeAad::tls_deserialize_exact(&bytes).unwrap();
236 assert!(parsed.is_empty());
237 }
238
239 #[test]
240 fn from_items_rejects_duplicates() {
241 let err = SafeAad::from_items(vec![item(3, b"a"), item(3, b"b")]).unwrap_err();
242 assert_eq!(err, SafeAadError::DuplicateComponentId(3));
243 }
244
245 #[test]
246 fn from_items_rejects_misordered() {
247 let err = SafeAad::from_items(vec![item(9, b""), item(2, b"")]).unwrap_err();
248 assert_eq!(err, SafeAadError::ItemsNotSortedAscending);
249 }
250
251 #[test]
252 fn deserialize_rejects_misordered() {
253 let raw_items: Vec<SafeAadItem> = vec![item(5, b"x"), item(1, b"y")];
257 let raw_bytes = raw_items.tls_serialize_detached().unwrap();
258
259 let err = SafeAad::tls_deserialize_exact(&raw_bytes).unwrap_err();
260 match err {
261 tls_codec::Error::DecodingError(message) => {
262 assert!(
263 message.contains("not sorted"),
264 "unexpected error message: {message}"
265 );
266 }
267 other => panic!("unexpected error variant: {other:?}"),
268 }
269 }
270
271 #[test]
272 fn deserialize_rejects_duplicates() {
273 let raw_items: Vec<SafeAadItem> = vec![item(4, b""), item(4, b"")];
274 let raw_bytes = raw_items.tls_serialize_detached().unwrap();
275
276 let err = SafeAad::tls_deserialize_exact(&raw_bytes).unwrap_err();
277 match err {
278 tls_codec::Error::DecodingError(message) => {
279 assert!(
280 message.contains("duplicate"),
281 "unexpected error message: {message}"
282 );
283 }
284 other => panic!("unexpected error variant: {other:?}"),
285 }
286 }
287
288 #[test]
289 fn boundary_component_ids() {
290 let safe_aad = SafeAad::from_items(vec![item(0, b"min"), item(u16::MAX, b"max")]).unwrap();
291
292 let bytes = safe_aad.tls_serialize_detached().unwrap();
293 let parsed = SafeAad::tls_deserialize_exact(&bytes).unwrap();
294
295 assert_eq!(parsed.get(0), Some(b"min".as_slice()));
296 assert_eq!(parsed.get(u16::MAX), Some(b"max".as_slice()));
297 }
298
299 #[test]
300 fn get_returns_none_for_missing() {
301 let safe_aad = SafeAad::from_items(vec![item(1, b"a"), item(10, b"b")]).unwrap();
302 assert_eq!(safe_aad.get(5), None);
303 assert_eq!(safe_aad.get(1), Some(b"a".as_slice()));
304 assert_eq!(safe_aad.get(10), Some(b"b".as_slice()));
305 }
306
307 #[test]
308 fn assemble_and_parse_authenticated_data_roundtrip() {
309 let safe_aad =
310 SafeAad::from_items(vec![item(2, b"safe-aad-data"), item(8, b"more")]).unwrap();
311 let tail = b"caller tail bytes";
312
313 let combined = assemble_authenticated_data(&safe_aad, tail).unwrap();
314
315 let (parsed, prefix_len) = parse_authenticated_data_prefix(&combined).unwrap();
316 assert_eq!(parsed, safe_aad);
317 assert_eq!(&combined[prefix_len..], tail);
318 }
319}