1use std::io::{Read, Write};
6
7use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use tls_codec::{
11 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
12 Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
13};
14
15use crate::{
16 binary_tree::array_representation::LeafNodeIndex,
17 ciphersuite::hash_ref::{make_proposal_ref, KeyPackageRef, ProposalRef},
18 error::LibraryError,
19 extensions::Extensions,
20 framing::{
21 mls_auth_content::AuthenticatedContent, mls_content::FramedContentBody, ContentType,
22 },
23 group::GroupId,
24 key_packages::*,
25 schedule::psk::*,
26 treesync::LeafNode,
27 versions::ProtocolVersion,
28};
29
30#[cfg(feature = "extensions-draft-08")]
31use crate::component::ComponentId;
32
33#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug, Serialize, Deserialize, Hash)]
76#[allow(missing_docs)]
77pub enum ProposalType {
78 Add,
79 Update,
80 Remove,
81 PreSharedKey,
82 Reinit,
83 ExternalInit,
84 GroupContextExtensions,
85 SelfRemove,
86 #[cfg(feature = "extensions-draft-08")]
87 AppEphemeral,
88 #[cfg(feature = "extensions-draft-08")]
89 AppDataUpdate,
90 Grease(u16),
91 Custom(u16),
92}
93
94impl ProposalType {
95 pub(crate) fn is_default(self) -> bool {
98 match self {
99 ProposalType::Add
100 | ProposalType::Update
101 | ProposalType::Remove
102 | ProposalType::PreSharedKey
103 | ProposalType::Reinit
104 | ProposalType::ExternalInit
105 | ProposalType::GroupContextExtensions => true,
106 ProposalType::SelfRemove | ProposalType::Grease(_) | ProposalType::Custom(_) => false,
107 #[cfg(feature = "extensions-draft-08")]
108 ProposalType::AppEphemeral | ProposalType::AppDataUpdate => false,
109 }
110 }
111
112 pub fn is_grease(&self) -> bool {
117 matches!(self, ProposalType::Grease(_))
118 }
119}
120
121impl Size for ProposalType {
122 fn tls_serialized_len(&self) -> usize {
123 2
124 }
125}
126
127impl TlsDeserializeTrait for ProposalType {
128 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
129 where
130 Self: Sized,
131 {
132 let mut proposal_type = [0u8; 2];
133 bytes.read_exact(&mut proposal_type)?;
134
135 Ok(ProposalType::from(u16::from_be_bytes(proposal_type)))
136 }
137}
138
139impl TlsSerializeTrait for ProposalType {
140 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
141 writer.write_all(&u16::from(*self).to_be_bytes())?;
142
143 Ok(2)
144 }
145}
146
147impl DeserializeBytes for ProposalType {
148 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
149 where
150 Self: Sized,
151 {
152 let mut bytes_ref = bytes;
153 let proposal_type = ProposalType::tls_deserialize(&mut bytes_ref)?;
154 let remainder = &bytes[proposal_type.tls_serialized_len()..];
155 Ok((proposal_type, remainder))
156 }
157}
158
159impl ProposalType {
160 pub fn is_path_required(&self) -> bool {
162 matches!(
163 self,
164 Self::Update
165 | Self::Remove
166 | Self::ExternalInit
167 | Self::GroupContextExtensions
168 | Self::SelfRemove
169 )
170 }
171}
172
173impl From<u16> for ProposalType {
174 fn from(value: u16) -> Self {
175 match value {
176 1 => ProposalType::Add,
177 2 => ProposalType::Update,
178 3 => ProposalType::Remove,
179 4 => ProposalType::PreSharedKey,
180 5 => ProposalType::Reinit,
181 6 => ProposalType::ExternalInit,
182 7 => ProposalType::GroupContextExtensions,
183 #[cfg(feature = "extensions-draft-08")]
184 8 => ProposalType::AppDataUpdate,
185 #[cfg(feature = "extensions-draft-08")]
186 0x0009 => ProposalType::AppEphemeral,
187 0x000a => ProposalType::SelfRemove,
188 other if crate::grease::is_grease_value(other) => ProposalType::Grease(other),
189 other => ProposalType::Custom(other),
190 }
191 }
192}
193
194impl From<ProposalType> for u16 {
195 fn from(value: ProposalType) -> Self {
196 match value {
197 ProposalType::Add => 1,
198 ProposalType::Update => 2,
199 ProposalType::Remove => 3,
200 ProposalType::PreSharedKey => 4,
201 ProposalType::Reinit => 5,
202 ProposalType::ExternalInit => 6,
203 ProposalType::GroupContextExtensions => 7,
204 #[cfg(feature = "extensions-draft-08")]
205 ProposalType::AppDataUpdate => 8,
206 #[cfg(feature = "extensions-draft-08")]
207 ProposalType::AppEphemeral => 0x0009,
208 ProposalType::SelfRemove => 0x000a,
209 ProposalType::Grease(id) => id,
210 ProposalType::Custom(id) => id,
211 }
212 }
213}
214
215#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
235#[allow(missing_docs)]
236#[repr(u16)]
237pub enum Proposal {
238 Add(Box<AddProposal>),
239 Update(Box<UpdateProposal>),
240 Remove(Box<RemoveProposal>),
241 PreSharedKey(Box<PreSharedKeyProposal>),
242 ReInit(Box<ReInitProposal>),
243 ExternalInit(Box<ExternalInitProposal>),
244 GroupContextExtensions(Box<GroupContextExtensionProposal>),
245 #[cfg(feature = "extensions-draft-08")]
247 AppDataUpdate(Box<AppDataUpdateProposal>),
248 SelfRemove,
250 #[cfg(feature = "extensions-draft-08")]
251 AppEphemeral(Box<AppEphemeralProposal>),
252 Custom(Box<CustomProposal>),
253}
254
255impl Proposal {
256 pub(crate) fn remove(r: RemoveProposal) -> Self {
258 Self::Remove(Box::new(r))
259 }
260
261 pub(crate) fn add(a: AddProposal) -> Self {
263 Self::Add(Box::new(a))
264 }
265
266 pub(crate) fn custom(c: CustomProposal) -> Self {
268 Self::Custom(Box::new(c))
269 }
270
271 pub(crate) fn psk(p: PreSharedKeyProposal) -> Self {
273 Self::PreSharedKey(Box::new(p))
274 }
275
276 pub(crate) fn update(p: UpdateProposal) -> Self {
278 Self::Update(Box::new(p))
279 }
280
281 pub(crate) fn group_context_extensions(p: GroupContextExtensionProposal) -> Self {
283 Self::GroupContextExtensions(Box::new(p))
284 }
285
286 pub(crate) fn external_init(p: ExternalInitProposal) -> Self {
288 Self::ExternalInit(Box::new(p))
289 }
290
291 #[cfg(test)]
292 pub(crate) fn re_init(p: ReInitProposal) -> Self {
294 Self::ReInit(Box::new(p))
295 }
296
297 pub fn proposal_type(&self) -> ProposalType {
299 match self {
300 Proposal::Add(_) => ProposalType::Add,
301 Proposal::Update(_) => ProposalType::Update,
302 Proposal::Remove(_) => ProposalType::Remove,
303 Proposal::PreSharedKey(_) => ProposalType::PreSharedKey,
304 Proposal::ReInit(_) => ProposalType::Reinit,
305 Proposal::ExternalInit(_) => ProposalType::ExternalInit,
306 Proposal::GroupContextExtensions(_) => ProposalType::GroupContextExtensions,
307 #[cfg(feature = "extensions-draft-08")]
308 Proposal::AppDataUpdate(_) => ProposalType::AppDataUpdate,
309 Proposal::SelfRemove => ProposalType::SelfRemove,
310 #[cfg(feature = "extensions-draft-08")]
311 Proposal::AppEphemeral(_) => ProposalType::AppEphemeral,
312 Proposal::Custom(custom) => ProposalType::Custom(custom.proposal_type.to_owned()),
313 }
314 }
315
316 pub(crate) fn is_type(&self, proposal_type: ProposalType) -> bool {
317 self.proposal_type() == proposal_type
318 }
319
320 pub fn is_path_required(&self) -> bool {
322 self.proposal_type().is_path_required()
323 }
324
325 pub(crate) fn has_lower_priority_than(&self, new_proposal: &Proposal) -> bool {
326 match (self, new_proposal) {
327 (Proposal::Update(_), _) => true,
329 (Proposal::Remove(_), Proposal::Update(_)) => false,
331 (Proposal::Remove(_), Proposal::Remove(_)) => true,
333 (_, Proposal::SelfRemove) => true,
335 _ => {
336 debug_assert!(false);
337 false
338 }
339 }
340 }
341
342 pub(crate) fn as_remove(&self) -> Option<&RemoveProposal> {
344 if let Self::Remove(v) = self {
345 Some(v)
346 } else {
347 None
348 }
349 }
350
351 #[must_use]
355 pub fn is_remove(&self) -> bool {
356 matches!(self, Self::Remove(..))
357 }
358}
359
360#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
372pub struct AddProposal {
373 pub(crate) key_package: KeyPackage,
374}
375
376impl AddProposal {
377 pub fn key_package(&self) -> &KeyPackage {
379 &self.key_package
380 }
381}
382
383#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
396pub struct UpdateProposal {
397 pub(crate) leaf_node: LeafNode,
398}
399
400impl UpdateProposal {
401 pub fn leaf_node(&self) -> &LeafNode {
403 &self.leaf_node
404 }
405}
406
407#[derive(
419 Debug,
420 PartialEq,
421 Eq,
422 Clone,
423 Serialize,
424 Deserialize,
425 TlsDeserialize,
426 TlsDeserializeBytes,
427 TlsSerialize,
428 TlsSize,
429)]
430pub struct RemoveProposal {
431 pub(crate) removed: LeafNodeIndex,
432}
433
434impl RemoveProposal {
435 pub fn removed(&self) -> LeafNodeIndex {
437 self.removed
438 }
439}
440
441#[derive(
453 Debug,
454 PartialEq,
455 Eq,
456 Clone,
457 Serialize,
458 Deserialize,
459 TlsDeserialize,
460 TlsDeserializeBytes,
461 TlsSerialize,
462 TlsSize,
463)]
464pub struct PreSharedKeyProposal {
465 psk: PreSharedKeyId,
466}
467
468impl PreSharedKeyProposal {
469 pub(crate) fn into_psk_id(self) -> PreSharedKeyId {
471 self.psk
472 }
473}
474
475impl PreSharedKeyProposal {
476 pub fn new(psk: PreSharedKeyId) -> Self {
478 Self { psk }
479 }
480}
481
482#[derive(
499 Debug,
500 PartialEq,
501 Eq,
502 Clone,
503 Serialize,
504 Deserialize,
505 TlsDeserialize,
506 TlsDeserializeBytes,
507 TlsSerialize,
508 TlsSize,
509)]
510pub struct ReInitProposal {
511 pub(crate) group_id: GroupId,
512 pub(crate) version: ProtocolVersion,
513 pub(crate) ciphersuite: Ciphersuite,
514 pub(crate) extensions: Extensions,
515}
516
517#[derive(
529 Debug,
530 PartialEq,
531 Eq,
532 Clone,
533 Serialize,
534 Deserialize,
535 TlsDeserialize,
536 TlsDeserializeBytes,
537 TlsSerialize,
538 TlsSize,
539)]
540pub struct ExternalInitProposal {
541 kem_output: VLBytes,
542}
543
544impl ExternalInitProposal {
545 pub(crate) fn kem_output(&self) -> &[u8] {
547 self.kem_output.as_slice()
548 }
549}
550
551impl From<Vec<u8>> for ExternalInitProposal {
552 fn from(kem_output: Vec<u8>) -> Self {
553 ExternalInitProposal {
554 kem_output: kem_output.into(),
555 }
556 }
557}
558
559#[cfg(feature = "extensions-draft-08")]
560#[derive(
564 Debug,
565 PartialEq,
566 Clone,
567 Serialize,
568 Deserialize,
569 TlsDeserialize,
570 TlsDeserializeBytes,
571 TlsSerialize,
572 TlsSize,
573)]
574pub struct AppAck {
575 received_ranges: Vec<MessageRange>,
576}
577
578#[cfg(feature = "extensions-draft-08")]
579#[derive(
581 Debug,
582 PartialEq,
583 Clone,
584 Serialize,
585 Deserialize,
586 TlsDeserialize,
587 TlsDeserializeBytes,
588 TlsSerialize,
589 TlsSize,
590)]
591pub struct AppEphemeralProposal {
592 component_id: ComponentId,
594 data: VLBytes,
596}
597
598#[cfg(feature = "extensions-draft-08")]
599impl AppEphemeralProposal {
600 pub fn new(component_id: ComponentId, data: Vec<u8>) -> Self {
602 Self {
603 component_id,
604 data: data.into(),
605 }
606 }
607 pub fn component_id(&self) -> ComponentId {
609 self.component_id
610 }
611
612 pub fn data(&self) -> &[u8] {
614 self.data.as_slice()
615 }
616}
617
618#[derive(
630 Debug,
631 PartialEq,
632 Eq,
633 Clone,
634 Serialize,
635 Deserialize,
636 TlsDeserialize,
637 TlsDeserializeBytes,
638 TlsSerialize,
639 TlsSize,
640)]
641pub struct GroupContextExtensionProposal {
642 extensions: Extensions,
643}
644
645impl GroupContextExtensionProposal {
646 pub(crate) fn new(extensions: Extensions) -> Self {
648 Self { extensions }
649 }
650
651 pub fn extensions(&self) -> &Extensions {
653 &self.extensions
654 }
655}
656
657#[derive(
680 PartialEq,
681 Clone,
682 Copy,
683 Debug,
684 TlsSerialize,
685 TlsDeserialize,
686 TlsDeserializeBytes,
687 TlsSize,
688 Serialize,
689 Deserialize,
690)]
691#[repr(u8)]
692pub enum ProposalOrRefType {
693 Proposal = 1,
695 Reference = 2,
697}
698
699#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
701#[repr(u8)]
702#[allow(missing_docs)]
703pub enum ProposalOrRef {
704 #[tls_codec(discriminant = 1)]
705 Proposal(Box<Proposal>),
706 Reference(Box<ProposalRef>),
707}
708
709impl ProposalOrRef {
710 pub(crate) fn proposal(p: Proposal) -> Self {
712 Self::Proposal(Box::new(p))
713 }
714
715 pub(crate) fn reference(p: ProposalRef) -> Self {
717 Self::Reference(Box::new(p))
718 }
719
720 pub(crate) fn as_proposal(&self) -> Option<&Proposal> {
721 if let Self::Proposal(v) = self {
722 Some(v)
723 } else {
724 None
725 }
726 }
727
728 pub(crate) fn as_reference(&self) -> Option<&ProposalRef> {
729 if let Self::Reference(v) = self {
730 Some(v)
731 } else {
732 None
733 }
734 }
735}
736
737impl From<Proposal> for ProposalOrRef {
738 fn from(value: Proposal) -> Self {
739 Self::proposal(value)
740 }
741}
742
743impl From<ProposalRef> for ProposalOrRef {
744 fn from(value: ProposalRef) -> Self {
745 Self::reference(value)
746 }
747}
748
749#[derive(Error, Debug)]
750pub(crate) enum ProposalRefError {
751 #[error("Expected `Proposal`, got `{wrong:?}`.")]
752 AuthenticatedContentHasWrongType { wrong: ContentType },
753 #[error(transparent)]
754 Other(#[from] LibraryError),
755}
756
757impl ProposalRef {
758 pub(crate) fn from_authenticated_content_by_ref(
759 crypto: &impl OpenMlsCrypto,
760 ciphersuite: Ciphersuite,
761 authenticated_content: &AuthenticatedContent,
762 ) -> Result<Self, ProposalRefError> {
763 if !matches!(
764 authenticated_content.content(),
765 FramedContentBody::Proposal(_)
766 ) {
767 return Err(ProposalRefError::AuthenticatedContentHasWrongType {
768 wrong: authenticated_content.content().content_type(),
769 });
770 };
771
772 let encoded = authenticated_content
773 .tls_serialize_detached()
774 .map_err(|error| ProposalRefError::Other(LibraryError::missing_bound_check(error)))?;
775
776 make_proposal_ref(&encoded, ciphersuite, crypto)
777 .map_err(|error| ProposalRefError::Other(LibraryError::unexpected_crypto_error(error)))
778 }
779
780 pub(crate) fn from_raw_proposal(
786 ciphersuite: Ciphersuite,
787 crypto: &impl OpenMlsCrypto,
788 proposal: &Proposal,
789 ) -> Result<Self, LibraryError> {
790 let mut data = b"Internal OpenMLS ProposalRef Label".to_vec();
792
793 let mut encoded = proposal
794 .tls_serialize_detached()
795 .map_err(LibraryError::missing_bound_check)?;
796
797 data.append(&mut encoded);
798
799 make_proposal_ref(&data, ciphersuite, crypto).map_err(LibraryError::unexpected_crypto_error)
800 }
801}
802
803#[derive(
811 Debug,
812 PartialEq,
813 Clone,
814 Serialize,
815 Deserialize,
816 TlsDeserialize,
817 TlsDeserializeBytes,
818 TlsSerialize,
819 TlsSize,
820)]
821pub(crate) struct MessageRange {
822 sender: KeyPackageRef,
823 first_generation: u32,
824 last_generation: u32,
825}
826
827#[cfg(feature = "extensions-draft-08")]
828mod app_data_update;
829#[cfg(feature = "extensions-draft-08")]
830pub use app_data_update::*;
831
832#[derive(
834 Debug,
835 PartialEq,
836 Clone,
837 Serialize,
838 Deserialize,
839 TlsSize,
840 TlsSerialize,
841 TlsDeserialize,
842 TlsDeserializeBytes,
843)]
844pub struct CustomProposal {
845 proposal_type: u16,
846 payload: Vec<u8>,
847}
848
849impl CustomProposal {
850 pub fn new(proposal_type: u16, payload: Vec<u8>) -> Self {
852 Self {
853 proposal_type,
854 payload,
855 }
856 }
857
858 pub fn proposal_type(&self) -> u16 {
860 self.proposal_type
861 }
862
863 pub fn payload(&self) -> &[u8] {
865 &self.payload
866 }
867}
868
869#[cfg(test)]
870mod tests {
871 use tls_codec::{Deserialize, Serialize};
872
873 use super::ProposalType;
874
875 #[test]
876 fn that_unknown_proposal_types_are_de_serialized_correctly() {
877 let proposal_types = [0x0000u16, 0x0B0B, 0x7C7C, 0xF000, 0xFFFF];
879
880 for proposal_type in proposal_types.into_iter() {
881 let test = proposal_type.to_be_bytes().to_vec();
883
884 let got = ProposalType::tls_deserialize_exact(&test).unwrap();
886
887 match got {
888 ProposalType::Custom(got_proposal_type) => {
889 assert_eq!(proposal_type, got_proposal_type);
890 }
891 other => panic!("Expected `ProposalType::Unknown`, got `{other:?}`."),
892 }
893
894 let got_serialized = got.tls_serialize_detached().unwrap();
896 assert_eq!(test, got_serialized);
897 }
898 }
899}