1use std::{
25 fmt::Debug,
26 io::{Read, Write},
27};
28
29use serde::{Deserialize, Serialize};
30
31mod application_id_extension;
33mod codec;
34mod external_pub_extension;
35mod external_sender_extension;
36mod last_resort;
37mod ratchet_tree_extension;
38mod required_capabilities;
39use errors::*;
40
41pub mod errors;
43
44pub use application_id_extension::ApplicationIdExtension;
46pub use external_pub_extension::ExternalPubExtension;
47pub use external_sender_extension::{
48 ExternalSender, ExternalSendersExtension, SenderExtensionIndex,
49};
50pub use last_resort::LastResortExtension;
51pub use ratchet_tree_extension::RatchetTreeExtension;
52pub use required_capabilities::RequiredCapabilitiesExtension;
53use tls_codec::{
54 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
55 Size, TlsSize,
56};
57
58#[cfg(test)]
59mod tests;
60
61#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)]
77pub enum ExtensionType {
78 ApplicationId,
81
82 RatchetTree,
85
86 RequiredCapabilities,
89
90 ExternalPub,
93
94 ExternalSenders,
97
98 LastResort,
101
102 Grease(u16),
104
105 Unknown(u16),
107}
108
109impl ExtensionType {
110 pub(crate) fn is_default(self) -> bool {
112 match self {
113 ExtensionType::ApplicationId
114 | ExtensionType::RatchetTree
115 | ExtensionType::RequiredCapabilities
116 | ExtensionType::ExternalPub
117 | ExtensionType::ExternalSenders => true,
118 ExtensionType::LastResort | ExtensionType::Grease(_) | ExtensionType::Unknown(_) => {
119 false
120 }
121 }
122 }
123
124 pub(crate) fn is_valid_in_leaf_node(self) -> Option<bool> {
129 match self {
130 ExtensionType::LastResort
131 | ExtensionType::RatchetTree
132 | ExtensionType::RequiredCapabilities
133 | ExtensionType::ExternalPub
134 | ExtensionType::ExternalSenders => Some(false),
135 ExtensionType::ApplicationId => Some(true),
136 ExtensionType::Grease(_) | ExtensionType::Unknown(_) => None,
137 }
138 }
139
140 pub fn is_grease(&self) -> bool {
145 matches!(self, ExtensionType::Grease(_))
146 }
147}
148
149impl Size for ExtensionType {
150 fn tls_serialized_len(&self) -> usize {
151 2
152 }
153}
154
155impl TlsDeserializeTrait for ExtensionType {
156 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
157 where
158 Self: Sized,
159 {
160 let mut extension_type = [0u8; 2];
161 bytes.read_exact(&mut extension_type)?;
162
163 Ok(ExtensionType::from(u16::from_be_bytes(extension_type)))
164 }
165}
166
167impl DeserializeBytes for ExtensionType {
168 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
169 where
170 Self: Sized,
171 {
172 let mut bytes_ref = bytes;
173 let extension_type = ExtensionType::tls_deserialize(&mut bytes_ref)?;
174 let remainder = &bytes[extension_type.tls_serialized_len()..];
175 Ok((extension_type, remainder))
176 }
177}
178
179impl TlsSerializeTrait for ExtensionType {
180 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
181 writer.write_all(&u16::from(*self).to_be_bytes())?;
182
183 Ok(2)
184 }
185}
186
187impl From<u16> for ExtensionType {
188 fn from(a: u16) -> Self {
189 match a {
190 1 => ExtensionType::ApplicationId,
191 2 => ExtensionType::RatchetTree,
192 3 => ExtensionType::RequiredCapabilities,
193 4 => ExtensionType::ExternalPub,
194 5 => ExtensionType::ExternalSenders,
195 10 => ExtensionType::LastResort,
196 unknown if crate::grease::is_grease_value(unknown) => ExtensionType::Grease(unknown),
197 unknown => ExtensionType::Unknown(unknown),
198 }
199 }
200}
201
202impl From<ExtensionType> for u16 {
203 fn from(value: ExtensionType) -> Self {
204 match value {
205 ExtensionType::ApplicationId => 1,
206 ExtensionType::RatchetTree => 2,
207 ExtensionType::RequiredCapabilities => 3,
208 ExtensionType::ExternalPub => 4,
209 ExtensionType::ExternalSenders => 5,
210 ExtensionType::LastResort => 10,
211 ExtensionType::Grease(value) => value,
212 ExtensionType::Unknown(unknown) => unknown,
213 }
214 }
215}
216
217#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
232pub enum Extension {
233 ApplicationId(ApplicationIdExtension),
235
236 RatchetTree(RatchetTreeExtension),
238
239 RequiredCapabilities(RequiredCapabilitiesExtension),
241
242 ExternalPub(ExternalPubExtension),
244
245 ExternalSenders(ExternalSendersExtension),
247
248 LastResort(LastResortExtension),
250
251 Unknown(u16, UnknownExtension),
253}
254
255#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize)]
257pub struct UnknownExtension(pub Vec<u8>);
258
259#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TlsSize)]
261pub struct Extensions {
262 unique: Vec<Extension>,
263}
264
265impl TlsSerializeTrait for Extensions {
266 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
267 self.unique.tls_serialize(writer)
268 }
269}
270
271impl TlsDeserializeTrait for Extensions {
272 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
273 where
274 Self: Sized,
275 {
276 let candidate: Vec<Extension> = Vec::tls_deserialize(bytes)?;
277 Extensions::try_from(candidate)
278 .map_err(|_| Error::DecodingError("Found duplicate extensions".into()))
279 }
280}
281
282impl DeserializeBytes for Extensions {
283 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
284 where
285 Self: Sized,
286 {
287 let mut bytes_ref = bytes;
288 let extensions = Extensions::tls_deserialize(&mut bytes_ref)?;
289 let remainder = &bytes[extensions.tls_serialized_len()..];
290 Ok((extensions, remainder))
291 }
292}
293
294impl Extensions {
295 pub fn empty() -> Self {
297 Self { unique: vec![] }
298 }
299
300 pub fn single(extension: Extension) -> Self {
302 Self {
303 unique: vec![extension],
304 }
305 }
306
307 pub fn from_vec(extensions: Vec<Extension>) -> Result<Self, InvalidExtensionError> {
312 extensions.try_into()
313 }
314
315 pub fn iter(&self) -> impl Iterator<Item = &Extension> {
317 self.unique.iter()
318 }
319
320 pub fn add(&mut self, extension: Extension) -> Result<(), InvalidExtensionError> {
325 if self.contains(extension.extension_type()) {
326 return Err(InvalidExtensionError::Duplicate);
327 }
328
329 self.unique.push(extension);
330
331 Ok(())
332 }
333
334 pub fn add_or_replace(&mut self, extension: Extension) -> Option<Extension> {
338 let replaced = self.remove(extension.extension_type());
339 self.unique.push(extension);
340 replaced
341 }
342
343 pub fn remove(&mut self, extension_type: ExtensionType) -> Option<Extension> {
348 if let Some(pos) = self
349 .unique
350 .iter()
351 .position(|ext| ext.extension_type() == extension_type)
352 {
353 Some(self.unique.remove(pos))
354 } else {
355 None
356 }
357 }
358
359 pub fn contains(&self, extension_type: ExtensionType) -> bool {
362 self.unique
363 .iter()
364 .any(|ext| ext.extension_type() == extension_type)
365 }
366
367 pub(crate) fn validate_extension_types_for_leaf_node(
370 &self,
371 ) -> Result<(), InvalidExtensionError> {
372 for extension_type in self.unique.iter().map(Extension::extension_type) {
373 if extension_type.is_valid_in_leaf_node() == Some(false) {
375 return Err(InvalidExtensionError::IllegalInLeafNodes);
376 }
377 }
378 Ok(())
379 }
380}
381
382impl TryFrom<Vec<Extension>> for Extensions {
383 type Error = InvalidExtensionError;
384
385 fn try_from(candidate: Vec<Extension>) -> Result<Self, Self::Error> {
386 let mut unique: Vec<Extension> = Vec::new();
387
388 for extension in candidate.into_iter() {
389 if unique
390 .iter()
391 .any(|ext| ext.extension_type() == extension.extension_type())
392 {
393 return Err(InvalidExtensionError::Duplicate);
394 } else {
395 unique.push(extension);
396 }
397 }
398
399 Ok(Self { unique })
400 }
401}
402
403impl Extensions {
404 fn find_by_type(&self, extension_type: ExtensionType) -> Option<&Extension> {
405 self.unique
406 .iter()
407 .find(|ext| ext.extension_type() == extension_type)
408 }
409
410 pub fn application_id(&self) -> Option<&ApplicationIdExtension> {
412 self.find_by_type(ExtensionType::ApplicationId)
413 .and_then(|e| match e {
414 Extension::ApplicationId(e) => Some(e),
415 _ => None,
416 })
417 }
418
419 pub fn ratchet_tree(&self) -> Option<&RatchetTreeExtension> {
421 self.find_by_type(ExtensionType::RatchetTree)
422 .and_then(|e| match e {
423 Extension::RatchetTree(e) => Some(e),
424 _ => None,
425 })
426 }
427
428 pub fn required_capabilities(&self) -> Option<&RequiredCapabilitiesExtension> {
431 self.find_by_type(ExtensionType::RequiredCapabilities)
432 .and_then(|e| match e {
433 Extension::RequiredCapabilities(e) => Some(e),
434 _ => None,
435 })
436 }
437
438 pub fn external_pub(&self) -> Option<&ExternalPubExtension> {
440 self.find_by_type(ExtensionType::ExternalPub)
441 .and_then(|e| match e {
442 Extension::ExternalPub(e) => Some(e),
443 _ => None,
444 })
445 }
446
447 pub fn external_senders(&self) -> Option<&ExternalSendersExtension> {
449 self.find_by_type(ExtensionType::ExternalSenders)
450 .and_then(|e| match e {
451 Extension::ExternalSenders(e) => Some(e),
452 _ => None,
453 })
454 }
455
456 pub fn unknown(&self, extension_type_id: u16) -> Option<&UnknownExtension> {
458 let extension_type: ExtensionType = extension_type_id.into();
459
460 match extension_type {
461 ExtensionType::Unknown(_) => self.find_by_type(extension_type).and_then(|e| match e {
462 Extension::Unknown(_, e) => Some(e),
463 _ => None,
464 }),
465 _ => None,
466 }
467 }
468}
469
470impl Extension {
471 pub fn as_application_id_extension(&self) -> Result<&ApplicationIdExtension, ExtensionError> {
475 match self {
476 Self::ApplicationId(e) => Ok(e),
477 _ => Err(ExtensionError::InvalidExtensionType(
478 "This is not an ApplicationIdExtension".into(),
479 )),
480 }
481 }
482
483 pub fn as_ratchet_tree_extension(&self) -> Result<&RatchetTreeExtension, ExtensionError> {
487 match self {
488 Self::RatchetTree(rte) => Ok(rte),
489 _ => Err(ExtensionError::InvalidExtensionType(
490 "This is not a RatchetTreeExtension".into(),
491 )),
492 }
493 }
494
495 pub fn as_required_capabilities_extension(
499 &self,
500 ) -> Result<&RequiredCapabilitiesExtension, ExtensionError> {
501 match self {
502 Self::RequiredCapabilities(e) => Ok(e),
503 _ => Err(ExtensionError::InvalidExtensionType(
504 "This is not a RequiredCapabilitiesExtension".into(),
505 )),
506 }
507 }
508
509 pub fn as_external_pub_extension(&self) -> Result<&ExternalPubExtension, ExtensionError> {
513 match self {
514 Self::ExternalPub(e) => Ok(e),
515 _ => Err(ExtensionError::InvalidExtensionType(
516 "This is not an ExternalPubExtension".into(),
517 )),
518 }
519 }
520
521 pub fn as_external_senders_extension(
525 &self,
526 ) -> Result<&ExternalSendersExtension, ExtensionError> {
527 match self {
528 Self::ExternalSenders(e) => Ok(e),
529 _ => Err(ExtensionError::InvalidExtensionType(
530 "This is not an ExternalSendersExtension".into(),
531 )),
532 }
533 }
534
535 #[inline]
537 pub const fn extension_type(&self) -> ExtensionType {
538 match self {
539 Extension::ApplicationId(_) => ExtensionType::ApplicationId,
540 Extension::RatchetTree(_) => ExtensionType::RatchetTree,
541 Extension::RequiredCapabilities(_) => ExtensionType::RequiredCapabilities,
542 Extension::ExternalPub(_) => ExtensionType::ExternalPub,
543 Extension::ExternalSenders(_) => ExtensionType::ExternalSenders,
544 Extension::LastResort(_) => ExtensionType::LastResort,
545 Extension::Unknown(kind, _) => ExtensionType::Unknown(*kind),
546 }
547 }
548}
549
550#[cfg(test)]
551mod test {
552 use itertools::Itertools;
553 use tls_codec::{Deserialize, Serialize, VLBytes};
554
555 use crate::{ciphersuite::HpkePublicKey, extensions::*};
556
557 #[test]
558 fn add() {
559 let mut extensions = Extensions::default();
560 extensions
561 .add(Extension::RequiredCapabilities(
562 RequiredCapabilitiesExtension::default(),
563 ))
564 .unwrap();
565 assert!(extensions
566 .add(Extension::RequiredCapabilities(
567 RequiredCapabilitiesExtension::default()
568 ))
569 .is_err());
570 }
571
572 #[test]
573 fn add_try_from() {
574 let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
577 let ext_y = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
578
579 let tests = [
580 (vec![], true),
581 (vec![ext_x.clone()], true),
582 (vec![ext_x.clone(), ext_x.clone()], false),
583 (vec![ext_x.clone(), ext_x.clone(), ext_x.clone()], false),
584 (vec![ext_y.clone()], true),
585 (vec![ext_y.clone(), ext_y.clone()], false),
586 (vec![ext_y.clone(), ext_y.clone(), ext_y.clone()], false),
587 (vec![ext_x.clone(), ext_y.clone()], true),
588 (vec![ext_y.clone(), ext_x.clone()], true),
589 (vec![ext_x.clone(), ext_x.clone(), ext_y.clone()], false),
590 (vec![ext_y.clone(), ext_y.clone(), ext_x.clone()], false),
591 (vec![ext_x.clone(), ext_y.clone(), ext_y.clone()], false),
592 (vec![ext_y.clone(), ext_x.clone(), ext_x.clone()], false),
593 (vec![ext_x.clone(), ext_y.clone(), ext_x.clone()], false),
594 (vec![ext_y.clone(), ext_x, ext_y], false),
595 ];
596
597 for (test, should_work) in tests.into_iter() {
598 {
600 let mut extensions = Extensions::default();
601
602 let mut works = true;
603 for ext in test.iter() {
604 match extensions.add(ext.clone()) {
605 Ok(_) => {}
606 Err(InvalidExtensionError::Duplicate) => {
607 works = false;
608 }
609 _ => panic!("This should have never happened."),
610 }
611 }
612
613 println!("{:?}, {:?}", test.clone(), should_work);
614 assert_eq!(works, should_work);
615 }
616
617 if should_work {
619 assert!(Extensions::try_from(test).is_ok());
620 } else {
621 assert!(Extensions::try_from(test).is_err());
622 }
623 }
624 }
625
626 #[test]
627 fn ensure_ordering() {
628 let ext_x = Extension::ApplicationId(ApplicationIdExtension::new(b"Test"));
632 let ext_y = Extension::ExternalPub(ExternalPubExtension::new(HpkePublicKey::new(vec![])));
633 let ext_z = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
634
635 for candidate in [ext_x, ext_y, ext_z]
636 .into_iter()
637 .permutations(3)
638 .collect::<Vec<_>>()
639 {
640 let candidate: Extensions = Extensions::try_from(candidate).unwrap();
641 let bytes = candidate.tls_serialize_detached().unwrap();
642 let got = Extensions::tls_deserialize(&mut bytes.as_slice()).unwrap();
643 assert_eq!(candidate, got);
644 }
645 }
646
647 #[test]
648 fn that_unknown_extensions_are_de_serialized_correctly() {
649 let extension_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF100, 0xFFFF];
650 let extension_datas = [vec![], vec![0], vec![1, 2, 3]];
651
652 for extension_type in extension_types.into_iter() {
653 for extension_data in extension_datas.iter() {
654 let test = {
656 let mut buf = extension_type.to_be_bytes().to_vec();
657 buf.append(
658 &mut VLBytes::new(extension_data.clone())
659 .tls_serialize_detached()
660 .unwrap(),
661 );
662 buf
663 };
664
665 let got = Extension::tls_deserialize_exact(&test).unwrap();
667
668 match got {
669 Extension::Unknown(got_extension_type, ref got_extension_data) => {
670 assert_eq!(extension_type, got_extension_type);
671 assert_eq!(extension_data, &got_extension_data.0);
672 }
673 other => panic!("Expected `Extension::Unknown`, got {other:?}"),
674 }
675
676 let got_serialized = got.tls_serialize_detached().unwrap();
678 assert_eq!(test, got_serialized);
679 }
680 }
681 }
682}