1use std::io::{Read, Write};
6
7use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use tls_codec::{
11 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
12 Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
13};
14
15use crate::{
16 binary_tree::array_representation::LeafNodeIndex,
17 ciphersuite::hash_ref::{make_proposal_ref, KeyPackageRef, ProposalRef},
18 error::LibraryError,
19 extensions::Extensions,
20 framing::{
21 mls_auth_content::AuthenticatedContent, mls_content::FramedContentBody, ContentType,
22 },
23 group::GroupId,
24 key_packages::*,
25 prelude::LeafNode,
26 schedule::psk::*,
27 versions::ProtocolVersion,
28};
29
30#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug, Serialize, Deserialize, Hash)]
73#[allow(missing_docs)]
74pub enum ProposalType {
75 Add,
76 Update,
77 Remove,
78 PreSharedKey,
79 Reinit,
80 ExternalInit,
81 GroupContextExtensions,
82 AppAck,
83 SelfRemove,
84 Custom(u16),
85}
86
87impl ProposalType {
88 pub(crate) fn is_default(self) -> bool {
91 match self {
92 ProposalType::Add
93 | ProposalType::Update
94 | ProposalType::Remove
95 | ProposalType::PreSharedKey
96 | ProposalType::Reinit
97 | ProposalType::ExternalInit
98 | ProposalType::GroupContextExtensions => true,
99 ProposalType::SelfRemove | ProposalType::AppAck | ProposalType::Custom(_) => false,
100 }
101 }
102}
103
104impl Size for ProposalType {
105 fn tls_serialized_len(&self) -> usize {
106 2
107 }
108}
109
110impl TlsDeserializeTrait for ProposalType {
111 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
112 where
113 Self: Sized,
114 {
115 let mut proposal_type = [0u8; 2];
116 bytes.read_exact(&mut proposal_type)?;
117
118 Ok(ProposalType::from(u16::from_be_bytes(proposal_type)))
119 }
120}
121
122impl TlsSerializeTrait for ProposalType {
123 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
124 writer.write_all(&u16::from(*self).to_be_bytes())?;
125
126 Ok(2)
127 }
128}
129
130impl DeserializeBytes for ProposalType {
131 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
132 where
133 Self: Sized,
134 {
135 let mut bytes_ref = bytes;
136 let proposal_type = ProposalType::tls_deserialize(&mut bytes_ref)?;
137 let remainder = &bytes[proposal_type.tls_serialized_len()..];
138 Ok((proposal_type, remainder))
139 }
140}
141
142impl ProposalType {
143 pub fn is_path_required(&self) -> bool {
145 matches!(
146 self,
147 Self::Update
148 | Self::Remove
149 | Self::ExternalInit
150 | Self::GroupContextExtensions
151 | Self::SelfRemove
152 )
153 }
154}
155
156impl From<u16> for ProposalType {
157 fn from(value: u16) -> Self {
158 match value {
159 1 => ProposalType::Add,
160 2 => ProposalType::Update,
161 3 => ProposalType::Remove,
162 4 => ProposalType::PreSharedKey,
163 5 => ProposalType::Reinit,
164 6 => ProposalType::ExternalInit,
165 7 => ProposalType::GroupContextExtensions,
166 0x000a => ProposalType::SelfRemove,
167 0x000b => ProposalType::AppAck,
168 other => ProposalType::Custom(other),
169 }
170 }
171}
172
173impl From<ProposalType> for u16 {
174 fn from(value: ProposalType) -> Self {
175 match value {
176 ProposalType::Add => 1,
177 ProposalType::Update => 2,
178 ProposalType::Remove => 3,
179 ProposalType::PreSharedKey => 4,
180 ProposalType::Reinit => 5,
181 ProposalType::ExternalInit => 6,
182 ProposalType::GroupContextExtensions => 7,
183 ProposalType::SelfRemove => 0x000a,
184 ProposalType::AppAck => 0x000b,
185 ProposalType::Custom(id) => id,
186 }
187 }
188}
189
190#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
210#[allow(missing_docs)]
211#[repr(u16)]
212pub enum Proposal {
213 Add(Box<AddProposal>),
214 Update(Box<UpdateProposal>),
215 Remove(Box<RemoveProposal>),
216 PreSharedKey(Box<PreSharedKeyProposal>),
217 ReInit(Box<ReInitProposal>),
218 ExternalInit(Box<ExternalInitProposal>),
219 GroupContextExtensions(Box<GroupContextExtensionProposal>),
220 AppAck(Box<AppAckProposal>),
224 SelfRemove,
226 Custom(Box<CustomProposal>),
227}
228
229impl Proposal {
230 pub(crate) fn remove(r: RemoveProposal) -> Self {
232 Self::Remove(Box::new(r))
233 }
234
235 pub(crate) fn add(a: AddProposal) -> Self {
237 Self::Add(Box::new(a))
238 }
239
240 pub(crate) fn custom(c: CustomProposal) -> Self {
242 Self::Custom(Box::new(c))
243 }
244
245 pub(crate) fn psk(p: PreSharedKeyProposal) -> Self {
247 Self::PreSharedKey(Box::new(p))
248 }
249
250 pub(crate) fn update(p: UpdateProposal) -> Self {
252 Self::Update(Box::new(p))
253 }
254
255 pub(crate) fn group_context_extensions(p: GroupContextExtensionProposal) -> Self {
257 Self::GroupContextExtensions(Box::new(p))
258 }
259
260 pub(crate) fn external_init(p: ExternalInitProposal) -> Self {
262 Self::ExternalInit(Box::new(p))
263 }
264
265 #[cfg(test)]
266 pub(crate) fn re_init(p: ReInitProposal) -> Self {
268 Self::ReInit(Box::new(p))
269 }
270
271 pub fn proposal_type(&self) -> ProposalType {
273 match self {
274 Proposal::Add(_) => ProposalType::Add,
275 Proposal::Update(_) => ProposalType::Update,
276 Proposal::Remove(_) => ProposalType::Remove,
277 Proposal::PreSharedKey(_) => ProposalType::PreSharedKey,
278 Proposal::ReInit(_) => ProposalType::Reinit,
279 Proposal::ExternalInit(_) => ProposalType::ExternalInit,
280 Proposal::GroupContextExtensions(_) => ProposalType::GroupContextExtensions,
281 Proposal::AppAck(_) => ProposalType::AppAck,
282 Proposal::SelfRemove => ProposalType::SelfRemove,
283 Proposal::Custom(custom) => ProposalType::Custom(custom.proposal_type.to_owned()),
284 }
285 }
286
287 pub(crate) fn is_type(&self, proposal_type: ProposalType) -> bool {
288 self.proposal_type() == proposal_type
289 }
290
291 pub fn is_path_required(&self) -> bool {
293 self.proposal_type().is_path_required()
294 }
295
296 pub(crate) fn has_lower_priority_than(&self, new_proposal: &Proposal) -> bool {
297 match (self, new_proposal) {
298 (Proposal::Update(_), _) => true,
300 (Proposal::Remove(_), Proposal::Update(_)) => false,
302 (Proposal::Remove(_), Proposal::Remove(_)) => true,
304 (_, Proposal::SelfRemove) => true,
306 _ => {
308 debug_assert!(false);
309 false
310 }
311 }
312 }
313
314 pub(crate) fn as_remove(&self) -> Option<&RemoveProposal> {
316 if let Self::Remove(v) = self {
317 Some(v)
318 } else {
319 None
320 }
321 }
322
323 #[must_use]
327 pub fn is_remove(&self) -> bool {
328 matches!(self, Self::Remove(..))
329 }
330}
331
332#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
344pub struct AddProposal {
345 pub(crate) key_package: KeyPackage,
346}
347
348impl AddProposal {
349 pub fn key_package(&self) -> &KeyPackage {
351 &self.key_package
352 }
353}
354
355#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
368pub struct UpdateProposal {
369 pub(crate) leaf_node: LeafNode,
370}
371
372impl UpdateProposal {
373 pub fn leaf_node(&self) -> &LeafNode {
375 &self.leaf_node
376 }
377}
378
379#[derive(
391 Debug,
392 PartialEq,
393 Eq,
394 Clone,
395 Serialize,
396 Deserialize,
397 TlsDeserialize,
398 TlsDeserializeBytes,
399 TlsSerialize,
400 TlsSize,
401)]
402pub struct RemoveProposal {
403 pub(crate) removed: LeafNodeIndex,
404}
405
406impl RemoveProposal {
407 pub fn removed(&self) -> LeafNodeIndex {
409 self.removed
410 }
411}
412
413#[derive(
425 Debug,
426 PartialEq,
427 Eq,
428 Clone,
429 Serialize,
430 Deserialize,
431 TlsDeserialize,
432 TlsDeserializeBytes,
433 TlsSerialize,
434 TlsSize,
435)]
436pub struct PreSharedKeyProposal {
437 psk: PreSharedKeyId,
438}
439
440impl PreSharedKeyProposal {
441 pub(crate) fn into_psk_id(self) -> PreSharedKeyId {
443 self.psk
444 }
445}
446
447impl PreSharedKeyProposal {
448 pub fn new(psk: PreSharedKeyId) -> Self {
450 Self { psk }
451 }
452}
453
454#[derive(
471 Debug,
472 PartialEq,
473 Eq,
474 Clone,
475 Serialize,
476 Deserialize,
477 TlsDeserialize,
478 TlsDeserializeBytes,
479 TlsSerialize,
480 TlsSize,
481)]
482pub struct ReInitProposal {
483 pub(crate) group_id: GroupId,
484 pub(crate) version: ProtocolVersion,
485 pub(crate) ciphersuite: Ciphersuite,
486 pub(crate) extensions: Extensions,
487}
488
489#[derive(
501 Debug,
502 PartialEq,
503 Eq,
504 Clone,
505 Serialize,
506 Deserialize,
507 TlsDeserialize,
508 TlsDeserializeBytes,
509 TlsSerialize,
510 TlsSize,
511)]
512pub struct ExternalInitProposal {
513 kem_output: VLBytes,
514}
515
516impl ExternalInitProposal {
517 pub(crate) fn kem_output(&self) -> &[u8] {
519 self.kem_output.as_slice()
520 }
521}
522
523impl From<Vec<u8>> for ExternalInitProposal {
524 fn from(kem_output: Vec<u8>) -> Self {
525 ExternalInitProposal {
526 kem_output: kem_output.into(),
527 }
528 }
529}
530
531#[derive(
535 Debug,
536 PartialEq,
537 Clone,
538 Serialize,
539 Deserialize,
540 TlsDeserialize,
541 TlsDeserializeBytes,
542 TlsSerialize,
543 TlsSize,
544)]
545pub struct AppAckProposal {
546 received_ranges: Vec<MessageRange>,
547}
548
549#[derive(
561 Debug,
562 PartialEq,
563 Eq,
564 Clone,
565 Serialize,
566 Deserialize,
567 TlsDeserialize,
568 TlsDeserializeBytes,
569 TlsSerialize,
570 TlsSize,
571)]
572pub struct GroupContextExtensionProposal {
573 extensions: Extensions,
574}
575
576impl GroupContextExtensionProposal {
577 pub(crate) fn new(extensions: Extensions) -> Self {
579 Self { extensions }
580 }
581
582 pub fn extensions(&self) -> &Extensions {
584 &self.extensions
585 }
586}
587
588#[derive(
611 PartialEq,
612 Clone,
613 Copy,
614 Debug,
615 TlsSerialize,
616 TlsDeserialize,
617 TlsDeserializeBytes,
618 TlsSize,
619 Serialize,
620 Deserialize,
621)]
622#[repr(u8)]
623pub enum ProposalOrRefType {
624 Proposal = 1,
626 Reference = 2,
628}
629
630#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
632#[repr(u8)]
633#[allow(missing_docs)]
634pub(crate) enum ProposalOrRef {
635 #[tls_codec(discriminant = 1)]
636 Proposal(Box<Proposal>),
637 Reference(Box<ProposalRef>),
638}
639
640impl ProposalOrRef {
641 pub(crate) fn proposal(p: Proposal) -> Self {
643 Self::Proposal(Box::new(p))
644 }
645
646 pub(crate) fn reference(p: ProposalRef) -> Self {
648 Self::Reference(Box::new(p))
649 }
650
651 pub(crate) fn as_proposal(&self) -> Option<&Proposal> {
652 if let Self::Proposal(v) = self {
653 Some(v)
654 } else {
655 None
656 }
657 }
658
659 pub(crate) fn as_reference(&self) -> Option<&ProposalRef> {
660 if let Self::Reference(v) = self {
661 Some(v)
662 } else {
663 None
664 }
665 }
666}
667
668impl From<Proposal> for ProposalOrRef {
669 fn from(value: Proposal) -> Self {
670 Self::proposal(value)
671 }
672}
673
674impl From<ProposalRef> for ProposalOrRef {
675 fn from(value: ProposalRef) -> Self {
676 Self::reference(value)
677 }
678}
679
680#[derive(Error, Debug)]
681pub(crate) enum ProposalRefError {
682 #[error("Expected `Proposal`, got `{wrong:?}`.")]
683 AuthenticatedContentHasWrongType { wrong: ContentType },
684 #[error(transparent)]
685 Other(#[from] LibraryError),
686}
687
688impl ProposalRef {
689 pub(crate) fn from_authenticated_content_by_ref(
690 crypto: &impl OpenMlsCrypto,
691 ciphersuite: Ciphersuite,
692 authenticated_content: &AuthenticatedContent,
693 ) -> Result<Self, ProposalRefError> {
694 if !matches!(
695 authenticated_content.content(),
696 FramedContentBody::Proposal(_)
697 ) {
698 return Err(ProposalRefError::AuthenticatedContentHasWrongType {
699 wrong: authenticated_content.content().content_type(),
700 });
701 };
702
703 let encoded = authenticated_content
704 .tls_serialize_detached()
705 .map_err(|error| ProposalRefError::Other(LibraryError::missing_bound_check(error)))?;
706
707 make_proposal_ref(&encoded, ciphersuite, crypto)
708 .map_err(|error| ProposalRefError::Other(LibraryError::unexpected_crypto_error(error)))
709 }
710
711 pub(crate) fn from_raw_proposal(
717 ciphersuite: Ciphersuite,
718 crypto: &impl OpenMlsCrypto,
719 proposal: &Proposal,
720 ) -> Result<Self, LibraryError> {
721 let mut data = b"Internal OpenMLS ProposalRef Label".to_vec();
723
724 let mut encoded = proposal
725 .tls_serialize_detached()
726 .map_err(LibraryError::missing_bound_check)?;
727
728 data.append(&mut encoded);
729
730 make_proposal_ref(&data, ciphersuite, crypto).map_err(LibraryError::unexpected_crypto_error)
731 }
732}
733
734#[derive(
742 Debug,
743 PartialEq,
744 Clone,
745 Serialize,
746 Deserialize,
747 TlsDeserialize,
748 TlsDeserializeBytes,
749 TlsSerialize,
750 TlsSize,
751)]
752pub(crate) struct MessageRange {
753 sender: KeyPackageRef,
754 first_generation: u32,
755 last_generation: u32,
756}
757
758#[derive(
760 Debug,
761 PartialEq,
762 Clone,
763 Serialize,
764 Deserialize,
765 TlsSize,
766 TlsSerialize,
767 TlsDeserialize,
768 TlsDeserializeBytes,
769)]
770pub struct CustomProposal {
771 proposal_type: u16,
772 payload: Vec<u8>,
773}
774
775impl CustomProposal {
776 pub fn new(proposal_type: u16, payload: Vec<u8>) -> Self {
778 Self {
779 proposal_type,
780 payload,
781 }
782 }
783
784 pub fn proposal_type(&self) -> u16 {
786 self.proposal_type
787 }
788
789 pub fn payload(&self) -> &[u8] {
791 &self.payload
792 }
793}
794
795#[cfg(test)]
796mod tests {
797 use tls_codec::{Deserialize, Serialize};
798
799 use super::ProposalType;
800
801 #[test]
802 fn that_unknown_proposal_types_are_de_serialized_correctly() {
803 let proposal_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF000, 0xFFFF];
804
805 for proposal_type in proposal_types.into_iter() {
806 let test = proposal_type.to_be_bytes().to_vec();
808
809 let got = ProposalType::tls_deserialize_exact(&test).unwrap();
811
812 match got {
813 ProposalType::Custom(got_proposal_type) => {
814 assert_eq!(proposal_type, got_proposal_type);
815 }
816 other => panic!("Expected `ProposalType::Unknown`, got `{other:?}`."),
817 }
818
819 let got_serialized = got.tls_serialize_detached().unwrap();
821 assert_eq!(test, got_serialized);
822 }
823 }
824}