1use std::io::{Read, Write};
6
7use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use tls_codec::{
11 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
12 Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
13};
14
15use crate::{
16 binary_tree::array_representation::LeafNodeIndex,
17 ciphersuite::hash_ref::{make_proposal_ref, KeyPackageRef, ProposalRef},
18 error::LibraryError,
19 extensions::Extensions,
20 framing::{
21 mls_auth_content::AuthenticatedContent, mls_content::FramedContentBody, ContentType,
22 },
23 group::{GroupContext, GroupId},
24 key_packages::*,
25 schedule::psk::*,
26 treesync::LeafNode,
27 versions::ProtocolVersion,
28};
29
30#[cfg(feature = "extensions-draft-08")]
31use crate::component::ComponentId;
32
33#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug, Serialize, Deserialize, Hash)]
76#[allow(missing_docs)]
77pub enum ProposalType {
78 Add,
79 Update,
80 Remove,
81 PreSharedKey,
82 Reinit,
83 ExternalInit,
84 GroupContextExtensions,
85 SelfRemove,
86 #[cfg(feature = "extensions-draft-08")]
87 AppEphemeral,
88 #[cfg(feature = "extensions-draft-08")]
89 AppDataUpdate,
90 Grease(u16),
91 Custom(u16),
92}
93
94impl ProposalType {
95 pub(crate) fn is_default(self) -> bool {
98 match self {
99 ProposalType::Add
100 | ProposalType::Update
101 | ProposalType::Remove
102 | ProposalType::PreSharedKey
103 | ProposalType::Reinit
104 | ProposalType::ExternalInit
105 | ProposalType::GroupContextExtensions => true,
106 ProposalType::SelfRemove | ProposalType::Grease(_) | ProposalType::Custom(_) => false,
107 #[cfg(feature = "extensions-draft-08")]
108 ProposalType::AppEphemeral | ProposalType::AppDataUpdate => false,
109 }
110 }
111
112 pub fn is_grease(&self) -> bool {
117 matches!(self, ProposalType::Grease(_))
118 }
119}
120
121impl Size for ProposalType {
122 fn tls_serialized_len(&self) -> usize {
123 2
124 }
125}
126
127impl TlsDeserializeTrait for ProposalType {
128 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
129 where
130 Self: Sized,
131 {
132 let mut proposal_type = [0u8; 2];
133 bytes.read_exact(&mut proposal_type)?;
134
135 Ok(ProposalType::from(u16::from_be_bytes(proposal_type)))
136 }
137}
138
139impl TlsSerializeTrait for ProposalType {
140 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
141 writer.write_all(&u16::from(*self).to_be_bytes())?;
142
143 Ok(2)
144 }
145}
146
147impl DeserializeBytes for ProposalType {
148 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
149 where
150 Self: Sized,
151 {
152 let mut bytes_ref = bytes;
153 let proposal_type = ProposalType::tls_deserialize(&mut bytes_ref)?;
154 let remainder = &bytes[proposal_type.tls_serialized_len()..];
155 Ok((proposal_type, remainder))
156 }
157}
158
159impl ProposalType {
160 pub fn is_path_required(&self) -> bool {
162 matches!(
163 self,
164 Self::Update
165 | Self::Remove
166 | Self::ExternalInit
167 | Self::GroupContextExtensions
168 | Self::SelfRemove
169 )
170 }
171}
172
173impl From<u16> for ProposalType {
174 fn from(value: u16) -> Self {
175 match value {
176 1 => ProposalType::Add,
177 2 => ProposalType::Update,
178 3 => ProposalType::Remove,
179 4 => ProposalType::PreSharedKey,
180 5 => ProposalType::Reinit,
181 6 => ProposalType::ExternalInit,
182 7 => ProposalType::GroupContextExtensions,
183 #[cfg(feature = "extensions-draft-08")]
184 8 => ProposalType::AppDataUpdate,
185 #[cfg(feature = "extensions-draft-08")]
186 0x0009 => ProposalType::AppEphemeral,
187 0x000a => ProposalType::SelfRemove,
188 other if crate::grease::is_grease_value(other) => ProposalType::Grease(other),
189 other => ProposalType::Custom(other),
190 }
191 }
192}
193
194impl From<ProposalType> for u16 {
195 fn from(value: ProposalType) -> Self {
196 match value {
197 ProposalType::Add => 1,
198 ProposalType::Update => 2,
199 ProposalType::Remove => 3,
200 ProposalType::PreSharedKey => 4,
201 ProposalType::Reinit => 5,
202 ProposalType::ExternalInit => 6,
203 ProposalType::GroupContextExtensions => 7,
204 #[cfg(feature = "extensions-draft-08")]
205 ProposalType::AppDataUpdate => 8,
206 #[cfg(feature = "extensions-draft-08")]
207 ProposalType::AppEphemeral => 0x0009,
208 ProposalType::SelfRemove => 0x000a,
209 ProposalType::Grease(id) => id,
210 ProposalType::Custom(id) => id,
211 }
212 }
213}
214
215#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
235#[allow(missing_docs)]
236#[repr(u16)]
237pub enum Proposal {
238 Add(Box<AddProposal>),
239 Update(Box<UpdateProposal>),
240 Remove(Box<RemoveProposal>),
241 PreSharedKey(Box<PreSharedKeyProposal>),
242 ReInit(Box<ReInitProposal>),
243 ExternalInit(Box<ExternalInitProposal>),
244 GroupContextExtensions(Box<GroupContextExtensionProposal>),
245 #[cfg(feature = "extensions-draft-08")]
247 AppDataUpdate(Box<AppDataUpdateProposal>),
248 SelfRemove,
250 #[cfg(feature = "extensions-draft-08")]
251 AppEphemeral(Box<AppEphemeralProposal>),
252 Custom(Box<CustomProposal>),
253}
254
255impl Proposal {
256 pub(crate) fn remove(r: RemoveProposal) -> Self {
258 Self::Remove(Box::new(r))
259 }
260
261 pub(crate) fn add(a: AddProposal) -> Self {
263 Self::Add(Box::new(a))
264 }
265
266 pub(crate) fn custom(c: CustomProposal) -> Self {
268 Self::Custom(Box::new(c))
269 }
270
271 pub(crate) fn psk(p: PreSharedKeyProposal) -> Self {
273 Self::PreSharedKey(Box::new(p))
274 }
275
276 pub(crate) fn update(p: UpdateProposal) -> Self {
278 Self::Update(Box::new(p))
279 }
280
281 pub(crate) fn group_context_extensions(p: GroupContextExtensionProposal) -> Self {
283 Self::GroupContextExtensions(Box::new(p))
284 }
285
286 pub(crate) fn external_init(p: ExternalInitProposal) -> Self {
288 Self::ExternalInit(Box::new(p))
289 }
290
291 #[cfg(test)]
292 pub(crate) fn re_init(p: ReInitProposal) -> Self {
294 Self::ReInit(Box::new(p))
295 }
296
297 pub fn proposal_type(&self) -> ProposalType {
299 match self {
300 Proposal::Add(_) => ProposalType::Add,
301 Proposal::Update(_) => ProposalType::Update,
302 Proposal::Remove(_) => ProposalType::Remove,
303 Proposal::PreSharedKey(_) => ProposalType::PreSharedKey,
304 Proposal::ReInit(_) => ProposalType::Reinit,
305 Proposal::ExternalInit(_) => ProposalType::ExternalInit,
306 Proposal::GroupContextExtensions(_) => ProposalType::GroupContextExtensions,
307 #[cfg(feature = "extensions-draft-08")]
308 Proposal::AppDataUpdate(_) => ProposalType::AppDataUpdate,
309 Proposal::SelfRemove => ProposalType::SelfRemove,
310 #[cfg(feature = "extensions-draft-08")]
311 Proposal::AppEphemeral(_) => ProposalType::AppEphemeral,
312 Proposal::Custom(custom) => ProposalType::Custom(custom.proposal_type.to_owned()),
313 }
314 }
315
316 pub(crate) fn is_type(&self, proposal_type: ProposalType) -> bool {
317 self.proposal_type() == proposal_type
318 }
319
320 pub fn is_path_required(&self) -> bool {
322 self.proposal_type().is_path_required()
323 }
324
325 pub(crate) fn has_lower_priority_than(&self, new_proposal: &Proposal) -> bool {
326 match (self, new_proposal) {
327 (Proposal::Update(_), _) => true,
329 (Proposal::Remove(_), Proposal::Update(_)) => false,
331 (Proposal::Remove(_), Proposal::Remove(_)) => true,
333 (_, Proposal::SelfRemove) => true,
335 _ => {
336 debug_assert!(false);
337 false
338 }
339 }
340 }
341
342 pub(crate) fn as_remove(&self) -> Option<&RemoveProposal> {
344 if let Self::Remove(v) = self {
345 Some(v)
346 } else {
347 None
348 }
349 }
350
351 #[must_use]
355 pub fn is_remove(&self) -> bool {
356 matches!(self, Self::Remove(..))
357 }
358}
359
360#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
372pub struct AddProposal {
373 pub(crate) key_package: KeyPackage,
374}
375
376impl AddProposal {
377 pub fn key_package(&self) -> &KeyPackage {
379 &self.key_package
380 }
381}
382
383impl From<KeyPackage> for AddProposal {
384 fn from(key_package: KeyPackage) -> AddProposal {
385 AddProposal { key_package }
386 }
387}
388
389#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
402pub struct UpdateProposal {
403 pub(crate) leaf_node: LeafNode,
404}
405
406impl UpdateProposal {
407 pub fn leaf_node(&self) -> &LeafNode {
409 &self.leaf_node
410 }
411}
412
413#[derive(
425 Debug,
426 PartialEq,
427 Eq,
428 Clone,
429 Serialize,
430 Deserialize,
431 TlsDeserialize,
432 TlsDeserializeBytes,
433 TlsSerialize,
434 TlsSize,
435)]
436pub struct RemoveProposal {
437 pub(crate) removed: LeafNodeIndex,
438}
439
440impl RemoveProposal {
441 pub fn removed(&self) -> LeafNodeIndex {
443 self.removed
444 }
445}
446
447#[derive(
459 Debug,
460 PartialEq,
461 Eq,
462 Clone,
463 Serialize,
464 Deserialize,
465 TlsDeserialize,
466 TlsDeserializeBytes,
467 TlsSerialize,
468 TlsSize,
469)]
470pub struct PreSharedKeyProposal {
471 psk: PreSharedKeyId,
472}
473
474impl PreSharedKeyProposal {
475 pub(crate) fn into_psk_id(self) -> PreSharedKeyId {
477 self.psk
478 }
479}
480
481impl PreSharedKeyProposal {
482 pub fn new(psk: PreSharedKeyId) -> Self {
484 Self { psk }
485 }
486}
487
488#[derive(
505 Debug,
506 PartialEq,
507 Eq,
508 Clone,
509 Serialize,
510 Deserialize,
511 TlsDeserialize,
512 TlsDeserializeBytes,
513 TlsSerialize,
514 TlsSize,
515)]
516pub struct ReInitProposal {
517 pub(crate) group_id: GroupId,
518 pub(crate) version: ProtocolVersion,
519 pub(crate) ciphersuite: Ciphersuite,
520 pub(crate) extensions: Extensions<GroupContext>,
521}
522
523#[derive(
535 Debug,
536 PartialEq,
537 Eq,
538 Clone,
539 Serialize,
540 Deserialize,
541 TlsDeserialize,
542 TlsDeserializeBytes,
543 TlsSerialize,
544 TlsSize,
545)]
546pub struct ExternalInitProposal {
547 kem_output: VLBytes,
548}
549
550impl ExternalInitProposal {
551 pub(crate) fn kem_output(&self) -> &[u8] {
553 self.kem_output.as_slice()
554 }
555}
556
557impl From<Vec<u8>> for ExternalInitProposal {
558 fn from(kem_output: Vec<u8>) -> Self {
559 ExternalInitProposal {
560 kem_output: kem_output.into(),
561 }
562 }
563}
564
565#[cfg(feature = "extensions-draft-08")]
566#[derive(
570 Debug,
571 PartialEq,
572 Clone,
573 Serialize,
574 Deserialize,
575 TlsDeserialize,
576 TlsDeserializeBytes,
577 TlsSerialize,
578 TlsSize,
579)]
580pub struct AppAck {
581 received_ranges: Vec<MessageRange>,
582}
583
584#[cfg(feature = "extensions-draft-08")]
585#[derive(
587 Debug,
588 PartialEq,
589 Clone,
590 Serialize,
591 Deserialize,
592 TlsDeserialize,
593 TlsDeserializeBytes,
594 TlsSerialize,
595 TlsSize,
596)]
597pub struct AppEphemeralProposal {
598 component_id: ComponentId,
600 data: VLBytes,
602}
603
604#[cfg(feature = "extensions-draft-08")]
605impl AppEphemeralProposal {
606 pub fn new(component_id: ComponentId, data: Vec<u8>) -> Self {
608 Self {
609 component_id,
610 data: data.into(),
611 }
612 }
613 pub fn component_id(&self) -> ComponentId {
615 self.component_id
616 }
617
618 pub fn data(&self) -> &[u8] {
620 self.data.as_slice()
621 }
622}
623
624#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
636pub struct GroupContextExtensionProposal {
637 extensions: Extensions<GroupContext>,
638}
639
640impl Size for GroupContextExtensionProposal {
641 fn tls_serialized_len(&self) -> usize {
642 self.extensions.tls_serialized_len()
643 }
644}
645
646impl TlsSerializeTrait for GroupContextExtensionProposal {
647 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
648 self.extensions.tls_serialize(writer)
649 }
650}
651
652impl GroupContextExtensionProposal {
653 pub(crate) fn new(extensions: Extensions<GroupContext>) -> Self {
655 Self { extensions }
656 }
657
658 pub fn extensions(&self) -> &Extensions<GroupContext> {
660 &self.extensions
661 }
662
663 pub fn into_extensions(self) -> Extensions<GroupContext> {
665 self.extensions
666 }
667}
668
669#[derive(
692 PartialEq,
693 Clone,
694 Copy,
695 Debug,
696 TlsSerialize,
697 TlsDeserialize,
698 TlsDeserializeBytes,
699 TlsSize,
700 Serialize,
701 Deserialize,
702)]
703#[repr(u8)]
704pub enum ProposalOrRefType {
705 Proposal = 1,
707 Reference = 2,
709}
710
711#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
713#[repr(u8)]
714#[allow(missing_docs)]
715pub enum ProposalOrRef {
716 #[tls_codec(discriminant = 1)]
717 Proposal(Box<Proposal>),
718 Reference(Box<ProposalRef>),
719}
720
721impl ProposalOrRef {
722 pub(crate) fn proposal(p: Proposal) -> Self {
724 Self::Proposal(Box::new(p))
725 }
726
727 pub(crate) fn reference(p: ProposalRef) -> Self {
729 Self::Reference(Box::new(p))
730 }
731
732 pub(crate) fn as_proposal(&self) -> Option<&Proposal> {
733 if let Self::Proposal(v) = self {
734 Some(v)
735 } else {
736 None
737 }
738 }
739
740 pub(crate) fn as_reference(&self) -> Option<&ProposalRef> {
741 if let Self::Reference(v) = self {
742 Some(v)
743 } else {
744 None
745 }
746 }
747}
748
749impl From<Proposal> for ProposalOrRef {
750 fn from(value: Proposal) -> Self {
751 Self::proposal(value)
752 }
753}
754
755impl From<ProposalRef> for ProposalOrRef {
756 fn from(value: ProposalRef) -> Self {
757 Self::reference(value)
758 }
759}
760
761#[derive(Error, Debug)]
762pub(crate) enum ProposalRefError {
763 #[error("Expected `Proposal`, got `{wrong:?}`.")]
764 AuthenticatedContentHasWrongType { wrong: ContentType },
765 #[error(transparent)]
766 Other(#[from] LibraryError),
767}
768
769impl ProposalRef {
770 pub(crate) fn from_authenticated_content_by_ref(
771 crypto: &impl OpenMlsCrypto,
772 ciphersuite: Ciphersuite,
773 authenticated_content: &AuthenticatedContent,
774 ) -> Result<Self, ProposalRefError> {
775 if !matches!(
776 authenticated_content.content(),
777 FramedContentBody::Proposal(_)
778 ) {
779 return Err(ProposalRefError::AuthenticatedContentHasWrongType {
780 wrong: authenticated_content.content().content_type(),
781 });
782 };
783
784 let encoded = authenticated_content
785 .tls_serialize_detached()
786 .map_err(|error| ProposalRefError::Other(LibraryError::missing_bound_check(error)))?;
787
788 make_proposal_ref(&encoded, ciphersuite, crypto)
789 .map_err(|error| ProposalRefError::Other(LibraryError::unexpected_crypto_error(error)))
790 }
791
792 pub(crate) fn from_raw_proposal(
798 ciphersuite: Ciphersuite,
799 crypto: &impl OpenMlsCrypto,
800 proposal: &Proposal,
801 ) -> Result<Self, LibraryError> {
802 let mut data = b"Internal OpenMLS ProposalRef Label".to_vec();
804
805 let mut encoded = proposal
806 .tls_serialize_detached()
807 .map_err(LibraryError::missing_bound_check)?;
808
809 data.append(&mut encoded);
810
811 make_proposal_ref(&data, ciphersuite, crypto).map_err(LibraryError::unexpected_crypto_error)
812 }
813}
814
815#[derive(
823 Debug,
824 PartialEq,
825 Clone,
826 Serialize,
827 Deserialize,
828 TlsDeserialize,
829 TlsDeserializeBytes,
830 TlsSerialize,
831 TlsSize,
832)]
833pub(crate) struct MessageRange {
834 sender: KeyPackageRef,
835 first_generation: u32,
836 last_generation: u32,
837}
838
839#[cfg(feature = "extensions-draft-08")]
840mod app_data_update;
841#[cfg(feature = "extensions-draft-08")]
842pub use app_data_update::*;
843
844#[derive(
846 Debug,
847 PartialEq,
848 Clone,
849 Serialize,
850 Deserialize,
851 TlsSize,
852 TlsSerialize,
853 TlsDeserialize,
854 TlsDeserializeBytes,
855)]
856pub struct CustomProposal {
857 proposal_type: u16,
858 payload: Vec<u8>,
859}
860
861impl CustomProposal {
862 pub fn new(proposal_type: u16, payload: Vec<u8>) -> Self {
864 Self {
865 proposal_type,
866 payload,
867 }
868 }
869
870 pub fn proposal_type(&self) -> u16 {
872 self.proposal_type
873 }
874
875 pub fn payload(&self) -> &[u8] {
877 &self.payload
878 }
879}
880
881#[cfg(test)]
882mod tests {
883 use tls_codec::{Deserialize, Serialize};
884
885 use super::ProposalType;
886
887 #[test]
888 fn that_unknown_proposal_types_are_de_serialized_correctly() {
889 let proposal_types = [0x0000u16, 0x0B0B, 0x7C7C, 0xF000, 0xFFFF];
891
892 for proposal_type in proposal_types.into_iter() {
893 let test = proposal_type.to_be_bytes().to_vec();
895
896 let got = ProposalType::tls_deserialize_exact(&test).unwrap();
898
899 match got {
900 ProposalType::Custom(got_proposal_type) => {
901 assert_eq!(proposal_type, got_proposal_type);
902 }
903 other => panic!("Expected `ProposalType::Unknown`, got `{other:?}`."),
904 }
905
906 let got_serialized = got.tls_serialize_detached().unwrap();
908 assert_eq!(test, got_serialized);
909 }
910 }
911}