1use std::io::{Read, Write};
6
7use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use tls_codec::{
11 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
12 Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
13};
14
15use crate::{
16 binary_tree::array_representation::LeafNodeIndex,
17 ciphersuite::hash_ref::{make_proposal_ref, KeyPackageRef, ProposalRef},
18 error::LibraryError,
19 extensions::Extensions,
20 framing::{
21 mls_auth_content::AuthenticatedContent, mls_content::FramedContentBody, ContentType,
22 },
23 group::GroupId,
24 key_packages::*,
25 prelude::LeafNode,
26 schedule::psk::*,
27 versions::ProtocolVersion,
28};
29
30#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug, Serialize, Deserialize, Hash)]
73#[allow(missing_docs)]
74pub enum ProposalType {
75 Add,
76 Update,
77 Remove,
78 PreSharedKey,
79 Reinit,
80 ExternalInit,
81 GroupContextExtensions,
82 AppAck,
83 SelfRemove,
84 Custom(u16),
85}
86
87impl ProposalType {
88 pub(crate) fn is_default(self) -> bool {
91 match self {
92 ProposalType::Add
93 | ProposalType::Update
94 | ProposalType::Remove
95 | ProposalType::PreSharedKey
96 | ProposalType::Reinit
97 | ProposalType::ExternalInit
98 | ProposalType::GroupContextExtensions => true,
99 ProposalType::SelfRemove | ProposalType::AppAck | ProposalType::Custom(_) => false,
100 }
101 }
102}
103
104impl Size for ProposalType {
105 fn tls_serialized_len(&self) -> usize {
106 2
107 }
108}
109
110impl TlsDeserializeTrait for ProposalType {
111 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
112 where
113 Self: Sized,
114 {
115 let mut proposal_type = [0u8; 2];
116 bytes.read_exact(&mut proposal_type)?;
117
118 Ok(ProposalType::from(u16::from_be_bytes(proposal_type)))
119 }
120}
121
122impl TlsSerializeTrait for ProposalType {
123 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
124 writer.write_all(&u16::from(*self).to_be_bytes())?;
125
126 Ok(2)
127 }
128}
129
130impl DeserializeBytes for ProposalType {
131 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
132 where
133 Self: Sized,
134 {
135 let mut bytes_ref = bytes;
136 let proposal_type = ProposalType::tls_deserialize(&mut bytes_ref)?;
137 let remainder = &bytes[proposal_type.tls_serialized_len()..];
138 Ok((proposal_type, remainder))
139 }
140}
141
142impl ProposalType {
143 pub fn is_path_required(&self) -> bool {
145 matches!(
146 self,
147 Self::Update
148 | Self::Remove
149 | Self::ExternalInit
150 | Self::GroupContextExtensions
151 | Self::SelfRemove
152 )
153 }
154}
155
156impl From<u16> for ProposalType {
157 fn from(value: u16) -> Self {
158 match value {
159 1 => ProposalType::Add,
160 2 => ProposalType::Update,
161 3 => ProposalType::Remove,
162 4 => ProposalType::PreSharedKey,
163 5 => ProposalType::Reinit,
164 6 => ProposalType::ExternalInit,
165 7 => ProposalType::GroupContextExtensions,
166 0x000a => ProposalType::SelfRemove,
167 0x000b => ProposalType::AppAck,
168 other => ProposalType::Custom(other),
169 }
170 }
171}
172
173impl From<ProposalType> for u16 {
174 fn from(value: ProposalType) -> Self {
175 match value {
176 ProposalType::Add => 1,
177 ProposalType::Update => 2,
178 ProposalType::Remove => 3,
179 ProposalType::PreSharedKey => 4,
180 ProposalType::Reinit => 5,
181 ProposalType::ExternalInit => 6,
182 ProposalType::GroupContextExtensions => 7,
183 ProposalType::SelfRemove => 0x000a,
184 ProposalType::AppAck => 0x000b,
185 ProposalType::Custom(id) => id,
186 }
187 }
188}
189
190#[allow(clippy::large_enum_variant)]
210#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
211#[allow(missing_docs)]
212#[repr(u16)]
213pub enum Proposal {
214 Add(AddProposal),
215 Update(UpdateProposal),
216 Remove(RemoveProposal),
217 PreSharedKey(PreSharedKeyProposal),
218 ReInit(ReInitProposal),
219 ExternalInit(ExternalInitProposal),
220 GroupContextExtensions(GroupContextExtensionProposal),
221 AppAck(AppAckProposal),
225 SelfRemove,
227 Custom(CustomProposal),
228}
229
230impl Proposal {
231 pub fn proposal_type(&self) -> ProposalType {
233 match self {
234 Proposal::Add(_) => ProposalType::Add,
235 Proposal::Update(_) => ProposalType::Update,
236 Proposal::Remove(_) => ProposalType::Remove,
237 Proposal::PreSharedKey(_) => ProposalType::PreSharedKey,
238 Proposal::ReInit(_) => ProposalType::Reinit,
239 Proposal::ExternalInit(_) => ProposalType::ExternalInit,
240 Proposal::GroupContextExtensions(_) => ProposalType::GroupContextExtensions,
241 Proposal::AppAck(_) => ProposalType::AppAck,
242 Proposal::SelfRemove => ProposalType::SelfRemove,
243 Proposal::Custom(CustomProposal {
244 proposal_type,
245 payload: _,
246 }) => ProposalType::Custom(proposal_type.to_owned()),
247 }
248 }
249
250 pub(crate) fn is_type(&self, proposal_type: ProposalType) -> bool {
251 self.proposal_type() == proposal_type
252 }
253
254 pub fn is_path_required(&self) -> bool {
256 self.proposal_type().is_path_required()
257 }
258
259 pub(crate) fn has_lower_priority_than(&self, new_proposal: &Proposal) -> bool {
260 match (self, new_proposal) {
261 (Proposal::Update(_), _) => true,
263 (Proposal::Remove(_), Proposal::Update(_)) => false,
265 (Proposal::Remove(_), Proposal::Remove(_)) => true,
267 (_, Proposal::SelfRemove) => true,
269 _ => {
271 debug_assert!(false);
272 false
273 }
274 }
275 }
276}
277
278#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
290pub struct AddProposal {
291 pub(crate) key_package: KeyPackage,
292}
293
294impl AddProposal {
295 pub fn key_package(&self) -> &KeyPackage {
297 &self.key_package
298 }
299}
300
301#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
314pub struct UpdateProposal {
315 pub(crate) leaf_node: LeafNode,
316}
317
318impl UpdateProposal {
319 pub fn leaf_node(&self) -> &LeafNode {
321 &self.leaf_node
322 }
323}
324
325#[derive(
337 Debug,
338 PartialEq,
339 Eq,
340 Clone,
341 Serialize,
342 Deserialize,
343 TlsDeserialize,
344 TlsDeserializeBytes,
345 TlsSerialize,
346 TlsSize,
347)]
348pub struct RemoveProposal {
349 pub(crate) removed: LeafNodeIndex,
350}
351
352impl RemoveProposal {
353 pub fn removed(&self) -> LeafNodeIndex {
355 self.removed
356 }
357}
358
359#[derive(
371 Debug,
372 PartialEq,
373 Eq,
374 Clone,
375 Serialize,
376 Deserialize,
377 TlsDeserialize,
378 TlsDeserializeBytes,
379 TlsSerialize,
380 TlsSize,
381)]
382pub struct PreSharedKeyProposal {
383 psk: PreSharedKeyId,
384}
385
386impl PreSharedKeyProposal {
387 pub(crate) fn into_psk_id(self) -> PreSharedKeyId {
389 self.psk
390 }
391}
392
393impl PreSharedKeyProposal {
394 pub fn new(psk: PreSharedKeyId) -> Self {
396 Self { psk }
397 }
398}
399
400#[derive(
417 Debug,
418 PartialEq,
419 Eq,
420 Clone,
421 Serialize,
422 Deserialize,
423 TlsDeserialize,
424 TlsDeserializeBytes,
425 TlsSerialize,
426 TlsSize,
427)]
428pub struct ReInitProposal {
429 pub(crate) group_id: GroupId,
430 pub(crate) version: ProtocolVersion,
431 pub(crate) ciphersuite: Ciphersuite,
432 pub(crate) extensions: Extensions,
433}
434
435#[derive(
447 Debug,
448 PartialEq,
449 Eq,
450 Clone,
451 Serialize,
452 Deserialize,
453 TlsDeserialize,
454 TlsDeserializeBytes,
455 TlsSerialize,
456 TlsSize,
457)]
458pub struct ExternalInitProposal {
459 kem_output: VLBytes,
460}
461
462impl ExternalInitProposal {
463 pub(crate) fn kem_output(&self) -> &[u8] {
465 self.kem_output.as_slice()
466 }
467}
468
469impl From<Vec<u8>> for ExternalInitProposal {
470 fn from(kem_output: Vec<u8>) -> Self {
471 ExternalInitProposal {
472 kem_output: kem_output.into(),
473 }
474 }
475}
476
477#[derive(
481 Debug,
482 PartialEq,
483 Clone,
484 Serialize,
485 Deserialize,
486 TlsDeserialize,
487 TlsDeserializeBytes,
488 TlsSerialize,
489 TlsSize,
490)]
491pub struct AppAckProposal {
492 received_ranges: Vec<MessageRange>,
493}
494
495#[derive(
507 Debug,
508 PartialEq,
509 Eq,
510 Clone,
511 Serialize,
512 Deserialize,
513 TlsDeserialize,
514 TlsDeserializeBytes,
515 TlsSerialize,
516 TlsSize,
517)]
518pub struct GroupContextExtensionProposal {
519 extensions: Extensions,
520}
521
522impl GroupContextExtensionProposal {
523 pub(crate) fn new(extensions: Extensions) -> Self {
525 Self { extensions }
526 }
527
528 pub fn extensions(&self) -> &Extensions {
530 &self.extensions
531 }
532}
533
534#[derive(
557 PartialEq,
558 Clone,
559 Copy,
560 Debug,
561 TlsSerialize,
562 TlsDeserialize,
563 TlsDeserializeBytes,
564 TlsSize,
565 Serialize,
566 Deserialize,
567)]
568#[repr(u8)]
569pub enum ProposalOrRefType {
570 Proposal = 1,
572 Reference = 2,
574}
575
576#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
578#[repr(u8)]
579#[allow(missing_docs)]
580#[allow(clippy::large_enum_variant)]
581pub(crate) enum ProposalOrRef {
582 #[tls_codec(discriminant = 1)]
583 Proposal(Proposal),
584 Reference(ProposalRef),
585}
586
587#[derive(Error, Debug)]
588pub(crate) enum ProposalRefError {
589 #[error("Expected `Proposal`, got `{wrong:?}`.")]
590 AuthenticatedContentHasWrongType { wrong: ContentType },
591 #[error(transparent)]
592 Other(#[from] LibraryError),
593}
594
595impl ProposalRef {
596 pub(crate) fn from_authenticated_content_by_ref(
597 crypto: &impl OpenMlsCrypto,
598 ciphersuite: Ciphersuite,
599 authenticated_content: &AuthenticatedContent,
600 ) -> Result<Self, ProposalRefError> {
601 if !matches!(
602 authenticated_content.content(),
603 FramedContentBody::Proposal(_)
604 ) {
605 return Err(ProposalRefError::AuthenticatedContentHasWrongType {
606 wrong: authenticated_content.content().content_type(),
607 });
608 };
609
610 let encoded = authenticated_content
611 .tls_serialize_detached()
612 .map_err(|error| ProposalRefError::Other(LibraryError::missing_bound_check(error)))?;
613
614 make_proposal_ref(&encoded, ciphersuite, crypto)
615 .map_err(|error| ProposalRefError::Other(LibraryError::unexpected_crypto_error(error)))
616 }
617
618 pub(crate) fn from_raw_proposal(
624 ciphersuite: Ciphersuite,
625 crypto: &impl OpenMlsCrypto,
626 proposal: &Proposal,
627 ) -> Result<Self, LibraryError> {
628 let mut data = b"Internal OpenMLS ProposalRef Label".to_vec();
630
631 let mut encoded = proposal
632 .tls_serialize_detached()
633 .map_err(LibraryError::missing_bound_check)?;
634
635 data.append(&mut encoded);
636
637 make_proposal_ref(&data, ciphersuite, crypto).map_err(LibraryError::unexpected_crypto_error)
638 }
639}
640
641#[derive(
649 Debug,
650 PartialEq,
651 Clone,
652 Serialize,
653 Deserialize,
654 TlsDeserialize,
655 TlsDeserializeBytes,
656 TlsSerialize,
657 TlsSize,
658)]
659pub(crate) struct MessageRange {
660 sender: KeyPackageRef,
661 first_generation: u32,
662 last_generation: u32,
663}
664
665#[derive(
667 Debug,
668 PartialEq,
669 Clone,
670 Serialize,
671 Deserialize,
672 TlsSize,
673 TlsSerialize,
674 TlsDeserialize,
675 TlsDeserializeBytes,
676)]
677pub struct CustomProposal {
678 proposal_type: u16,
679 payload: Vec<u8>,
680}
681
682impl CustomProposal {
683 pub fn new(proposal_type: u16, payload: Vec<u8>) -> Self {
685 Self {
686 proposal_type,
687 payload,
688 }
689 }
690
691 pub fn proposal_type(&self) -> u16 {
693 self.proposal_type
694 }
695
696 pub fn payload(&self) -> &[u8] {
698 &self.payload
699 }
700}
701
702#[cfg(test)]
703mod tests {
704 use tls_codec::{Deserialize, Serialize};
705
706 use super::ProposalType;
707
708 #[test]
709 fn that_unknown_proposal_types_are_de_serialized_correctly() {
710 let proposal_types = [0x0000u16, 0x0A0A, 0x7A7A, 0xF000, 0xFFFF];
711
712 for proposal_type in proposal_types.into_iter() {
713 let test = proposal_type.to_be_bytes().to_vec();
715
716 let got = ProposalType::tls_deserialize_exact(&test).unwrap();
718
719 match got {
720 ProposalType::Custom(got_proposal_type) => {
721 assert_eq!(proposal_type, got_proposal_type);
722 }
723 other => panic!("Expected `ProposalType::Unknown`, got `{other:?}`."),
724 }
725
726 let got_serialized = got.tls_serialize_detached().unwrap();
728 assert_eq!(test, got_serialized);
729 }
730 }
731}