1use std::io::{Read, Write};
6
7use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use tls_codec::{
11 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
12 Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
13};
14
15use crate::{
16 binary_tree::array_representation::LeafNodeIndex,
17 ciphersuite::hash_ref::{make_proposal_ref, KeyPackageRef, ProposalRef},
18 error::LibraryError,
19 extensions::Extensions,
20 framing::{
21 mls_auth_content::AuthenticatedContent, mls_content::FramedContentBody, ContentType,
22 },
23 group::GroupId,
24 key_packages::*,
25 prelude::LeafNode,
26 schedule::psk::*,
27 versions::ProtocolVersion,
28};
29
30#[cfg(feature = "extensions-draft-08")]
31use crate::extensions::ComponentId;
32
33#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug, Serialize, Deserialize, Hash)]
76#[allow(missing_docs)]
77pub enum ProposalType {
78 Add,
79 Update,
80 Remove,
81 PreSharedKey,
82 Reinit,
83 ExternalInit,
84 GroupContextExtensions,
85 SelfRemove,
86 #[cfg(feature = "extensions-draft-08")]
87 AppEphemeral,
88 Custom(u16),
89}
90
91impl ProposalType {
92 pub(crate) fn is_default(self) -> bool {
95 match self {
96 ProposalType::Add
97 | ProposalType::Update
98 | ProposalType::Remove
99 | ProposalType::PreSharedKey
100 | ProposalType::Reinit
101 | ProposalType::ExternalInit
102 | ProposalType::GroupContextExtensions => true,
103 ProposalType::SelfRemove | ProposalType::Custom(_) => false,
104 #[cfg(feature = "extensions-draft-08")]
105 ProposalType::AppEphemeral => false,
106 }
107 }
108}
109
110impl Size for ProposalType {
111 fn tls_serialized_len(&self) -> usize {
112 2
113 }
114}
115
116impl TlsDeserializeTrait for ProposalType {
117 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
118 where
119 Self: Sized,
120 {
121 let mut proposal_type = [0u8; 2];
122 bytes.read_exact(&mut proposal_type)?;
123
124 Ok(ProposalType::from(u16::from_be_bytes(proposal_type)))
125 }
126}
127
128impl TlsSerializeTrait for ProposalType {
129 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
130 writer.write_all(&u16::from(*self).to_be_bytes())?;
131
132 Ok(2)
133 }
134}
135
136impl DeserializeBytes for ProposalType {
137 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
138 where
139 Self: Sized,
140 {
141 let mut bytes_ref = bytes;
142 let proposal_type = ProposalType::tls_deserialize(&mut bytes_ref)?;
143 let remainder = &bytes[proposal_type.tls_serialized_len()..];
144 Ok((proposal_type, remainder))
145 }
146}
147
148impl ProposalType {
149 pub fn is_path_required(&self) -> bool {
151 matches!(
152 self,
153 Self::Update
154 | Self::Remove
155 | Self::ExternalInit
156 | Self::GroupContextExtensions
157 | Self::SelfRemove
158 )
159 }
160}
161
162impl From<u16> for ProposalType {
163 fn from(value: u16) -> Self {
164 match value {
165 1 => ProposalType::Add,
166 2 => ProposalType::Update,
167 3 => ProposalType::Remove,
168 4 => ProposalType::PreSharedKey,
169 5 => ProposalType::Reinit,
170 6 => ProposalType::ExternalInit,
171 7 => ProposalType::GroupContextExtensions,
172 #[cfg(feature = "extensions-draft-08")]
173 0x0009 => ProposalType::AppEphemeral,
174 0x000a => ProposalType::SelfRemove,
175 other => ProposalType::Custom(other),
176 }
177 }
178}
179
180impl From<ProposalType> for u16 {
181 fn from(value: ProposalType) -> Self {
182 match value {
183 ProposalType::Add => 1,
184 ProposalType::Update => 2,
185 ProposalType::Remove => 3,
186 ProposalType::PreSharedKey => 4,
187 ProposalType::Reinit => 5,
188 ProposalType::ExternalInit => 6,
189 ProposalType::GroupContextExtensions => 7,
190 #[cfg(feature = "extensions-draft-08")]
191 ProposalType::AppEphemeral => 0x0009,
192 ProposalType::SelfRemove => 0x000a,
193 ProposalType::Custom(id) => id,
194 }
195 }
196}
197
198#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
218#[allow(missing_docs)]
219#[repr(u16)]
220pub enum Proposal {
221 Add(Box<AddProposal>),
222 Update(Box<UpdateProposal>),
223 Remove(Box<RemoveProposal>),
224 PreSharedKey(Box<PreSharedKeyProposal>),
225 ReInit(Box<ReInitProposal>),
226 ExternalInit(Box<ExternalInitProposal>),
227 GroupContextExtensions(Box<GroupContextExtensionProposal>),
228 SelfRemove,
231 #[cfg(feature = "extensions-draft-08")]
232 AppEphemeral(Box<AppEphemeralProposal>),
233 Custom(Box<CustomProposal>),
234}
235
236impl Proposal {
237 pub(crate) fn remove(r: RemoveProposal) -> Self {
239 Self::Remove(Box::new(r))
240 }
241
242 pub(crate) fn add(a: AddProposal) -> Self {
244 Self::Add(Box::new(a))
245 }
246
247 pub(crate) fn custom(c: CustomProposal) -> Self {
249 Self::Custom(Box::new(c))
250 }
251
252 pub(crate) fn psk(p: PreSharedKeyProposal) -> Self {
254 Self::PreSharedKey(Box::new(p))
255 }
256
257 pub(crate) fn update(p: UpdateProposal) -> Self {
259 Self::Update(Box::new(p))
260 }
261
262 pub(crate) fn group_context_extensions(p: GroupContextExtensionProposal) -> Self {
264 Self::GroupContextExtensions(Box::new(p))
265 }
266
267 pub(crate) fn external_init(p: ExternalInitProposal) -> Self {
269 Self::ExternalInit(Box::new(p))
270 }
271
272 #[cfg(test)]
273 pub(crate) fn re_init(p: ReInitProposal) -> Self {
275 Self::ReInit(Box::new(p))
276 }
277
278 pub fn proposal_type(&self) -> ProposalType {
280 match self {
281 Proposal::Add(_) => ProposalType::Add,
282 Proposal::Update(_) => ProposalType::Update,
283 Proposal::Remove(_) => ProposalType::Remove,
284 Proposal::PreSharedKey(_) => ProposalType::PreSharedKey,
285 Proposal::ReInit(_) => ProposalType::Reinit,
286 Proposal::ExternalInit(_) => ProposalType::ExternalInit,
287 Proposal::GroupContextExtensions(_) => ProposalType::GroupContextExtensions,
288 Proposal::SelfRemove => ProposalType::SelfRemove,
289 #[cfg(feature = "extensions-draft-08")]
290 Proposal::AppEphemeral(_) => ProposalType::AppEphemeral,
291 Proposal::Custom(custom) => ProposalType::Custom(custom.proposal_type.to_owned()),
292 }
293 }
294
295 pub(crate) fn is_type(&self, proposal_type: ProposalType) -> bool {
296 self.proposal_type() == proposal_type
297 }
298
299 pub fn is_path_required(&self) -> bool {
301 self.proposal_type().is_path_required()
302 }
303
304 pub(crate) fn has_lower_priority_than(&self, new_proposal: &Proposal) -> bool {
305 match (self, new_proposal) {
306 (Proposal::Update(_), _) => true,
308 (Proposal::Remove(_), Proposal::Update(_)) => false,
310 (Proposal::Remove(_), Proposal::Remove(_)) => true,
312 (_, Proposal::SelfRemove) => true,
314 _ => {
316 debug_assert!(false);
317 false
318 }
319 }
320 }
321
322 pub(crate) fn as_remove(&self) -> Option<&RemoveProposal> {
324 if let Self::Remove(v) = self {
325 Some(v)
326 } else {
327 None
328 }
329 }
330
331 #[must_use]
335 pub fn is_remove(&self) -> bool {
336 matches!(self, Self::Remove(..))
337 }
338}
339
340#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
352pub struct AddProposal {
353 pub(crate) key_package: KeyPackage,
354}
355
356impl AddProposal {
357 pub fn key_package(&self) -> &KeyPackage {
359 &self.key_package
360 }
361}
362
363#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
376pub struct UpdateProposal {
377 pub(crate) leaf_node: LeafNode,
378}
379
380impl UpdateProposal {
381 pub fn leaf_node(&self) -> &LeafNode {
383 &self.leaf_node
384 }
385}
386
387#[derive(
399 Debug,
400 PartialEq,
401 Eq,
402 Clone,
403 Serialize,
404 Deserialize,
405 TlsDeserialize,
406 TlsDeserializeBytes,
407 TlsSerialize,
408 TlsSize,
409)]
410pub struct RemoveProposal {
411 pub(crate) removed: LeafNodeIndex,
412}
413
414impl RemoveProposal {
415 pub fn removed(&self) -> LeafNodeIndex {
417 self.removed
418 }
419}
420
421#[derive(
433 Debug,
434 PartialEq,
435 Eq,
436 Clone,
437 Serialize,
438 Deserialize,
439 TlsDeserialize,
440 TlsDeserializeBytes,
441 TlsSerialize,
442 TlsSize,
443)]
444pub struct PreSharedKeyProposal {
445 psk: PreSharedKeyId,
446}
447
448impl PreSharedKeyProposal {
449 pub(crate) fn into_psk_id(self) -> PreSharedKeyId {
451 self.psk
452 }
453}
454
455impl PreSharedKeyProposal {
456 pub fn new(psk: PreSharedKeyId) -> Self {
458 Self { psk }
459 }
460}
461
462#[derive(
479 Debug,
480 PartialEq,
481 Eq,
482 Clone,
483 Serialize,
484 Deserialize,
485 TlsDeserialize,
486 TlsDeserializeBytes,
487 TlsSerialize,
488 TlsSize,
489)]
490pub struct ReInitProposal {
491 pub(crate) group_id: GroupId,
492 pub(crate) version: ProtocolVersion,
493 pub(crate) ciphersuite: Ciphersuite,
494 pub(crate) extensions: Extensions,
495}
496
497#[derive(
509 Debug,
510 PartialEq,
511 Eq,
512 Clone,
513 Serialize,
514 Deserialize,
515 TlsDeserialize,
516 TlsDeserializeBytes,
517 TlsSerialize,
518 TlsSize,
519)]
520pub struct ExternalInitProposal {
521 kem_output: VLBytes,
522}
523
524impl ExternalInitProposal {
525 pub(crate) fn kem_output(&self) -> &[u8] {
527 self.kem_output.as_slice()
528 }
529}
530
531impl From<Vec<u8>> for ExternalInitProposal {
532 fn from(kem_output: Vec<u8>) -> Self {
533 ExternalInitProposal {
534 kem_output: kem_output.into(),
535 }
536 }
537}
538
539#[cfg(feature = "extensions-draft-08")]
540#[derive(
544 Debug,
545 PartialEq,
546 Clone,
547 Serialize,
548 Deserialize,
549 TlsDeserialize,
550 TlsDeserializeBytes,
551 TlsSerialize,
552 TlsSize,
553)]
554pub struct AppAck {
555 received_ranges: Vec<MessageRange>,
556}
557
558#[cfg(feature = "extensions-draft-08")]
559#[derive(
561 Debug,
562 PartialEq,
563 Clone,
564 Serialize,
565 Deserialize,
566 TlsDeserialize,
567 TlsDeserializeBytes,
568 TlsSerialize,
569 TlsSize,
570)]
571pub struct AppEphemeralProposal {
572 component_id: ComponentId,
574 data: VLBytes,
576}
577#[cfg(feature = "extensions-draft-08")]
578impl AppEphemeralProposal {
579 pub fn new(component_id: ComponentId, data: Vec<u8>) -> Self {
581 Self {
582 component_id,
583 data: data.into(),
584 }
585 }
586 pub fn component_id(&self) -> ComponentId {
588 self.component_id
589 }
590
591 pub fn data(&self) -> &[u8] {
593 self.data.as_slice()
594 }
595}
596
597#[derive(
609 Debug,
610 PartialEq,
611 Eq,
612 Clone,
613 Serialize,
614 Deserialize,
615 TlsDeserialize,
616 TlsDeserializeBytes,
617 TlsSerialize,
618 TlsSize,
619)]
620pub struct GroupContextExtensionProposal {
621 extensions: Extensions,
622}
623
624impl GroupContextExtensionProposal {
625 pub(crate) fn new(extensions: Extensions) -> Self {
627 Self { extensions }
628 }
629
630 pub fn extensions(&self) -> &Extensions {
632 &self.extensions
633 }
634}
635
636#[derive(
659 PartialEq,
660 Clone,
661 Copy,
662 Debug,
663 TlsSerialize,
664 TlsDeserialize,
665 TlsDeserializeBytes,
666 TlsSize,
667 Serialize,
668 Deserialize,
669)]
670#[repr(u8)]
671pub enum ProposalOrRefType {
672 Proposal = 1,
674 Reference = 2,
676}
677
678#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
680#[repr(u8)]
681#[allow(missing_docs)]
682pub(crate) enum ProposalOrRef {
683 #[tls_codec(discriminant = 1)]
684 Proposal(Box<Proposal>),
685 Reference(Box<ProposalRef>),
686}
687
688impl ProposalOrRef {
689 pub(crate) fn proposal(p: Proposal) -> Self {
691 Self::Proposal(Box::new(p))
692 }
693
694 pub(crate) fn reference(p: ProposalRef) -> Self {
696 Self::Reference(Box::new(p))
697 }
698
699 pub(crate) fn as_proposal(&self) -> Option<&Proposal> {
700 if let Self::Proposal(v) = self {
701 Some(v)
702 } else {
703 None
704 }
705 }
706
707 pub(crate) fn as_reference(&self) -> Option<&ProposalRef> {
708 if let Self::Reference(v) = self {
709 Some(v)
710 } else {
711 None
712 }
713 }
714}
715
716impl From<Proposal> for ProposalOrRef {
717 fn from(value: Proposal) -> Self {
718 Self::proposal(value)
719 }
720}
721
722impl From<ProposalRef> for ProposalOrRef {
723 fn from(value: ProposalRef) -> Self {
724 Self::reference(value)
725 }
726}
727
728#[derive(Error, Debug)]
729pub(crate) enum ProposalRefError {
730 #[error("Expected `Proposal`, got `{wrong:?}`.")]
731 AuthenticatedContentHasWrongType { wrong: ContentType },
732 #[error(transparent)]
733 Other(#[from] LibraryError),
734}
735
736impl ProposalRef {
737 pub(crate) fn from_authenticated_content_by_ref(
738 crypto: &impl OpenMlsCrypto,
739 ciphersuite: Ciphersuite,
740 authenticated_content: &AuthenticatedContent,
741 ) -> Result<Self, ProposalRefError> {
742 if !matches!(
743 authenticated_content.content(),
744 FramedContentBody::Proposal(_)
745 ) {
746 return Err(ProposalRefError::AuthenticatedContentHasWrongType {
747 wrong: authenticated_content.content().content_type(),
748 });
749 };
750
751 let encoded = authenticated_content
752 .tls_serialize_detached()
753 .map_err(|error| ProposalRefError::Other(LibraryError::missing_bound_check(error)))?;
754
755 make_proposal_ref(&encoded, ciphersuite, crypto)
756 .map_err(|error| ProposalRefError::Other(LibraryError::unexpected_crypto_error(error)))
757 }
758
759 pub(crate) fn from_raw_proposal(
765 ciphersuite: Ciphersuite,
766 crypto: &impl OpenMlsCrypto,
767 proposal: &Proposal,
768 ) -> Result<Self, LibraryError> {
769 let mut data = b"Internal OpenMLS ProposalRef Label".to_vec();
771
772 let mut encoded = proposal
773 .tls_serialize_detached()
774 .map_err(LibraryError::missing_bound_check)?;
775
776 data.append(&mut encoded);
777
778 make_proposal_ref(&data, ciphersuite, crypto).map_err(LibraryError::unexpected_crypto_error)
779 }
780}
781
782#[derive(
790 Debug,
791 PartialEq,
792 Clone,
793 Serialize,
794 Deserialize,
795 TlsDeserialize,
796 TlsDeserializeBytes,
797 TlsSerialize,
798 TlsSize,
799)]
800pub(crate) struct MessageRange {
801 sender: KeyPackageRef,
802 first_generation: u32,
803 last_generation: u32,
804}
805
806#[derive(
808 Debug,
809 PartialEq,
810 Clone,
811 Serialize,
812 Deserialize,
813 TlsSize,
814 TlsSerialize,
815 TlsDeserialize,
816 TlsDeserializeBytes,
817)]
818pub struct CustomProposal {
819 proposal_type: u16,
820 payload: Vec<u8>,
821}
822
823impl CustomProposal {
824 pub fn new(proposal_type: u16, payload: Vec<u8>) -> Self {
826 Self {
827 proposal_type,
828 payload,
829 }
830 }
831
832 pub fn proposal_type(&self) -> u16 {
834 self.proposal_type
835 }
836
837 pub fn payload(&self) -> &[u8] {
839 &self.payload
840 }
841}
842
843#[cfg(test)]
844mod tests {
845 use tls_codec::{Deserialize, Serialize};
846
847 use super::ProposalType;
848
849 #[test]
850 fn that_unknown_proposal_types_are_de_serialized_correctly() {
851 let proposal_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF000, 0xFFFF];
852
853 for proposal_type in proposal_types.into_iter() {
854 let test = proposal_type.to_be_bytes().to_vec();
856
857 let got = ProposalType::tls_deserialize_exact(&test).unwrap();
859
860 match got {
861 ProposalType::Custom(got_proposal_type) => {
862 assert_eq!(proposal_type, got_proposal_type);
863 }
864 other => panic!("Expected `ProposalType::Unknown`, got `{other:?}`."),
865 }
866
867 let got_serialized = got.tls_serialize_detached().unwrap();
869 assert_eq!(test, got_serialized);
870 }
871 }
872}