1use std::io::{Read, Write};
6
7use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use tls_codec::{
11 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
12 Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
13};
14
15use crate::{
16 binary_tree::array_representation::LeafNodeIndex,
17 ciphersuite::hash_ref::{make_proposal_ref, KeyPackageRef, ProposalRef},
18 error::LibraryError,
19 extensions::Extensions,
20 framing::{
21 mls_auth_content::AuthenticatedContent, mls_content::FramedContentBody, ContentType,
22 },
23 group::GroupId,
24 key_packages::*,
25 prelude::LeafNode,
26 schedule::psk::*,
27 versions::ProtocolVersion,
28};
29
30#[cfg(feature = "extensions-draft-08")]
31use crate::component::ComponentId;
32
33#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug, Serialize, Deserialize, Hash)]
76#[allow(missing_docs)]
77pub enum ProposalType {
78 Add,
79 Update,
80 Remove,
81 PreSharedKey,
82 Reinit,
83 ExternalInit,
84 GroupContextExtensions,
85 SelfRemove,
86 #[cfg(feature = "extensions-draft-08")]
87 AppEphemeral,
88 Grease(u16),
89 Custom(u16),
90}
91
92impl ProposalType {
93 pub(crate) fn is_default(self) -> bool {
96 match self {
97 ProposalType::Add
98 | ProposalType::Update
99 | ProposalType::Remove
100 | ProposalType::PreSharedKey
101 | ProposalType::Reinit
102 | ProposalType::ExternalInit
103 | ProposalType::GroupContextExtensions => true,
104 ProposalType::SelfRemove | ProposalType::Grease(_) | ProposalType::Custom(_) => false,
105 #[cfg(feature = "extensions-draft-08")]
106 ProposalType::AppEphemeral => false,
107 }
108 }
109
110 pub fn is_grease(&self) -> bool {
115 matches!(self, ProposalType::Grease(_))
116 }
117}
118
119impl Size for ProposalType {
120 fn tls_serialized_len(&self) -> usize {
121 2
122 }
123}
124
125impl TlsDeserializeTrait for ProposalType {
126 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
127 where
128 Self: Sized,
129 {
130 let mut proposal_type = [0u8; 2];
131 bytes.read_exact(&mut proposal_type)?;
132
133 Ok(ProposalType::from(u16::from_be_bytes(proposal_type)))
134 }
135}
136
137impl TlsSerializeTrait for ProposalType {
138 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
139 writer.write_all(&u16::from(*self).to_be_bytes())?;
140
141 Ok(2)
142 }
143}
144
145impl DeserializeBytes for ProposalType {
146 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
147 where
148 Self: Sized,
149 {
150 let mut bytes_ref = bytes;
151 let proposal_type = ProposalType::tls_deserialize(&mut bytes_ref)?;
152 let remainder = &bytes[proposal_type.tls_serialized_len()..];
153 Ok((proposal_type, remainder))
154 }
155}
156
157impl ProposalType {
158 pub fn is_path_required(&self) -> bool {
160 matches!(
161 self,
162 Self::Update
163 | Self::Remove
164 | Self::ExternalInit
165 | Self::GroupContextExtensions
166 | Self::SelfRemove
167 )
168 }
169}
170
171impl From<u16> for ProposalType {
172 fn from(value: u16) -> Self {
173 match value {
174 1 => ProposalType::Add,
175 2 => ProposalType::Update,
176 3 => ProposalType::Remove,
177 4 => ProposalType::PreSharedKey,
178 5 => ProposalType::Reinit,
179 6 => ProposalType::ExternalInit,
180 7 => ProposalType::GroupContextExtensions,
181 #[cfg(feature = "extensions-draft-08")]
182 0x0009 => ProposalType::AppEphemeral,
183 0x000a => ProposalType::SelfRemove,
184 other if crate::grease::is_grease_value(other) => ProposalType::Grease(other),
185 other => ProposalType::Custom(other),
186 }
187 }
188}
189
190impl From<ProposalType> for u16 {
191 fn from(value: ProposalType) -> Self {
192 match value {
193 ProposalType::Add => 1,
194 ProposalType::Update => 2,
195 ProposalType::Remove => 3,
196 ProposalType::PreSharedKey => 4,
197 ProposalType::Reinit => 5,
198 ProposalType::ExternalInit => 6,
199 ProposalType::GroupContextExtensions => 7,
200 #[cfg(feature = "extensions-draft-08")]
201 ProposalType::AppEphemeral => 0x0009,
202 ProposalType::SelfRemove => 0x000a,
203 ProposalType::Grease(id) => id,
204 ProposalType::Custom(id) => id,
205 }
206 }
207}
208
209#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
229#[allow(missing_docs)]
230#[repr(u16)]
231pub enum Proposal {
232 Add(Box<AddProposal>),
233 Update(Box<UpdateProposal>),
234 Remove(Box<RemoveProposal>),
235 PreSharedKey(Box<PreSharedKeyProposal>),
236 ReInit(Box<ReInitProposal>),
237 ExternalInit(Box<ExternalInitProposal>),
238 GroupContextExtensions(Box<GroupContextExtensionProposal>),
239 SelfRemove,
242 #[cfg(feature = "extensions-draft-08")]
243 AppEphemeral(Box<AppEphemeralProposal>),
244 Custom(Box<CustomProposal>),
245}
246
247impl Proposal {
248 pub(crate) fn remove(r: RemoveProposal) -> Self {
250 Self::Remove(Box::new(r))
251 }
252
253 pub(crate) fn add(a: AddProposal) -> Self {
255 Self::Add(Box::new(a))
256 }
257
258 pub(crate) fn custom(c: CustomProposal) -> Self {
260 Self::Custom(Box::new(c))
261 }
262
263 pub(crate) fn psk(p: PreSharedKeyProposal) -> Self {
265 Self::PreSharedKey(Box::new(p))
266 }
267
268 pub(crate) fn update(p: UpdateProposal) -> Self {
270 Self::Update(Box::new(p))
271 }
272
273 pub(crate) fn group_context_extensions(p: GroupContextExtensionProposal) -> Self {
275 Self::GroupContextExtensions(Box::new(p))
276 }
277
278 pub(crate) fn external_init(p: ExternalInitProposal) -> Self {
280 Self::ExternalInit(Box::new(p))
281 }
282
283 #[cfg(test)]
284 pub(crate) fn re_init(p: ReInitProposal) -> Self {
286 Self::ReInit(Box::new(p))
287 }
288
289 pub fn proposal_type(&self) -> ProposalType {
291 match self {
292 Proposal::Add(_) => ProposalType::Add,
293 Proposal::Update(_) => ProposalType::Update,
294 Proposal::Remove(_) => ProposalType::Remove,
295 Proposal::PreSharedKey(_) => ProposalType::PreSharedKey,
296 Proposal::ReInit(_) => ProposalType::Reinit,
297 Proposal::ExternalInit(_) => ProposalType::ExternalInit,
298 Proposal::GroupContextExtensions(_) => ProposalType::GroupContextExtensions,
299 Proposal::SelfRemove => ProposalType::SelfRemove,
300 #[cfg(feature = "extensions-draft-08")]
301 Proposal::AppEphemeral(_) => ProposalType::AppEphemeral,
302 Proposal::Custom(custom) => ProposalType::Custom(custom.proposal_type.to_owned()),
303 }
304 }
305
306 pub(crate) fn is_type(&self, proposal_type: ProposalType) -> bool {
307 self.proposal_type() == proposal_type
308 }
309
310 pub fn is_path_required(&self) -> bool {
312 self.proposal_type().is_path_required()
313 }
314
315 pub(crate) fn has_lower_priority_than(&self, new_proposal: &Proposal) -> bool {
316 match (self, new_proposal) {
317 (Proposal::Update(_), _) => true,
319 (Proposal::Remove(_), Proposal::Update(_)) => false,
321 (Proposal::Remove(_), Proposal::Remove(_)) => true,
323 (_, Proposal::SelfRemove) => true,
325 _ => {
327 debug_assert!(false);
328 false
329 }
330 }
331 }
332
333 pub(crate) fn as_remove(&self) -> Option<&RemoveProposal> {
335 if let Self::Remove(v) = self {
336 Some(v)
337 } else {
338 None
339 }
340 }
341
342 #[must_use]
346 pub fn is_remove(&self) -> bool {
347 matches!(self, Self::Remove(..))
348 }
349}
350
351#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
363pub struct AddProposal {
364 pub(crate) key_package: KeyPackage,
365}
366
367impl AddProposal {
368 pub fn key_package(&self) -> &KeyPackage {
370 &self.key_package
371 }
372}
373
374#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
387pub struct UpdateProposal {
388 pub(crate) leaf_node: LeafNode,
389}
390
391impl UpdateProposal {
392 pub fn leaf_node(&self) -> &LeafNode {
394 &self.leaf_node
395 }
396}
397
398#[derive(
410 Debug,
411 PartialEq,
412 Eq,
413 Clone,
414 Serialize,
415 Deserialize,
416 TlsDeserialize,
417 TlsDeserializeBytes,
418 TlsSerialize,
419 TlsSize,
420)]
421pub struct RemoveProposal {
422 pub(crate) removed: LeafNodeIndex,
423}
424
425impl RemoveProposal {
426 pub fn removed(&self) -> LeafNodeIndex {
428 self.removed
429 }
430}
431
432#[derive(
444 Debug,
445 PartialEq,
446 Eq,
447 Clone,
448 Serialize,
449 Deserialize,
450 TlsDeserialize,
451 TlsDeserializeBytes,
452 TlsSerialize,
453 TlsSize,
454)]
455pub struct PreSharedKeyProposal {
456 psk: PreSharedKeyId,
457}
458
459impl PreSharedKeyProposal {
460 pub(crate) fn into_psk_id(self) -> PreSharedKeyId {
462 self.psk
463 }
464}
465
466impl PreSharedKeyProposal {
467 pub fn new(psk: PreSharedKeyId) -> Self {
469 Self { psk }
470 }
471}
472
473#[derive(
490 Debug,
491 PartialEq,
492 Eq,
493 Clone,
494 Serialize,
495 Deserialize,
496 TlsDeserialize,
497 TlsDeserializeBytes,
498 TlsSerialize,
499 TlsSize,
500)]
501pub struct ReInitProposal {
502 pub(crate) group_id: GroupId,
503 pub(crate) version: ProtocolVersion,
504 pub(crate) ciphersuite: Ciphersuite,
505 pub(crate) extensions: Extensions,
506}
507
508#[derive(
520 Debug,
521 PartialEq,
522 Eq,
523 Clone,
524 Serialize,
525 Deserialize,
526 TlsDeserialize,
527 TlsDeserializeBytes,
528 TlsSerialize,
529 TlsSize,
530)]
531pub struct ExternalInitProposal {
532 kem_output: VLBytes,
533}
534
535impl ExternalInitProposal {
536 pub(crate) fn kem_output(&self) -> &[u8] {
538 self.kem_output.as_slice()
539 }
540}
541
542impl From<Vec<u8>> for ExternalInitProposal {
543 fn from(kem_output: Vec<u8>) -> Self {
544 ExternalInitProposal {
545 kem_output: kem_output.into(),
546 }
547 }
548}
549
550#[cfg(feature = "extensions-draft-08")]
551#[derive(
555 Debug,
556 PartialEq,
557 Clone,
558 Serialize,
559 Deserialize,
560 TlsDeserialize,
561 TlsDeserializeBytes,
562 TlsSerialize,
563 TlsSize,
564)]
565pub struct AppAck {
566 received_ranges: Vec<MessageRange>,
567}
568
569#[cfg(feature = "extensions-draft-08")]
570#[derive(
572 Debug,
573 PartialEq,
574 Clone,
575 Serialize,
576 Deserialize,
577 TlsDeserialize,
578 TlsDeserializeBytes,
579 TlsSerialize,
580 TlsSize,
581)]
582pub struct AppEphemeralProposal {
583 component_id: ComponentId,
585 data: VLBytes,
587}
588#[cfg(feature = "extensions-draft-08")]
589impl AppEphemeralProposal {
590 pub fn new(component_id: ComponentId, data: Vec<u8>) -> Self {
592 Self {
593 component_id,
594 data: data.into(),
595 }
596 }
597 pub fn component_id(&self) -> ComponentId {
599 self.component_id
600 }
601
602 pub fn data(&self) -> &[u8] {
604 self.data.as_slice()
605 }
606}
607
608#[derive(
620 Debug,
621 PartialEq,
622 Eq,
623 Clone,
624 Serialize,
625 Deserialize,
626 TlsDeserialize,
627 TlsDeserializeBytes,
628 TlsSerialize,
629 TlsSize,
630)]
631pub struct GroupContextExtensionProposal {
632 extensions: Extensions,
633}
634
635impl GroupContextExtensionProposal {
636 pub(crate) fn new(extensions: Extensions) -> Self {
638 Self { extensions }
639 }
640
641 pub fn extensions(&self) -> &Extensions {
643 &self.extensions
644 }
645}
646
647#[derive(
670 PartialEq,
671 Clone,
672 Copy,
673 Debug,
674 TlsSerialize,
675 TlsDeserialize,
676 TlsDeserializeBytes,
677 TlsSize,
678 Serialize,
679 Deserialize,
680)]
681#[repr(u8)]
682pub enum ProposalOrRefType {
683 Proposal = 1,
685 Reference = 2,
687}
688
689#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
691#[repr(u8)]
692#[allow(missing_docs)]
693pub(crate) enum ProposalOrRef {
694 #[tls_codec(discriminant = 1)]
695 Proposal(Box<Proposal>),
696 Reference(Box<ProposalRef>),
697}
698
699impl ProposalOrRef {
700 pub(crate) fn proposal(p: Proposal) -> Self {
702 Self::Proposal(Box::new(p))
703 }
704
705 pub(crate) fn reference(p: ProposalRef) -> Self {
707 Self::Reference(Box::new(p))
708 }
709
710 pub(crate) fn as_proposal(&self) -> Option<&Proposal> {
711 if let Self::Proposal(v) = self {
712 Some(v)
713 } else {
714 None
715 }
716 }
717
718 pub(crate) fn as_reference(&self) -> Option<&ProposalRef> {
719 if let Self::Reference(v) = self {
720 Some(v)
721 } else {
722 None
723 }
724 }
725}
726
727impl From<Proposal> for ProposalOrRef {
728 fn from(value: Proposal) -> Self {
729 Self::proposal(value)
730 }
731}
732
733impl From<ProposalRef> for ProposalOrRef {
734 fn from(value: ProposalRef) -> Self {
735 Self::reference(value)
736 }
737}
738
739#[derive(Error, Debug)]
740pub(crate) enum ProposalRefError {
741 #[error("Expected `Proposal`, got `{wrong:?}`.")]
742 AuthenticatedContentHasWrongType { wrong: ContentType },
743 #[error(transparent)]
744 Other(#[from] LibraryError),
745}
746
747impl ProposalRef {
748 pub(crate) fn from_authenticated_content_by_ref(
749 crypto: &impl OpenMlsCrypto,
750 ciphersuite: Ciphersuite,
751 authenticated_content: &AuthenticatedContent,
752 ) -> Result<Self, ProposalRefError> {
753 if !matches!(
754 authenticated_content.content(),
755 FramedContentBody::Proposal(_)
756 ) {
757 return Err(ProposalRefError::AuthenticatedContentHasWrongType {
758 wrong: authenticated_content.content().content_type(),
759 });
760 };
761
762 let encoded = authenticated_content
763 .tls_serialize_detached()
764 .map_err(|error| ProposalRefError::Other(LibraryError::missing_bound_check(error)))?;
765
766 make_proposal_ref(&encoded, ciphersuite, crypto)
767 .map_err(|error| ProposalRefError::Other(LibraryError::unexpected_crypto_error(error)))
768 }
769
770 pub(crate) fn from_raw_proposal(
776 ciphersuite: Ciphersuite,
777 crypto: &impl OpenMlsCrypto,
778 proposal: &Proposal,
779 ) -> Result<Self, LibraryError> {
780 let mut data = b"Internal OpenMLS ProposalRef Label".to_vec();
782
783 let mut encoded = proposal
784 .tls_serialize_detached()
785 .map_err(LibraryError::missing_bound_check)?;
786
787 data.append(&mut encoded);
788
789 make_proposal_ref(&data, ciphersuite, crypto).map_err(LibraryError::unexpected_crypto_error)
790 }
791}
792
793#[derive(
801 Debug,
802 PartialEq,
803 Clone,
804 Serialize,
805 Deserialize,
806 TlsDeserialize,
807 TlsDeserializeBytes,
808 TlsSerialize,
809 TlsSize,
810)]
811pub(crate) struct MessageRange {
812 sender: KeyPackageRef,
813 first_generation: u32,
814 last_generation: u32,
815}
816
817#[derive(
819 Debug,
820 PartialEq,
821 Clone,
822 Serialize,
823 Deserialize,
824 TlsSize,
825 TlsSerialize,
826 TlsDeserialize,
827 TlsDeserializeBytes,
828)]
829pub struct CustomProposal {
830 proposal_type: u16,
831 payload: Vec<u8>,
832}
833
834impl CustomProposal {
835 pub fn new(proposal_type: u16, payload: Vec<u8>) -> Self {
837 Self {
838 proposal_type,
839 payload,
840 }
841 }
842
843 pub fn proposal_type(&self) -> u16 {
845 self.proposal_type
846 }
847
848 pub fn payload(&self) -> &[u8] {
850 &self.payload
851 }
852}
853
854#[cfg(test)]
855mod tests {
856 use tls_codec::{Deserialize, Serialize};
857
858 use super::ProposalType;
859
860 #[test]
861 fn that_unknown_proposal_types_are_de_serialized_correctly() {
862 let proposal_types = [0x0000u16, 0x0B0B, 0x7C7C, 0xF000, 0xFFFF];
864
865 for proposal_type in proposal_types.into_iter() {
866 let test = proposal_type.to_be_bytes().to_vec();
868
869 let got = ProposalType::tls_deserialize_exact(&test).unwrap();
871
872 match got {
873 ProposalType::Custom(got_proposal_type) => {
874 assert_eq!(proposal_type, got_proposal_type);
875 }
876 other => panic!("Expected `ProposalType::Unknown`, got `{other:?}`."),
877 }
878
879 let got_serialized = got.tls_serialize_detached().unwrap();
881 assert_eq!(test, got_serialized);
882 }
883 }
884}