1use std::io::{Read, Write};
6
7use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use tls_codec::{
11 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
12 Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
13};
14
15use crate::{
16 binary_tree::array_representation::LeafNodeIndex,
17 ciphersuite::hash_ref::{make_proposal_ref, KeyPackageRef, ProposalRef},
18 error::LibraryError,
19 extensions::Extensions,
20 framing::{
21 mls_auth_content::AuthenticatedContent, mls_content::FramedContentBody, ContentType,
22 },
23 group::{GroupContext, GroupId},
24 key_packages::*,
25 schedule::psk::*,
26 treesync::LeafNode,
27 versions::ProtocolVersion,
28};
29
30#[cfg(feature = "extensions-draft-08")]
31use crate::component::ComponentId;
32
33#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug, Serialize, Deserialize, Hash)]
76#[allow(missing_docs)]
77pub enum ProposalType {
78 Add,
79 Update,
80 Remove,
81 PreSharedKey,
82 Reinit,
83 ExternalInit,
84 GroupContextExtensions,
85 SelfRemove,
86 #[cfg(feature = "extensions-draft-08")]
87 AppEphemeral,
88 #[cfg(feature = "extensions-draft-08")]
89 AppDataUpdate,
90 Grease(u16),
91 Custom(u16),
92}
93
94impl ProposalType {
95 pub(crate) fn is_default(self) -> bool {
98 match self {
99 ProposalType::Add
100 | ProposalType::Update
101 | ProposalType::Remove
102 | ProposalType::PreSharedKey
103 | ProposalType::Reinit
104 | ProposalType::ExternalInit
105 | ProposalType::GroupContextExtensions => true,
106 ProposalType::SelfRemove | ProposalType::Grease(_) | ProposalType::Custom(_) => false,
107 #[cfg(feature = "extensions-draft-08")]
108 ProposalType::AppEphemeral | ProposalType::AppDataUpdate => false,
109 }
110 }
111
112 pub fn is_grease(&self) -> bool {
117 matches!(self, ProposalType::Grease(_))
118 }
119}
120
121impl Size for ProposalType {
122 fn tls_serialized_len(&self) -> usize {
123 2
124 }
125}
126
127impl TlsDeserializeTrait for ProposalType {
128 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
129 where
130 Self: Sized,
131 {
132 let mut proposal_type = [0u8; 2];
133 bytes.read_exact(&mut proposal_type)?;
134
135 Ok(ProposalType::from(u16::from_be_bytes(proposal_type)))
136 }
137}
138
139impl TlsSerializeTrait for ProposalType {
140 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
141 writer.write_all(&u16::from(*self).to_be_bytes())?;
142
143 Ok(2)
144 }
145}
146
147impl DeserializeBytes for ProposalType {
148 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
149 where
150 Self: Sized,
151 {
152 let mut bytes_ref = bytes;
153 let proposal_type = ProposalType::tls_deserialize(&mut bytes_ref)?;
154 let remainder = &bytes[proposal_type.tls_serialized_len()..];
155 Ok((proposal_type, remainder))
156 }
157}
158
159impl ProposalType {
160 pub fn is_path_required(&self) -> bool {
162 matches!(
163 self,
164 Self::Update
165 | Self::Remove
166 | Self::ExternalInit
167 | Self::GroupContextExtensions
168 | Self::SelfRemove
169 )
170 }
171}
172
173impl From<u16> for ProposalType {
174 fn from(value: u16) -> Self {
175 match value {
176 1 => ProposalType::Add,
177 2 => ProposalType::Update,
178 3 => ProposalType::Remove,
179 4 => ProposalType::PreSharedKey,
180 5 => ProposalType::Reinit,
181 6 => ProposalType::ExternalInit,
182 7 => ProposalType::GroupContextExtensions,
183 #[cfg(feature = "extensions-draft-08")]
184 8 => ProposalType::AppDataUpdate,
185 #[cfg(feature = "extensions-draft-08")]
186 0x0009 => ProposalType::AppEphemeral,
187 0x000a => ProposalType::SelfRemove,
188 other if crate::grease::is_grease_value(other) => ProposalType::Grease(other),
189 other => ProposalType::Custom(other),
190 }
191 }
192}
193
194impl From<ProposalType> for u16 {
195 fn from(value: ProposalType) -> Self {
196 match value {
197 ProposalType::Add => 1,
198 ProposalType::Update => 2,
199 ProposalType::Remove => 3,
200 ProposalType::PreSharedKey => 4,
201 ProposalType::Reinit => 5,
202 ProposalType::ExternalInit => 6,
203 ProposalType::GroupContextExtensions => 7,
204 #[cfg(feature = "extensions-draft-08")]
205 ProposalType::AppDataUpdate => 8,
206 #[cfg(feature = "extensions-draft-08")]
207 ProposalType::AppEphemeral => 0x0009,
208 ProposalType::SelfRemove => 0x000a,
209 ProposalType::Grease(id) => id,
210 ProposalType::Custom(id) => id,
211 }
212 }
213}
214
215#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
235#[allow(missing_docs)]
236#[repr(u16)]
237pub enum Proposal {
238 Add(Box<AddProposal>),
239 Update(Box<UpdateProposal>),
240 Remove(Box<RemoveProposal>),
241 PreSharedKey(Box<PreSharedKeyProposal>),
242 ReInit(Box<ReInitProposal>),
243 ExternalInit(Box<ExternalInitProposal>),
244 GroupContextExtensions(Box<GroupContextExtensionProposal>),
245 #[cfg(feature = "extensions-draft-08")]
247 AppDataUpdate(Box<AppDataUpdateProposal>),
248 SelfRemove,
250 #[cfg(feature = "extensions-draft-08")]
251 AppEphemeral(Box<AppEphemeralProposal>),
252 Custom(Box<CustomProposal>),
253}
254
255impl Proposal {
256 pub(crate) fn remove(r: RemoveProposal) -> Self {
258 Self::Remove(Box::new(r))
259 }
260
261 pub(crate) fn add(a: AddProposal) -> Self {
263 Self::Add(Box::new(a))
264 }
265
266 pub(crate) fn custom(c: CustomProposal) -> Self {
268 Self::Custom(Box::new(c))
269 }
270
271 pub(crate) fn psk(p: PreSharedKeyProposal) -> Self {
273 Self::PreSharedKey(Box::new(p))
274 }
275
276 pub(crate) fn update(p: UpdateProposal) -> Self {
278 Self::Update(Box::new(p))
279 }
280
281 pub(crate) fn group_context_extensions(p: GroupContextExtensionProposal) -> Self {
283 Self::GroupContextExtensions(Box::new(p))
284 }
285
286 pub(crate) fn external_init(p: ExternalInitProposal) -> Self {
288 Self::ExternalInit(Box::new(p))
289 }
290
291 #[cfg(test)]
292 pub(crate) fn re_init(p: ReInitProposal) -> Self {
294 Self::ReInit(Box::new(p))
295 }
296
297 pub fn proposal_type(&self) -> ProposalType {
299 match self {
300 Proposal::Add(_) => ProposalType::Add,
301 Proposal::Update(_) => ProposalType::Update,
302 Proposal::Remove(_) => ProposalType::Remove,
303 Proposal::PreSharedKey(_) => ProposalType::PreSharedKey,
304 Proposal::ReInit(_) => ProposalType::Reinit,
305 Proposal::ExternalInit(_) => ProposalType::ExternalInit,
306 Proposal::GroupContextExtensions(_) => ProposalType::GroupContextExtensions,
307 #[cfg(feature = "extensions-draft-08")]
308 Proposal::AppDataUpdate(_) => ProposalType::AppDataUpdate,
309 Proposal::SelfRemove => ProposalType::SelfRemove,
310 #[cfg(feature = "extensions-draft-08")]
311 Proposal::AppEphemeral(_) => ProposalType::AppEphemeral,
312 Proposal::Custom(custom) => ProposalType::Custom(custom.proposal_type.to_owned()),
313 }
314 }
315
316 pub(crate) fn is_type(&self, proposal_type: ProposalType) -> bool {
317 self.proposal_type() == proposal_type
318 }
319
320 pub fn is_path_required(&self) -> bool {
322 self.proposal_type().is_path_required()
323 }
324
325 pub(crate) fn has_lower_priority_than(&self, new_proposal: &Proposal) -> bool {
326 match (self, new_proposal) {
327 (Proposal::Update(_), _) => true,
329 (Proposal::Remove(_), Proposal::Update(_)) => false,
331 (Proposal::Remove(_), Proposal::Remove(_)) => true,
333 (_, Proposal::SelfRemove) => true,
335 _ => {
336 debug_assert!(false);
337 false
338 }
339 }
340 }
341
342 pub(crate) fn as_remove(&self) -> Option<&RemoveProposal> {
344 if let Self::Remove(v) = self {
345 Some(v)
346 } else {
347 None
348 }
349 }
350
351 #[must_use]
355 pub fn is_remove(&self) -> bool {
356 matches!(self, Self::Remove(..))
357 }
358}
359
360#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
372pub struct AddProposal {
373 pub(crate) key_package: KeyPackage,
374}
375
376impl AddProposal {
377 pub fn key_package(&self) -> &KeyPackage {
379 &self.key_package
380 }
381}
382
383#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
396pub struct UpdateProposal {
397 pub(crate) leaf_node: LeafNode,
398}
399
400impl UpdateProposal {
401 pub fn leaf_node(&self) -> &LeafNode {
403 &self.leaf_node
404 }
405}
406
407#[derive(
419 Debug,
420 PartialEq,
421 Eq,
422 Clone,
423 Serialize,
424 Deserialize,
425 TlsDeserialize,
426 TlsDeserializeBytes,
427 TlsSerialize,
428 TlsSize,
429)]
430pub struct RemoveProposal {
431 pub(crate) removed: LeafNodeIndex,
432}
433
434impl RemoveProposal {
435 pub fn removed(&self) -> LeafNodeIndex {
437 self.removed
438 }
439}
440
441#[derive(
453 Debug,
454 PartialEq,
455 Eq,
456 Clone,
457 Serialize,
458 Deserialize,
459 TlsDeserialize,
460 TlsDeserializeBytes,
461 TlsSerialize,
462 TlsSize,
463)]
464pub struct PreSharedKeyProposal {
465 psk: PreSharedKeyId,
466}
467
468impl PreSharedKeyProposal {
469 pub(crate) fn into_psk_id(self) -> PreSharedKeyId {
471 self.psk
472 }
473}
474
475impl PreSharedKeyProposal {
476 pub fn new(psk: PreSharedKeyId) -> Self {
478 Self { psk }
479 }
480}
481
482#[derive(
499 Debug,
500 PartialEq,
501 Eq,
502 Clone,
503 Serialize,
504 Deserialize,
505 TlsDeserialize,
506 TlsDeserializeBytes,
507 TlsSerialize,
508 TlsSize,
509)]
510pub struct ReInitProposal {
511 pub(crate) group_id: GroupId,
512 pub(crate) version: ProtocolVersion,
513 pub(crate) ciphersuite: Ciphersuite,
514 pub(crate) extensions: Extensions<GroupContext>,
515}
516
517#[derive(
529 Debug,
530 PartialEq,
531 Eq,
532 Clone,
533 Serialize,
534 Deserialize,
535 TlsDeserialize,
536 TlsDeserializeBytes,
537 TlsSerialize,
538 TlsSize,
539)]
540pub struct ExternalInitProposal {
541 kem_output: VLBytes,
542}
543
544impl ExternalInitProposal {
545 pub(crate) fn kem_output(&self) -> &[u8] {
547 self.kem_output.as_slice()
548 }
549}
550
551impl From<Vec<u8>> for ExternalInitProposal {
552 fn from(kem_output: Vec<u8>) -> Self {
553 ExternalInitProposal {
554 kem_output: kem_output.into(),
555 }
556 }
557}
558
559#[cfg(feature = "extensions-draft-08")]
560#[derive(
564 Debug,
565 PartialEq,
566 Clone,
567 Serialize,
568 Deserialize,
569 TlsDeserialize,
570 TlsDeserializeBytes,
571 TlsSerialize,
572 TlsSize,
573)]
574pub struct AppAck {
575 received_ranges: Vec<MessageRange>,
576}
577
578#[cfg(feature = "extensions-draft-08")]
579#[derive(
581 Debug,
582 PartialEq,
583 Clone,
584 Serialize,
585 Deserialize,
586 TlsDeserialize,
587 TlsDeserializeBytes,
588 TlsSerialize,
589 TlsSize,
590)]
591pub struct AppEphemeralProposal {
592 component_id: ComponentId,
594 data: VLBytes,
596}
597
598#[cfg(feature = "extensions-draft-08")]
599impl AppEphemeralProposal {
600 pub fn new(component_id: ComponentId, data: Vec<u8>) -> Self {
602 Self {
603 component_id,
604 data: data.into(),
605 }
606 }
607 pub fn component_id(&self) -> ComponentId {
609 self.component_id
610 }
611
612 pub fn data(&self) -> &[u8] {
614 self.data.as_slice()
615 }
616}
617
618#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
630pub struct GroupContextExtensionProposal {
631 extensions: Extensions<GroupContext>,
632}
633
634impl Size for GroupContextExtensionProposal {
635 fn tls_serialized_len(&self) -> usize {
636 self.extensions.tls_serialized_len()
637 }
638}
639
640impl TlsSerializeTrait for GroupContextExtensionProposal {
641 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
642 self.extensions.tls_serialize(writer)
643 }
644}
645
646impl GroupContextExtensionProposal {
647 pub(crate) fn new(extensions: Extensions<GroupContext>) -> Self {
649 Self { extensions }
650 }
651
652 pub fn extensions(&self) -> &Extensions<GroupContext> {
654 &self.extensions
655 }
656
657 pub fn into_extensions(self) -> Extensions<GroupContext> {
659 self.extensions
660 }
661}
662
663#[derive(
686 PartialEq,
687 Clone,
688 Copy,
689 Debug,
690 TlsSerialize,
691 TlsDeserialize,
692 TlsDeserializeBytes,
693 TlsSize,
694 Serialize,
695 Deserialize,
696)]
697#[repr(u8)]
698pub enum ProposalOrRefType {
699 Proposal = 1,
701 Reference = 2,
703}
704
705#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
707#[repr(u8)]
708#[allow(missing_docs)]
709pub enum ProposalOrRef {
710 #[tls_codec(discriminant = 1)]
711 Proposal(Box<Proposal>),
712 Reference(Box<ProposalRef>),
713}
714
715impl ProposalOrRef {
716 pub(crate) fn proposal(p: Proposal) -> Self {
718 Self::Proposal(Box::new(p))
719 }
720
721 pub(crate) fn reference(p: ProposalRef) -> Self {
723 Self::Reference(Box::new(p))
724 }
725
726 pub(crate) fn as_proposal(&self) -> Option<&Proposal> {
727 if let Self::Proposal(v) = self {
728 Some(v)
729 } else {
730 None
731 }
732 }
733
734 pub(crate) fn as_reference(&self) -> Option<&ProposalRef> {
735 if let Self::Reference(v) = self {
736 Some(v)
737 } else {
738 None
739 }
740 }
741}
742
743impl From<Proposal> for ProposalOrRef {
744 fn from(value: Proposal) -> Self {
745 Self::proposal(value)
746 }
747}
748
749impl From<ProposalRef> for ProposalOrRef {
750 fn from(value: ProposalRef) -> Self {
751 Self::reference(value)
752 }
753}
754
755#[derive(Error, Debug)]
756pub(crate) enum ProposalRefError {
757 #[error("Expected `Proposal`, got `{wrong:?}`.")]
758 AuthenticatedContentHasWrongType { wrong: ContentType },
759 #[error(transparent)]
760 Other(#[from] LibraryError),
761}
762
763impl ProposalRef {
764 pub(crate) fn from_authenticated_content_by_ref(
765 crypto: &impl OpenMlsCrypto,
766 ciphersuite: Ciphersuite,
767 authenticated_content: &AuthenticatedContent,
768 ) -> Result<Self, ProposalRefError> {
769 if !matches!(
770 authenticated_content.content(),
771 FramedContentBody::Proposal(_)
772 ) {
773 return Err(ProposalRefError::AuthenticatedContentHasWrongType {
774 wrong: authenticated_content.content().content_type(),
775 });
776 };
777
778 let encoded = authenticated_content
779 .tls_serialize_detached()
780 .map_err(|error| ProposalRefError::Other(LibraryError::missing_bound_check(error)))?;
781
782 make_proposal_ref(&encoded, ciphersuite, crypto)
783 .map_err(|error| ProposalRefError::Other(LibraryError::unexpected_crypto_error(error)))
784 }
785
786 pub(crate) fn from_raw_proposal(
792 ciphersuite: Ciphersuite,
793 crypto: &impl OpenMlsCrypto,
794 proposal: &Proposal,
795 ) -> Result<Self, LibraryError> {
796 let mut data = b"Internal OpenMLS ProposalRef Label".to_vec();
798
799 let mut encoded = proposal
800 .tls_serialize_detached()
801 .map_err(LibraryError::missing_bound_check)?;
802
803 data.append(&mut encoded);
804
805 make_proposal_ref(&data, ciphersuite, crypto).map_err(LibraryError::unexpected_crypto_error)
806 }
807}
808
809#[derive(
817 Debug,
818 PartialEq,
819 Clone,
820 Serialize,
821 Deserialize,
822 TlsDeserialize,
823 TlsDeserializeBytes,
824 TlsSerialize,
825 TlsSize,
826)]
827pub(crate) struct MessageRange {
828 sender: KeyPackageRef,
829 first_generation: u32,
830 last_generation: u32,
831}
832
833#[cfg(feature = "extensions-draft-08")]
834mod app_data_update;
835#[cfg(feature = "extensions-draft-08")]
836pub use app_data_update::*;
837
838#[derive(
840 Debug,
841 PartialEq,
842 Clone,
843 Serialize,
844 Deserialize,
845 TlsSize,
846 TlsSerialize,
847 TlsDeserialize,
848 TlsDeserializeBytes,
849)]
850pub struct CustomProposal {
851 proposal_type: u16,
852 payload: Vec<u8>,
853}
854
855impl CustomProposal {
856 pub fn new(proposal_type: u16, payload: Vec<u8>) -> Self {
858 Self {
859 proposal_type,
860 payload,
861 }
862 }
863
864 pub fn proposal_type(&self) -> u16 {
866 self.proposal_type
867 }
868
869 pub fn payload(&self) -> &[u8] {
871 &self.payload
872 }
873}
874
875#[cfg(test)]
876mod tests {
877 use tls_codec::{Deserialize, Serialize};
878
879 use super::ProposalType;
880
881 #[test]
882 fn that_unknown_proposal_types_are_de_serialized_correctly() {
883 let proposal_types = [0x0000u16, 0x0B0B, 0x7C7C, 0xF000, 0xFFFF];
885
886 for proposal_type in proposal_types.into_iter() {
887 let test = proposal_type.to_be_bytes().to_vec();
889
890 let got = ProposalType::tls_deserialize_exact(&test).unwrap();
892
893 match got {
894 ProposalType::Custom(got_proposal_type) => {
895 assert_eq!(proposal_type, got_proposal_type);
896 }
897 other => panic!("Expected `ProposalType::Unknown`, got `{other:?}`."),
898 }
899
900 let got_serialized = got.tls_serialize_detached().unwrap();
902 assert_eq!(test, got_serialized);
903 }
904 }
905}