1use std::io::{Read, Write};
6
7use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use tls_codec::{
11 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
12 Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
13};
14
15use crate::{
16 binary_tree::array_representation::LeafNodeIndex,
17 ciphersuite::hash_ref::{make_proposal_ref, KeyPackageRef, ProposalRef},
18 error::LibraryError,
19 extensions::Extensions,
20 framing::{
21 mls_auth_content::AuthenticatedContent, mls_content::FramedContentBody, ContentType,
22 },
23 group::GroupId,
24 key_packages::*,
25 prelude::LeafNode,
26 schedule::psk::*,
27 versions::ProtocolVersion,
28};
29
30#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug, Serialize, Deserialize, Hash)]
72#[allow(missing_docs)]
73pub enum ProposalType {
74 Add,
75 Update,
76 Remove,
77 PreSharedKey,
78 Reinit,
79 ExternalInit,
80 GroupContextExtensions,
81 AppAck,
82 SelfRemove,
83 Custom(u16),
84}
85
86impl ProposalType {
87 pub(crate) fn is_default(self) -> bool {
90 match self {
91 ProposalType::Add
92 | ProposalType::Update
93 | ProposalType::Remove
94 | ProposalType::PreSharedKey
95 | ProposalType::Reinit
96 | ProposalType::ExternalInit
97 | ProposalType::GroupContextExtensions => true,
98 ProposalType::SelfRemove | ProposalType::AppAck | ProposalType::Custom(_) => false,
99 }
100 }
101}
102
103impl Size for ProposalType {
104 fn tls_serialized_len(&self) -> usize {
105 2
106 }
107}
108
109impl TlsDeserializeTrait for ProposalType {
110 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
111 where
112 Self: Sized,
113 {
114 let mut proposal_type = [0u8; 2];
115 bytes.read_exact(&mut proposal_type)?;
116
117 Ok(ProposalType::from(u16::from_be_bytes(proposal_type)))
118 }
119}
120
121impl TlsSerializeTrait for ProposalType {
122 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
123 writer.write_all(&u16::from(*self).to_be_bytes())?;
124
125 Ok(2)
126 }
127}
128
129impl DeserializeBytes for ProposalType {
130 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
131 where
132 Self: Sized,
133 {
134 let mut bytes_ref = bytes;
135 let proposal_type = ProposalType::tls_deserialize(&mut bytes_ref)?;
136 let remainder = &bytes[proposal_type.tls_serialized_len()..];
137 Ok((proposal_type, remainder))
138 }
139}
140
141impl ProposalType {
142 pub fn is_path_required(&self) -> bool {
144 matches!(
145 self,
146 Self::Update
147 | Self::Remove
148 | Self::ExternalInit
149 | Self::GroupContextExtensions
150 | Self::SelfRemove
151 )
152 }
153}
154
155impl From<u16> for ProposalType {
156 fn from(value: u16) -> Self {
157 match value {
158 1 => ProposalType::Add,
159 2 => ProposalType::Update,
160 3 => ProposalType::Remove,
161 4 => ProposalType::PreSharedKey,
162 5 => ProposalType::Reinit,
163 6 => ProposalType::ExternalInit,
164 7 => ProposalType::GroupContextExtensions,
165 8 => ProposalType::AppAck,
166 0x000c => ProposalType::SelfRemove,
167 other => ProposalType::Custom(other),
168 }
169 }
170}
171
172impl From<ProposalType> for u16 {
173 fn from(value: ProposalType) -> Self {
174 match value {
175 ProposalType::Add => 1,
176 ProposalType::Update => 2,
177 ProposalType::Remove => 3,
178 ProposalType::PreSharedKey => 4,
179 ProposalType::Reinit => 5,
180 ProposalType::ExternalInit => 6,
181 ProposalType::GroupContextExtensions => 7,
182 ProposalType::AppAck => 8,
183 ProposalType::SelfRemove => 0x000c,
184 ProposalType::Custom(id) => id,
185 }
186 }
187}
188
189#[allow(clippy::large_enum_variant)]
209#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
210#[allow(missing_docs)]
211#[repr(u16)]
212pub enum Proposal {
213 Add(AddProposal),
214 Update(UpdateProposal),
215 Remove(RemoveProposal),
216 PreSharedKey(PreSharedKeyProposal),
217 ReInit(ReInitProposal),
218 ExternalInit(ExternalInitProposal),
219 GroupContextExtensions(GroupContextExtensionProposal),
220 AppAck(AppAckProposal),
224 SelfRemove,
226 Custom(CustomProposal),
227}
228
229impl Proposal {
230 pub fn proposal_type(&self) -> ProposalType {
232 match self {
233 Proposal::Add(_) => ProposalType::Add,
234 Proposal::Update(_) => ProposalType::Update,
235 Proposal::Remove(_) => ProposalType::Remove,
236 Proposal::PreSharedKey(_) => ProposalType::PreSharedKey,
237 Proposal::ReInit(_) => ProposalType::Reinit,
238 Proposal::ExternalInit(_) => ProposalType::ExternalInit,
239 Proposal::GroupContextExtensions(_) => ProposalType::GroupContextExtensions,
240 Proposal::AppAck(_) => ProposalType::AppAck,
241 Proposal::SelfRemove => ProposalType::SelfRemove,
242 Proposal::Custom(CustomProposal {
243 proposal_type,
244 payload: _,
245 }) => ProposalType::Custom(proposal_type.to_owned()),
246 }
247 }
248
249 pub(crate) fn is_type(&self, proposal_type: ProposalType) -> bool {
250 self.proposal_type() == proposal_type
251 }
252
253 pub fn is_path_required(&self) -> bool {
255 self.proposal_type().is_path_required()
256 }
257
258 pub(crate) fn has_lower_priority_than(&self, new_proposal: &Proposal) -> bool {
259 match (self, new_proposal) {
260 (Proposal::Update(_), _) => true,
262 (Proposal::Remove(_), Proposal::Update(_)) => false,
264 (Proposal::Remove(_), Proposal::Remove(_)) => true,
266 (_, Proposal::SelfRemove) => true,
268 _ => {
270 debug_assert!(false);
271 false
272 }
273 }
274 }
275}
276
277#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
289pub struct AddProposal {
290 pub(crate) key_package: KeyPackage,
291}
292
293impl AddProposal {
294 pub fn key_package(&self) -> &KeyPackage {
296 &self.key_package
297 }
298}
299
300#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
313pub struct UpdateProposal {
314 pub(crate) leaf_node: LeafNode,
315}
316
317impl UpdateProposal {
318 pub fn leaf_node(&self) -> &LeafNode {
320 &self.leaf_node
321 }
322}
323
324#[derive(
336 Debug,
337 PartialEq,
338 Eq,
339 Clone,
340 Serialize,
341 Deserialize,
342 TlsDeserialize,
343 TlsDeserializeBytes,
344 TlsSerialize,
345 TlsSize,
346)]
347pub struct RemoveProposal {
348 pub(crate) removed: LeafNodeIndex,
349}
350
351impl RemoveProposal {
352 pub fn removed(&self) -> LeafNodeIndex {
354 self.removed
355 }
356}
357
358#[derive(
370 Debug,
371 PartialEq,
372 Eq,
373 Clone,
374 Serialize,
375 Deserialize,
376 TlsDeserialize,
377 TlsDeserializeBytes,
378 TlsSerialize,
379 TlsSize,
380)]
381pub struct PreSharedKeyProposal {
382 psk: PreSharedKeyId,
383}
384
385impl PreSharedKeyProposal {
386 pub(crate) fn into_psk_id(self) -> PreSharedKeyId {
388 self.psk
389 }
390}
391
392impl PreSharedKeyProposal {
393 pub(crate) fn new(psk: PreSharedKeyId) -> Self {
395 Self { psk }
396 }
397}
398
399#[derive(
416 Debug,
417 PartialEq,
418 Eq,
419 Clone,
420 Serialize,
421 Deserialize,
422 TlsDeserialize,
423 TlsDeserializeBytes,
424 TlsSerialize,
425 TlsSize,
426)]
427pub struct ReInitProposal {
428 pub(crate) group_id: GroupId,
429 pub(crate) version: ProtocolVersion,
430 pub(crate) ciphersuite: Ciphersuite,
431 pub(crate) extensions: Extensions,
432}
433
434#[derive(
446 Debug,
447 PartialEq,
448 Eq,
449 Clone,
450 Serialize,
451 Deserialize,
452 TlsDeserialize,
453 TlsDeserializeBytes,
454 TlsSerialize,
455 TlsSize,
456)]
457pub struct ExternalInitProposal {
458 kem_output: VLBytes,
459}
460
461impl ExternalInitProposal {
462 pub(crate) fn kem_output(&self) -> &[u8] {
464 self.kem_output.as_slice()
465 }
466}
467
468impl From<Vec<u8>> for ExternalInitProposal {
469 fn from(kem_output: Vec<u8>) -> Self {
470 ExternalInitProposal {
471 kem_output: kem_output.into(),
472 }
473 }
474}
475
476#[derive(
480 Debug,
481 PartialEq,
482 Clone,
483 Serialize,
484 Deserialize,
485 TlsDeserialize,
486 TlsDeserializeBytes,
487 TlsSerialize,
488 TlsSize,
489)]
490pub struct AppAckProposal {
491 received_ranges: Vec<MessageRange>,
492}
493
494#[derive(
506 Debug,
507 PartialEq,
508 Eq,
509 Clone,
510 Serialize,
511 Deserialize,
512 TlsDeserialize,
513 TlsDeserializeBytes,
514 TlsSerialize,
515 TlsSize,
516)]
517pub struct GroupContextExtensionProposal {
518 extensions: Extensions,
519}
520
521impl GroupContextExtensionProposal {
522 pub(crate) fn new(extensions: Extensions) -> Self {
524 Self { extensions }
525 }
526
527 pub fn extensions(&self) -> &Extensions {
529 &self.extensions
530 }
531}
532
533#[derive(
556 PartialEq,
557 Clone,
558 Copy,
559 Debug,
560 TlsSerialize,
561 TlsDeserialize,
562 TlsDeserializeBytes,
563 TlsSize,
564 Serialize,
565 Deserialize,
566)]
567#[repr(u8)]
568pub enum ProposalOrRefType {
569 Proposal = 1,
571 Reference = 2,
573}
574
575#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
577#[repr(u8)]
578#[allow(missing_docs)]
579#[allow(clippy::large_enum_variant)]
580pub(crate) enum ProposalOrRef {
581 #[tls_codec(discriminant = 1)]
582 Proposal(Proposal),
583 Reference(ProposalRef),
584}
585
586#[derive(Error, Debug)]
587pub(crate) enum ProposalRefError {
588 #[error("Expected `Proposal`, got `{wrong:?}`.")]
589 AuthenticatedContentHasWrongType { wrong: ContentType },
590 #[error(transparent)]
591 Other(#[from] LibraryError),
592}
593
594impl ProposalRef {
595 pub(crate) fn from_authenticated_content_by_ref(
596 crypto: &impl OpenMlsCrypto,
597 ciphersuite: Ciphersuite,
598 authenticated_content: &AuthenticatedContent,
599 ) -> Result<Self, ProposalRefError> {
600 if !matches!(
601 authenticated_content.content(),
602 FramedContentBody::Proposal(_)
603 ) {
604 return Err(ProposalRefError::AuthenticatedContentHasWrongType {
605 wrong: authenticated_content.content().content_type(),
606 });
607 };
608
609 let encoded = authenticated_content
610 .tls_serialize_detached()
611 .map_err(|error| ProposalRefError::Other(LibraryError::missing_bound_check(error)))?;
612
613 make_proposal_ref(&encoded, ciphersuite, crypto)
614 .map_err(|error| ProposalRefError::Other(LibraryError::unexpected_crypto_error(error)))
615 }
616
617 pub(crate) fn from_raw_proposal(
623 ciphersuite: Ciphersuite,
624 crypto: &impl OpenMlsCrypto,
625 proposal: &Proposal,
626 ) -> Result<Self, LibraryError> {
627 let mut data = b"Internal OpenMLS ProposalRef Label".to_vec();
629
630 let mut encoded = proposal
631 .tls_serialize_detached()
632 .map_err(LibraryError::missing_bound_check)?;
633
634 data.append(&mut encoded);
635
636 make_proposal_ref(&data, ciphersuite, crypto).map_err(LibraryError::unexpected_crypto_error)
637 }
638}
639
640#[derive(
648 Debug,
649 PartialEq,
650 Clone,
651 Serialize,
652 Deserialize,
653 TlsDeserialize,
654 TlsDeserializeBytes,
655 TlsSerialize,
656 TlsSize,
657)]
658pub(crate) struct MessageRange {
659 sender: KeyPackageRef,
660 first_generation: u32,
661 last_generation: u32,
662}
663
664#[derive(
666 Debug,
667 PartialEq,
668 Clone,
669 Serialize,
670 Deserialize,
671 TlsSize,
672 TlsSerialize,
673 TlsDeserialize,
674 TlsDeserializeBytes,
675)]
676pub struct CustomProposal {
677 proposal_type: u16,
678 payload: Vec<u8>,
679}
680
681impl CustomProposal {
682 pub fn new(proposal_type: u16, payload: Vec<u8>) -> Self {
684 Self {
685 proposal_type,
686 payload,
687 }
688 }
689
690 pub fn proposal_type(&self) -> u16 {
692 self.proposal_type
693 }
694
695 pub fn payload(&self) -> &[u8] {
697 &self.payload
698 }
699}
700
701#[cfg(test)]
702mod tests {
703 use tls_codec::{Deserialize, Serialize};
704
705 use super::ProposalType;
706
707 #[test]
708 fn that_unknown_proposal_types_are_de_serialized_correctly() {
709 let proposal_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF000, 0xFFFF];
710
711 for proposal_type in proposal_types.into_iter() {
712 let test = proposal_type.to_be_bytes().to_vec();
714
715 let got = ProposalType::tls_deserialize_exact(&test).unwrap();
717
718 match got {
719 ProposalType::Custom(got_proposal_type) => {
720 assert_eq!(proposal_type, got_proposal_type);
721 }
722 other => panic!("Expected `ProposalType::Unknown`, got `{:?}`.", other),
723 }
724
725 let got_serialized = got.tls_serialize_detached().unwrap();
727 assert_eq!(test, got_serialized);
728 }
729 }
730}