1use std::{
25 fmt::Debug,
26 io::{Read, Write},
27};
28
29use serde::{Deserialize, Serialize};
30
31mod application_id_extension;
33mod codec;
34mod external_pub_extension;
35mod external_sender_extension;
36mod last_resort;
37mod ratchet_tree_extension;
38mod required_capabilities;
39use errors::*;
40
41pub mod errors;
43
44pub use application_id_extension::ApplicationIdExtension;
46pub use external_pub_extension::ExternalPubExtension;
47pub use external_sender_extension::{
48 ExternalSender, ExternalSendersExtension, SenderExtensionIndex,
49};
50pub use last_resort::LastResortExtension;
51pub use ratchet_tree_extension::RatchetTreeExtension;
52pub use required_capabilities::RequiredCapabilitiesExtension;
53use tls_codec::{
54 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
55 Size, TlsSize,
56};
57
58#[cfg(test)]
59mod tests;
60
61#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)]
77pub enum ExtensionType {
78 ApplicationId,
81
82 RatchetTree,
85
86 RequiredCapabilities,
89
90 ExternalPub,
93
94 ExternalSenders,
97
98 LastResort,
101
102 Unknown(u16),
104}
105
106impl ExtensionType {
107 pub(crate) fn is_default(self) -> bool {
109 match self {
110 ExtensionType::ApplicationId
111 | ExtensionType::RatchetTree
112 | ExtensionType::RequiredCapabilities
113 | ExtensionType::ExternalPub
114 | ExtensionType::ExternalSenders => true,
115 ExtensionType::LastResort | ExtensionType::Unknown(_) => false,
116 }
117 }
118
119 pub(crate) fn is_valid_in_leaf_node(self) -> Option<bool> {
124 match self {
125 ExtensionType::LastResort
126 | ExtensionType::RatchetTree
127 | ExtensionType::RequiredCapabilities
128 | ExtensionType::ExternalPub
129 | ExtensionType::ExternalSenders => Some(false),
130 ExtensionType::ApplicationId => Some(true),
131 ExtensionType::Unknown(_) => None,
132 }
133 }
134}
135
136impl Size for ExtensionType {
137 fn tls_serialized_len(&self) -> usize {
138 2
139 }
140}
141
142impl TlsDeserializeTrait for ExtensionType {
143 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
144 where
145 Self: Sized,
146 {
147 let mut extension_type = [0u8; 2];
148 bytes.read_exact(&mut extension_type)?;
149
150 Ok(ExtensionType::from(u16::from_be_bytes(extension_type)))
151 }
152}
153
154impl DeserializeBytes for ExtensionType {
155 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
156 where
157 Self: Sized,
158 {
159 let mut bytes_ref = bytes;
160 let extension_type = ExtensionType::tls_deserialize(&mut bytes_ref)?;
161 let remainder = &bytes[extension_type.tls_serialized_len()..];
162 Ok((extension_type, remainder))
163 }
164}
165
166impl TlsSerializeTrait for ExtensionType {
167 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
168 writer.write_all(&u16::from(*self).to_be_bytes())?;
169
170 Ok(2)
171 }
172}
173
174impl From<u16> for ExtensionType {
175 fn from(a: u16) -> Self {
176 match a {
177 1 => ExtensionType::ApplicationId,
178 2 => ExtensionType::RatchetTree,
179 3 => ExtensionType::RequiredCapabilities,
180 4 => ExtensionType::ExternalPub,
181 5 => ExtensionType::ExternalSenders,
182 10 => ExtensionType::LastResort,
183 unknown => ExtensionType::Unknown(unknown),
184 }
185 }
186}
187
188impl From<ExtensionType> for u16 {
189 fn from(value: ExtensionType) -> Self {
190 match value {
191 ExtensionType::ApplicationId => 1,
192 ExtensionType::RatchetTree => 2,
193 ExtensionType::RequiredCapabilities => 3,
194 ExtensionType::ExternalPub => 4,
195 ExtensionType::ExternalSenders => 5,
196 ExtensionType::LastResort => 10,
197 ExtensionType::Unknown(unknown) => unknown,
198 }
199 }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
217pub enum Extension {
218 ApplicationId(ApplicationIdExtension),
220
221 RatchetTree(RatchetTreeExtension),
223
224 RequiredCapabilities(RequiredCapabilitiesExtension),
226
227 ExternalPub(ExternalPubExtension),
229
230 ExternalSenders(ExternalSendersExtension),
232
233 LastResort(LastResortExtension),
235
236 Unknown(u16, UnknownExtension),
238}
239
240#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize)]
242pub struct UnknownExtension(pub Vec<u8>);
243
244#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TlsSize)]
246pub struct Extensions {
247 unique: Vec<Extension>,
248}
249
250impl TlsSerializeTrait for Extensions {
251 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
252 self.unique.tls_serialize(writer)
253 }
254}
255
256impl TlsDeserializeTrait for Extensions {
257 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
258 where
259 Self: Sized,
260 {
261 let candidate: Vec<Extension> = Vec::tls_deserialize(bytes)?;
262 Extensions::try_from(candidate)
263 .map_err(|_| Error::DecodingError("Found duplicate extensions".into()))
264 }
265}
266
267impl DeserializeBytes for Extensions {
268 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
269 where
270 Self: Sized,
271 {
272 let mut bytes_ref = bytes;
273 let extensions = Extensions::tls_deserialize(&mut bytes_ref)?;
274 let remainder = &bytes[extensions.tls_serialized_len()..];
275 Ok((extensions, remainder))
276 }
277}
278
279impl Extensions {
280 pub fn empty() -> Self {
282 Self { unique: vec![] }
283 }
284
285 pub fn single(extension: Extension) -> Self {
287 Self {
288 unique: vec![extension],
289 }
290 }
291
292 pub fn from_vec(extensions: Vec<Extension>) -> Result<Self, InvalidExtensionError> {
297 extensions.try_into()
298 }
299
300 pub fn iter(&self) -> impl Iterator<Item = &Extension> {
302 self.unique.iter()
303 }
304
305 pub fn add(&mut self, extension: Extension) -> Result<(), InvalidExtensionError> {
310 if self.contains(extension.extension_type()) {
311 return Err(InvalidExtensionError::Duplicate);
312 }
313
314 self.unique.push(extension);
315
316 Ok(())
317 }
318
319 pub fn add_or_replace(&mut self, extension: Extension) -> Option<Extension> {
323 let replaced = self.remove(extension.extension_type());
324 self.unique.push(extension);
325 replaced
326 }
327
328 pub fn remove(&mut self, extension_type: ExtensionType) -> Option<Extension> {
333 if let Some(pos) = self
334 .unique
335 .iter()
336 .position(|ext| ext.extension_type() == extension_type)
337 {
338 Some(self.unique.remove(pos))
339 } else {
340 None
341 }
342 }
343
344 pub fn contains(&self, extension_type: ExtensionType) -> bool {
347 self.unique
348 .iter()
349 .any(|ext| ext.extension_type() == extension_type)
350 }
351
352 pub(crate) fn validate_extension_types_for_leaf_node(
355 &self,
356 ) -> Result<(), InvalidExtensionError> {
357 for extension_type in self.unique.iter().map(Extension::extension_type) {
358 if extension_type.is_valid_in_leaf_node() == Some(false) {
360 return Err(InvalidExtensionError::IllegalInLeafNodes);
361 }
362 }
363 Ok(())
364 }
365}
366
367impl TryFrom<Vec<Extension>> for Extensions {
368 type Error = InvalidExtensionError;
369
370 fn try_from(candidate: Vec<Extension>) -> Result<Self, Self::Error> {
371 let mut unique: Vec<Extension> = Vec::new();
372
373 for extension in candidate.into_iter() {
374 if unique
375 .iter()
376 .any(|ext| ext.extension_type() == extension.extension_type())
377 {
378 return Err(InvalidExtensionError::Duplicate);
379 } else {
380 unique.push(extension);
381 }
382 }
383
384 Ok(Self { unique })
385 }
386}
387
388impl Extensions {
389 fn find_by_type(&self, extension_type: ExtensionType) -> Option<&Extension> {
390 self.unique
391 .iter()
392 .find(|ext| ext.extension_type() == extension_type)
393 }
394
395 pub fn application_id(&self) -> Option<&ApplicationIdExtension> {
397 self.find_by_type(ExtensionType::ApplicationId)
398 .and_then(|e| match e {
399 Extension::ApplicationId(e) => Some(e),
400 _ => None,
401 })
402 }
403
404 pub fn ratchet_tree(&self) -> Option<&RatchetTreeExtension> {
406 self.find_by_type(ExtensionType::RatchetTree)
407 .and_then(|e| match e {
408 Extension::RatchetTree(e) => Some(e),
409 _ => None,
410 })
411 }
412
413 pub fn required_capabilities(&self) -> Option<&RequiredCapabilitiesExtension> {
416 self.find_by_type(ExtensionType::RequiredCapabilities)
417 .and_then(|e| match e {
418 Extension::RequiredCapabilities(e) => Some(e),
419 _ => None,
420 })
421 }
422
423 pub fn external_pub(&self) -> Option<&ExternalPubExtension> {
425 self.find_by_type(ExtensionType::ExternalPub)
426 .and_then(|e| match e {
427 Extension::ExternalPub(e) => Some(e),
428 _ => None,
429 })
430 }
431
432 pub fn external_senders(&self) -> Option<&ExternalSendersExtension> {
434 self.find_by_type(ExtensionType::ExternalSenders)
435 .and_then(|e| match e {
436 Extension::ExternalSenders(e) => Some(e),
437 _ => None,
438 })
439 }
440
441 pub fn unknown(&self, extension_type_id: u16) -> Option<&UnknownExtension> {
443 let extension_type: ExtensionType = extension_type_id.into();
444
445 match extension_type {
446 ExtensionType::Unknown(_) => self.find_by_type(extension_type).and_then(|e| match e {
447 Extension::Unknown(_, e) => Some(e),
448 _ => None,
449 }),
450 _ => None,
451 }
452 }
453}
454
455impl Extension {
456 pub fn as_application_id_extension(&self) -> Result<&ApplicationIdExtension, ExtensionError> {
460 match self {
461 Self::ApplicationId(e) => Ok(e),
462 _ => Err(ExtensionError::InvalidExtensionType(
463 "This is not an ApplicationIdExtension".into(),
464 )),
465 }
466 }
467
468 pub fn as_ratchet_tree_extension(&self) -> Result<&RatchetTreeExtension, ExtensionError> {
472 match self {
473 Self::RatchetTree(rte) => Ok(rte),
474 _ => Err(ExtensionError::InvalidExtensionType(
475 "This is not a RatchetTreeExtension".into(),
476 )),
477 }
478 }
479
480 pub fn as_required_capabilities_extension(
484 &self,
485 ) -> Result<&RequiredCapabilitiesExtension, ExtensionError> {
486 match self {
487 Self::RequiredCapabilities(e) => Ok(e),
488 _ => Err(ExtensionError::InvalidExtensionType(
489 "This is not a RequiredCapabilitiesExtension".into(),
490 )),
491 }
492 }
493
494 pub fn as_external_pub_extension(&self) -> Result<&ExternalPubExtension, ExtensionError> {
498 match self {
499 Self::ExternalPub(e) => Ok(e),
500 _ => Err(ExtensionError::InvalidExtensionType(
501 "This is not an ExternalPubExtension".into(),
502 )),
503 }
504 }
505
506 pub fn as_external_senders_extension(
510 &self,
511 ) -> Result<&ExternalSendersExtension, ExtensionError> {
512 match self {
513 Self::ExternalSenders(e) => Ok(e),
514 _ => Err(ExtensionError::InvalidExtensionType(
515 "This is not an ExternalSendersExtension".into(),
516 )),
517 }
518 }
519
520 #[inline]
522 pub const fn extension_type(&self) -> ExtensionType {
523 match self {
524 Extension::ApplicationId(_) => ExtensionType::ApplicationId,
525 Extension::RatchetTree(_) => ExtensionType::RatchetTree,
526 Extension::RequiredCapabilities(_) => ExtensionType::RequiredCapabilities,
527 Extension::ExternalPub(_) => ExtensionType::ExternalPub,
528 Extension::ExternalSenders(_) => ExtensionType::ExternalSenders,
529 Extension::LastResort(_) => ExtensionType::LastResort,
530 Extension::Unknown(kind, _) => ExtensionType::Unknown(*kind),
531 }
532 }
533}
534
535#[cfg(test)]
536mod test {
537 use itertools::Itertools;
538 use tls_codec::{Deserialize, Serialize, VLBytes};
539
540 use crate::{ciphersuite::HpkePublicKey, extensions::*};
541
542 #[test]
543 fn add() {
544 let mut extensions = Extensions::default();
545 extensions
546 .add(Extension::RequiredCapabilities(
547 RequiredCapabilitiesExtension::default(),
548 ))
549 .unwrap();
550 assert!(extensions
551 .add(Extension::RequiredCapabilities(
552 RequiredCapabilitiesExtension::default()
553 ))
554 .is_err());
555 }
556
557 #[test]
558 fn add_try_from() {
559 let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
562 let ext_y = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
563
564 let tests = [
565 (vec![], true),
566 (vec![ext_x.clone()], true),
567 (vec![ext_x.clone(), ext_x.clone()], false),
568 (vec![ext_x.clone(), ext_x.clone(), ext_x.clone()], false),
569 (vec![ext_y.clone()], true),
570 (vec![ext_y.clone(), ext_y.clone()], false),
571 (vec![ext_y.clone(), ext_y.clone(), ext_y.clone()], false),
572 (vec![ext_x.clone(), ext_y.clone()], true),
573 (vec![ext_y.clone(), ext_x.clone()], true),
574 (vec![ext_x.clone(), ext_x.clone(), ext_y.clone()], false),
575 (vec![ext_y.clone(), ext_y.clone(), ext_x.clone()], false),
576 (vec![ext_x.clone(), ext_y.clone(), ext_y.clone()], false),
577 (vec![ext_y.clone(), ext_x.clone(), ext_x.clone()], false),
578 (vec![ext_x.clone(), ext_y.clone(), ext_x.clone()], false),
579 (vec![ext_y.clone(), ext_x, ext_y], false),
580 ];
581
582 for (test, should_work) in tests.into_iter() {
583 {
585 let mut extensions = Extensions::default();
586
587 let mut works = true;
588 for ext in test.iter() {
589 match extensions.add(ext.clone()) {
590 Ok(_) => {}
591 Err(InvalidExtensionError::Duplicate) => {
592 works = false;
593 }
594 _ => panic!("This should have never happened."),
595 }
596 }
597
598 println!("{:?}, {:?}", test.clone(), should_work);
599 assert_eq!(works, should_work);
600 }
601
602 if should_work {
604 assert!(Extensions::try_from(test).is_ok());
605 } else {
606 assert!(Extensions::try_from(test).is_err());
607 }
608 }
609 }
610
611 #[test]
612 fn ensure_ordering() {
613 let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
617 let ext_y = Extension::ExternalPub(ExternalPubExtension::new(HpkePublicKey::new(vec![])));
618 let ext_z = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
619
620 for candidate in [ext_x, ext_y, ext_z]
621 .into_iter()
622 .permutations(3)
623 .collect::<Vec<_>>()
624 {
625 let candidate: Extensions = Extensions::try_from(candidate).unwrap();
626 let bytes = candidate.tls_serialize_detached().unwrap();
627 let got = Extensions::tls_deserialize(&mut bytes.as_slice()).unwrap();
628 assert_eq!(candidate, got);
629 }
630 }
631
632 #[test]
633 fn that_unknown_extensions_are_de_serialized_correctly() {
634 let extension_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF100, 0xFFFF];
635 let extension_datas = [vec![], vec![0], vec![1, 2, 3]];
636
637 for extension_type in extension_types.into_iter() {
638 for extension_data in extension_datas.iter() {
639 let test = {
641 let mut buf = extension_type.to_be_bytes().to_vec();
642 buf.append(
643 &mut VLBytes::new(extension_data.clone())
644 .tls_serialize_detached()
645 .unwrap(),
646 );
647 buf
648 };
649
650 let got = Extension::tls_deserialize_exact(&test).unwrap();
652
653 match got {
654 Extension::Unknown(got_extension_type, ref got_extension_data) => {
655 assert_eq!(extension_type, got_extension_type);
656 assert_eq!(extension_data, &got_extension_data.0);
657 }
658 other => panic!("Expected `Extension::Unknown`, got {other:?}"),
659 }
660
661 let got_serialized = got.tls_serialize_detached().unwrap();
663 assert_eq!(test, got_serialized);
664 }
665 }
666 }
667}