1use std::collections::{hash_map::Entry, HashMap, HashSet};
2
3use openmls_traits::crypto::OpenMlsCrypto;
4use openmls_traits::types::Ciphersuite;
5use serde::{Deserialize, Serialize};
6
7use crate::{
8 binary_tree::array_representation::LeafNodeIndex,
9 ciphersuite::hash_ref::ProposalRef,
10 error::LibraryError,
11 framing::{mls_auth_content::AuthenticatedContent, mls_content::FramedContentBody, Sender},
12 group::errors::*,
13 messages::proposals::{
14 AddProposal, PreSharedKeyProposal, Proposal, ProposalOrRef, ProposalOrRefType,
15 ProposalType, RemoveProposal, UpdateProposal,
16 },
17 utils::vector_converter,
18};
19
20#[derive(Debug, Clone)]
21pub(crate) struct SelfRemoveInStore {
22 pub(crate) sender: LeafNodeIndex,
23 pub(crate) proposal_ref: ProposalRef,
24}
25
26#[derive(Debug, Default, Serialize, Deserialize, PartialEq)]
29#[cfg_attr(any(test, feature = "test-utils"), derive(Clone))]
30pub struct ProposalStore {
31 queued_proposals: Vec<QueuedProposal>,
32}
33
34impl ProposalStore {
35 pub fn new() -> Self {
37 Self {
38 queued_proposals: Vec::new(),
39 }
40 }
41 #[cfg(test)]
42 pub(crate) fn from_queued_proposal(queued_proposal: QueuedProposal) -> Self {
43 Self {
44 queued_proposals: vec![queued_proposal],
45 }
46 }
47 pub(crate) fn add(&mut self, queued_proposal: QueuedProposal) {
48 self.queued_proposals.push(queued_proposal);
49 }
50 pub(crate) fn proposals(&self) -> impl Iterator<Item = &QueuedProposal> {
51 self.queued_proposals.iter()
52 }
53 pub(crate) fn is_empty(&self) -> bool {
54 self.queued_proposals.is_empty()
55 }
56 pub(crate) fn empty(&mut self) {
57 self.queued_proposals.clear();
58 }
59
60 pub(crate) fn remove(&mut self, proposal_ref: &ProposalRef) -> Option<()> {
63 let index = self
64 .queued_proposals
65 .iter()
66 .position(|p| &p.proposal_reference() == proposal_ref)?;
67 self.queued_proposals.remove(index);
68 Some(())
69 }
70
71 pub(crate) fn self_removes(&self) -> Vec<SelfRemoveInStore> {
72 self.queued_proposals
73 .iter()
74 .filter_map(|queued_proposal| {
75 match (queued_proposal.proposal(), queued_proposal.sender()) {
76 (Proposal::SelfRemove, Sender::Member(sender_index)) => {
77 Some(SelfRemoveInStore {
78 sender: *sender_index,
79 proposal_ref: queued_proposal.proposal_reference(),
80 })
81 }
82 _ => None,
83 }
84 })
85 .collect()
86 }
87}
88
89#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
92pub struct QueuedProposal {
93 proposal: Proposal,
94 proposal_reference: ProposalRef,
95 sender: Sender,
96 proposal_or_ref_type: ProposalOrRefType,
97}
98
99impl QueuedProposal {
100 pub(crate) fn from_authenticated_content_by_ref(
102 ciphersuite: Ciphersuite,
103 crypto: &impl OpenMlsCrypto,
104 public_message: AuthenticatedContent,
105 ) -> Result<Self, LibraryError> {
106 Self::from_authenticated_content(
107 ciphersuite,
108 crypto,
109 public_message,
110 ProposalOrRefType::Reference,
111 )
112 }
113
114 pub(crate) fn from_authenticated_content(
116 ciphersuite: Ciphersuite,
117 crypto: &impl OpenMlsCrypto,
118 public_message: AuthenticatedContent,
119 proposal_or_ref_type: ProposalOrRefType,
120 ) -> Result<Self, LibraryError> {
121 let proposal_reference =
122 ProposalRef::from_authenticated_content_by_ref(crypto, ciphersuite, &public_message)
123 .map_err(|_| LibraryError::custom("Could not calculate `ProposalRef`."))?;
124
125 let (body, sender) = public_message.into_body_and_sender();
126
127 let proposal = match body {
128 FramedContentBody::Proposal(p) => p,
129 _ => return Err(LibraryError::custom("Wrong content type")),
130 };
131
132 Ok(Self {
133 proposal,
134 proposal_reference,
135 sender,
136 proposal_or_ref_type,
137 })
138 }
139
140 pub(crate) fn from_proposal_and_sender(
146 ciphersuite: Ciphersuite,
147 crypto: &impl OpenMlsCrypto,
148 proposal: Proposal,
149 sender: &Sender,
150 ) -> Result<Self, LibraryError> {
151 let proposal_reference = ProposalRef::from_raw_proposal(ciphersuite, crypto, &proposal)?;
152 Ok(Self {
153 proposal,
154 proposal_reference,
155 sender: sender.clone(),
156 proposal_or_ref_type: ProposalOrRefType::Proposal,
157 })
158 }
159
160 pub fn proposal(&self) -> &Proposal {
162 &self.proposal
163 }
164 pub(crate) fn proposal_reference(&self) -> ProposalRef {
166 self.proposal_reference.clone()
167 }
168
169 pub(crate) fn proposal_reference_ref(&self) -> &ProposalRef {
171 &self.proposal_reference
172 }
173
174 pub fn proposal_or_ref_type(&self) -> ProposalOrRefType {
176 self.proposal_or_ref_type
177 }
178 pub fn sender(&self) -> &Sender {
180 &self.sender
181 }
182}
183
184struct OrderedProposalRefs {
187 proposal_refs: HashSet<ProposalRef>,
188 ordered_proposal_refs: Vec<ProposalRef>,
189}
190
191impl OrderedProposalRefs {
192 fn new() -> Self {
193 Self {
194 proposal_refs: HashSet::new(),
195 ordered_proposal_refs: Vec::new(),
196 }
197 }
198
199 fn add(&mut self, proposal_ref: ProposalRef) {
202 if self.proposal_refs.insert(proposal_ref.clone()) {
205 self.ordered_proposal_refs.push(proposal_ref);
206 }
207 }
208
209 fn iter(&self) -> impl Iterator<Item = &ProposalRef> {
212 self.ordered_proposal_refs.iter()
213 }
214}
215
216#[derive(Default, Debug, Serialize, Deserialize)]
222#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
223pub(crate) struct ProposalQueue {
224 proposal_references: Vec<ProposalRef>,
227 #[serde(with = "vector_converter")]
230 queued_proposals: HashMap<ProposalRef, QueuedProposal>,
231}
232
233impl ProposalQueue {
234 pub(crate) fn is_empty(&self) -> bool {
237 self.proposal_references.is_empty()
238 }
239
240 pub(crate) fn from_committed_proposals(
245 ciphersuite: Ciphersuite,
246 crypto: &impl OpenMlsCrypto,
247 committed_proposals: Vec<ProposalOrRef>,
248 proposal_store: &ProposalStore,
249 sender: &Sender,
250 ) -> Result<Self, FromCommittedProposalsError> {
251 log::debug!("from_committed_proposals");
252 let mut proposals_by_reference_queue: HashMap<ProposalRef, QueuedProposal> = HashMap::new();
255 for queued_proposal in proposal_store.proposals() {
256 proposals_by_reference_queue.insert(
257 queued_proposal.proposal_reference(),
258 queued_proposal.clone(),
259 );
260 }
261 log::trace!(" known proposals:\n{proposals_by_reference_queue:#?}");
262 let mut proposal_queue = ProposalQueue::default();
264
265 log::trace!(" committed proposals ...");
267 for proposal_or_ref in committed_proposals.into_iter() {
268 log::trace!(" proposal_or_ref:\n{proposal_or_ref:#?}");
269 let queued_proposal = match proposal_or_ref {
270 ProposalOrRef::Proposal(proposal) => {
271 if proposal
273 .as_remove()
274 .and_then(|remove_proposal| {
275 sender.as_member().filter(|leaf_index| {
276 remove_proposal.removed() == *leaf_index
278 })
279 })
280 .is_some()
281 {
282 return Err(FromCommittedProposalsError::SelfRemoval);
283 };
284
285 QueuedProposal::from_proposal_and_sender(
286 ciphersuite,
287 crypto,
288 *proposal,
289 sender,
290 )?
291 }
292 ProposalOrRef::Reference(ref proposal_reference) => {
293 match proposals_by_reference_queue.get(proposal_reference) {
294 Some(queued_proposal) => {
295 if let Proposal::Remove(ref remove_proposal) = queued_proposal.proposal
297 {
298 if let Sender::Member(leaf_index) = sender {
299 if remove_proposal.removed() == *leaf_index {
300 return Err(FromCommittedProposalsError::SelfRemoval);
301 }
302 }
303 }
304
305 queued_proposal.clone()
306 }
307 None => return Err(FromCommittedProposalsError::ProposalNotFound),
308 }
309 }
310 };
311 proposal_queue.add(queued_proposal);
312 }
313
314 Ok(proposal_queue)
315 }
316
317 pub fn get(&self, proposal_reference: &ProposalRef) -> Option<&QueuedProposal> {
319 self.queued_proposals.get(proposal_reference)
320 }
321
322 pub(crate) fn add(&mut self, queued_proposal: QueuedProposal) {
324 let proposal_reference = queued_proposal.proposal_reference();
325 if let Entry::Vacant(entry) = self.queued_proposals.entry(proposal_reference.clone()) {
327 self.proposal_references.push(proposal_reference);
329 entry.insert(queued_proposal);
331 }
332 }
333
334 pub(crate) fn filtered_by_type(
337 &self,
338 proposal_type: ProposalType,
339 ) -> impl Iterator<Item = &QueuedProposal> {
340 self.proposal_references
342 .iter()
343 .filter(move |&pr| match self.queued_proposals.get(pr) {
344 Some(p) => p.proposal.is_type(proposal_type),
345 None => false,
346 })
347 .filter_map(move |reference| self.get(reference))
348 }
349
350 pub(crate) fn queued_proposals(&self) -> impl Iterator<Item = &QueuedProposal> {
353 self.proposal_references
355 .iter()
356 .filter_map(move |reference| self.get(reference))
357 }
358
359 pub(crate) fn add_proposals(&self) -> impl Iterator<Item = QueuedAddProposal<'_>> {
362 self.queued_proposals().filter_map(|queued_proposal| {
363 if let Proposal::Add(add_proposal) = queued_proposal.proposal() {
364 let sender = queued_proposal.sender();
365 Some(QueuedAddProposal {
366 add_proposal,
367 sender,
368 })
369 } else {
370 None
371 }
372 })
373 }
374
375 pub(crate) fn remove_proposals(&self) -> impl Iterator<Item = QueuedRemoveProposal<'_>> {
378 self.queued_proposals().filter_map(|queued_proposal| {
379 if let Proposal::Remove(remove_proposal) = queued_proposal.proposal() {
380 let sender = queued_proposal.sender();
381 Some(QueuedRemoveProposal {
382 remove_proposal,
383 sender,
384 })
385 } else {
386 None
387 }
388 })
389 }
390
391 pub(crate) fn update_proposals(&self) -> impl Iterator<Item = QueuedUpdateProposal<'_>> {
394 self.queued_proposals().filter_map(|queued_proposal| {
395 if let Proposal::Update(update_proposal) = queued_proposal.proposal() {
396 let sender = queued_proposal.sender();
397 Some(QueuedUpdateProposal {
398 update_proposal,
399 sender,
400 })
401 } else {
402 None
403 }
404 })
405 }
406
407 pub(crate) fn psk_proposals(&self) -> impl Iterator<Item = QueuedPskProposal<'_>> {
410 self.queued_proposals().filter_map(|queued_proposal| {
411 if let Proposal::PreSharedKey(psk_proposal) = queued_proposal.proposal() {
412 let sender = queued_proposal.sender();
413 Some(QueuedPskProposal {
414 psk_proposal,
415 sender,
416 })
417 } else {
418 None
419 }
420 })
421 }
422
423 pub(crate) fn filter_proposals(
446 iter: impl IntoIterator<Item = QueuedProposal>,
447 own_index: LeafNodeIndex,
448 ) -> Result<(Self, bool), ProposalQueueError> {
449 let mut adds: OrderedProposalRefs = OrderedProposalRefs::new();
452 let mut valid_proposals: OrderedProposalRefs = OrderedProposalRefs::new();
453 let mut proposal_pool: HashMap<ProposalRef, QueuedProposal> = HashMap::new();
454 let mut contains_own_updates = false;
455 let mut contains_external_init = false;
456
457 let mut member_specific_proposals: HashMap<LeafNodeIndex, QueuedProposal> = HashMap::new();
458 let mut register_member_specific_proposal =
459 |member: LeafNodeIndex, proposal: QueuedProposal| {
460 match member_specific_proposals.entry(member) {
462 Entry::Vacant(vacant_entry) => {
464 vacant_entry.insert(proposal);
465 }
466 Entry::Occupied(mut occupied_entry)
469 if occupied_entry
470 .get()
471 .proposal()
472 .has_lower_priority_than(&proposal.proposal) =>
473 {
474 occupied_entry.insert(proposal);
475 }
476 Entry::Occupied(_) => {}
478 }
479 };
480
481 for queued_proposal in iter {
483 proposal_pool.insert(
484 queued_proposal.proposal_reference(),
485 queued_proposal.clone(),
486 );
487 match queued_proposal.proposal {
488 Proposal::Add(_) => {
489 adds.add(queued_proposal.proposal_reference());
490 }
491 Proposal::Update(_) => {
492 let Sender::Member(sender_index) = queued_proposal.sender() else {
495 return Err(ProposalQueueError::UpdateFromExternalSender);
496 };
497 if sender_index == &own_index {
498 contains_own_updates = true;
499 continue;
500 }
501 register_member_specific_proposal(*sender_index, queued_proposal);
502 }
503 Proposal::Remove(ref remove_proposal) => {
504 let removed = remove_proposal.removed();
505 register_member_specific_proposal(removed, queued_proposal);
506 }
507 Proposal::PreSharedKey(_) => {
508 valid_proposals.add(queued_proposal.proposal_reference());
509 }
510 Proposal::ReInit(_) => {
511 }
513 Proposal::ExternalInit(_) => {
514 if !contains_external_init {
516 valid_proposals.add(queued_proposal.proposal_reference());
517 contains_external_init = true;
518 }
519 }
520 Proposal::GroupContextExtensions(_) => {
521 valid_proposals.add(queued_proposal.proposal_reference());
522 }
523 Proposal::AppAck(_) => unimplemented!("See #291"),
524 Proposal::SelfRemove => {
525 let Sender::Member(removed) = queued_proposal.sender() else {
526 return Err(ProposalQueueError::SelfRemoveFromNonMember);
527 };
528 register_member_specific_proposal(*removed, queued_proposal);
529 }
530 Proposal::Custom(_) => {
531 valid_proposals.add(queued_proposal.proposal_reference());
534 }
535 }
536 }
537
538 for proposal in member_specific_proposals.values() {
540 valid_proposals.add(proposal.proposal_reference());
541 }
542
543 let mut proposal_queue = ProposalQueue::default();
545 for proposal_reference in adds.iter().chain(valid_proposals.iter()) {
546 let queued_proposal = proposal_pool
547 .get(proposal_reference)
548 .cloned()
549 .ok_or(ProposalQueueError::ProposalNotFound)?;
550 proposal_queue.add(queued_proposal);
551 }
552 Ok((proposal_queue, contains_own_updates))
553 }
554
555 #[cfg(test)]
558 pub(crate) fn contains(&self, proposal_reference_list: &[ProposalRef]) -> bool {
559 for proposal_reference in proposal_reference_list {
560 if !self.queued_proposals.contains_key(proposal_reference) {
561 return false;
562 }
563 }
564 true
565 }
566
567 pub(crate) fn commit_list(&self) -> Vec<ProposalOrRef> {
569 self.proposal_references
571 .iter()
572 .filter_map(|proposal_reference| self.queued_proposals.get(proposal_reference))
573 .map(|queued_proposal| {
574 match queued_proposal.proposal_or_ref_type {
576 ProposalOrRefType::Proposal => {
577 ProposalOrRef::proposal(queued_proposal.proposal.clone())
578 }
579 ProposalOrRefType::Reference => {
580 ProposalOrRef::reference(queued_proposal.proposal_reference.clone())
581 }
582 }
583 })
584 .collect::<Vec<ProposalOrRef>>()
585 }
586}
587
588impl Extend<QueuedProposal> for ProposalQueue {
589 fn extend<T: IntoIterator<Item = QueuedProposal>>(&mut self, iter: T) {
590 for proposal in iter {
591 self.add(proposal)
592 }
593 }
594}
595
596impl IntoIterator for ProposalQueue {
597 type Item = QueuedProposal;
598
599 type IntoIter = std::collections::hash_map::IntoValues<ProposalRef, QueuedProposal>;
600
601 fn into_iter(self) -> Self::IntoIter {
602 self.queued_proposals.into_values()
603 }
604}
605
606impl<'a> IntoIterator for &'a ProposalQueue {
607 type Item = &'a QueuedProposal;
608
609 type IntoIter = std::collections::hash_map::Values<'a, ProposalRef, QueuedProposal>;
610
611 fn into_iter(self) -> Self::IntoIter {
612 self.queued_proposals.values()
613 }
614}
615
616impl FromIterator<QueuedProposal> for ProposalQueue {
617 fn from_iter<T: IntoIterator<Item = QueuedProposal>>(iter: T) -> Self {
618 let mut out = Self::default();
619 out.extend(iter);
620 out
621 }
622}
623
624#[derive(PartialEq, Debug)]
626pub struct QueuedAddProposal<'a> {
627 add_proposal: &'a AddProposal,
628 sender: &'a Sender,
629}
630
631impl QueuedAddProposal<'_> {
632 pub fn add_proposal(&self) -> &AddProposal {
634 self.add_proposal
635 }
636
637 pub fn sender(&self) -> &Sender {
639 self.sender
640 }
641}
642
643#[derive(PartialEq, Eq, Debug)]
645pub struct QueuedRemoveProposal<'a> {
646 remove_proposal: &'a RemoveProposal,
647 sender: &'a Sender,
648}
649
650impl QueuedRemoveProposal<'_> {
651 pub fn remove_proposal(&self) -> &RemoveProposal {
653 self.remove_proposal
654 }
655
656 pub fn sender(&self) -> &Sender {
658 self.sender
659 }
660}
661
662#[derive(PartialEq, Eq, Debug)]
664pub struct QueuedUpdateProposal<'a> {
665 update_proposal: &'a UpdateProposal,
666 sender: &'a Sender,
667}
668
669impl QueuedUpdateProposal<'_> {
670 pub fn update_proposal(&self) -> &UpdateProposal {
672 self.update_proposal
673 }
674
675 pub fn sender(&self) -> &Sender {
677 self.sender
678 }
679}
680
681#[derive(PartialEq, Eq, Debug)]
683pub struct QueuedPskProposal<'a> {
684 psk_proposal: &'a PreSharedKeyProposal,
685 sender: &'a Sender,
686}
687
688impl QueuedPskProposal<'_> {
689 pub fn psk_proposal(&self) -> &PreSharedKeyProposal {
691 self.psk_proposal
692 }
693
694 pub fn sender(&self) -> &Sender {
696 self.sender
697 }
698}