1use std::mem;
4
5use errors::{CommitToPendingProposalsError, MergePendingCommitError};
6use openmls_traits::{crypto::OpenMlsCrypto, signatures::Signer, storage::StorageProvider as _};
7
8use crate::{
9 framing::mls_content::FramedContentBody,
10 group::{errors::MergeCommitError, StageCommitError, ValidationError},
11 messages::group_info::GroupInfo,
12 storage::OpenMlsProvider,
13 tree::sender_ratchet::SenderRatchetConfiguration,
14};
15
16#[cfg(feature = "extensions-draft-08")]
17use crate::{
18 component::{ComponentData, ComponentId},
19 extensions::AppDataDictionary,
20 messages::proposals_in::{ProposalIn, ProposalOrRefIn},
21};
22
23#[cfg(feature = "extensions-draft-08")]
24use std::collections::BTreeMap;
25
26use super::{errors::ProcessMessageError, *};
27
28#[cfg(feature = "extensions-draft-08")]
29pub struct AppDataDictionaryUpdater<'a> {
31 old_dict: Option<&'a AppDataDictionary>,
32 new_entries: Option<AppDataUpdates>,
33}
34
35#[cfg(feature = "extensions-draft-08")]
39#[derive(Default, Debug)]
40pub struct AppDataUpdates(BTreeMap<ComponentId, Option<Vec<u8>>>);
41
42#[cfg(feature = "extensions-draft-08")]
43impl IntoIterator for AppDataUpdates {
44 type Item = (ComponentId, Option<Vec<u8>>);
45
46 type IntoIter = <BTreeMap<ComponentId, Option<Vec<u8>>> as IntoIterator>::IntoIter;
47
48 fn into_iter(self) -> Self::IntoIter {
49 self.0.into_iter()
50 }
51}
52
53#[cfg(feature = "extensions-draft-08")]
54impl AppDataUpdates {
55 pub fn len(&self) -> usize {
57 self.0.len()
58 }
59
60 pub fn is_empty(&self) -> bool {
62 self.0.is_empty()
63 }
64}
65
66#[cfg(feature = "extensions-draft-08")]
67impl<'a> AppDataDictionaryUpdater<'a> {
68 pub fn new(old_dict: Option<&'a AppDataDictionary>) -> Self {
70 Self {
71 old_dict,
72 new_entries: None,
73 }
74 }
75
76 pub fn old_value(&self, component_id: ComponentId) -> Option<&[u8]> {
78 self.old_dict?.get(&component_id)
79 }
80
81 fn new_entries_mut(&mut self) -> &mut AppDataUpdates {
84 self.new_entries
85 .get_or_insert_with(|| AppDataUpdates(BTreeMap::new()))
86 }
87
88 pub fn set(&mut self, component_data: ComponentData) {
91 let (id, data) = component_data.into_parts();
92
93 self.new_entries_mut().0.insert(id, Some(data.into()));
94 }
95
96 pub fn remove(&mut self, id: &ComponentId) {
98 self.new_entries_mut().0.insert(*id, None);
99 }
100
101 pub fn changes(self) -> Option<AppDataUpdates> {
105 self.new_entries
106 }
107}
108
109impl MlsGroup {
110 pub fn process_message<Provider: OpenMlsProvider>(
120 &mut self,
121 provider: &Provider,
122 message: impl Into<ProtocolMessage>,
123 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
124 let unverified_message = self.unprotect_message(provider, message)?;
125
126 #[cfg(feature = "extensions-draft-08")]
129 if let Some(proposals) = unverified_message.committed_proposals() {
130 for proposal_or_ref in proposals {
131 if let ProposalOrRefIn::Proposal(proposal) = proposal_or_ref {
132 if matches!(proposal.as_ref(), ProposalIn::AppDataUpdate(_)) {
133 return Err(ProcessMessageError::FoundAppDataUpdateProposal);
134 }
135 }
136 }
137 }
138 self.process_unverified_message(provider, unverified_message)
139 }
140
141 #[cfg(feature = "extensions-draft-08")]
142 pub fn app_data_dictionary_updater<'a>(&'a self) -> AppDataDictionaryUpdater<'a> {
144 AppDataDictionaryUpdater::new(self.context().app_data_dict())
145 }
146
147 pub fn unprotect_message<Provider: OpenMlsProvider>(
150 &mut self,
151 provider: &Provider,
152 message: impl Into<ProtocolMessage>,
153 ) -> Result<UnverifiedMessage, ProcessMessageError<Provider::StorageError>> {
154 if !self.is_active() {
156 return Err(ProcessMessageError::GroupStateError(
157 MlsGroupStateError::UseAfterEviction,
158 ));
159 }
160 let message = message.into();
161
162 if !message.is_external()
164 && message.is_handshake_message()
165 && !self
166 .configuration()
167 .wire_format_policy()
168 .incoming()
169 .is_compatible_with(message.wire_format())
170 {
171 return Err(ProcessMessageError::IncompatibleWireFormat);
172 }
173
174 let sender_ratchet_configuration = *self.configuration().sender_ratchet_configuration();
176
177 let will_modify_secret_tree = matches!(message, ProtocolMessage::PrivateMessage(_));
180
181 let decrypted_message =
187 self.decrypt_message(provider.crypto(), message, &sender_ratchet_configuration)?;
188
189 if will_modify_secret_tree {
191 provider
192 .storage()
193 .write_message_secrets(self.group_id(), &self.message_secrets_store)
194 .map_err(ProcessMessageError::StorageError)?;
195 }
196
197 let unverified_message = self
198 .public_group
199 .parse_message(decrypted_message, &self.message_secrets_store)
200 .map_err(ProcessMessageError::from)?;
201
202 Ok(unverified_message)
203 }
204
205 pub fn store_pending_proposal<Storage: StorageProvider>(
207 &mut self,
208 storage: &Storage,
209 proposal: QueuedProposal,
210 ) -> Result<(), Storage::Error> {
211 storage.queue_proposal(self.group_id(), &proposal.proposal_reference(), &proposal)?;
212 self.proposal_store_mut().add(proposal);
214
215 Ok(())
216 }
217
218 pub fn has_pending_proposals(&self) -> bool {
220 !self.proposal_store().is_empty()
221 }
222
223 #[allow(clippy::type_complexity)]
234 pub fn commit_to_pending_proposals<Provider: OpenMlsProvider>(
235 &mut self,
236 provider: &Provider,
237 signer: &impl Signer,
238 ) -> Result<
239 (MlsMessageOut, Option<MlsMessageOut>, Option<GroupInfo>),
240 CommitToPendingProposalsError<Provider::StorageError>,
241 > {
242 self.is_operational()?;
243
244 let (commit, welcome, group_info) = self
247 .commit_builder()
248 .consume_proposal_store(true)
250 .load_psks(provider.storage())?
251 .build(provider.rand(), provider.crypto(), signer, |_| true)?
252 .stage_commit(provider)?
253 .into_contents();
254
255 Ok((
256 commit,
257 welcome.map(|welcome| MlsMessageOut::from_welcome(welcome, self.version())),
259 group_info,
260 ))
261 }
262
263 pub fn merge_staged_commit<Provider: OpenMlsProvider>(
266 &mut self,
267 provider: &Provider,
268 staged_commit: StagedCommit,
269 ) -> Result<(), MergeCommitError<Provider::StorageError>> {
270 if staged_commit.self_removed() {
272 self.group_state = MlsGroupState::Inactive;
273 }
274 provider
275 .storage()
276 .write_group_state(self.group_id(), &self.group_state)
277 .map_err(MergeCommitError::StorageError)?;
278
279 self.merge_commit(provider, staged_commit)?;
281
282 let resumption_psk = self.group_epoch_secrets().resumption_psk();
284 self.resumption_psk_store
285 .add(self.context().epoch(), resumption_psk.clone());
286 provider
287 .storage()
288 .write_resumption_psk_store(self.group_id(), &self.resumption_psk_store)
289 .map_err(MergeCommitError::StorageError)?;
290
291 self.own_leaf_nodes.clear();
293 provider
294 .storage()
295 .delete_own_leaf_nodes(self.group_id())
296 .map_err(MergeCommitError::StorageError)?;
297
298 self.clear_pending_commit(provider.storage())
300 .map_err(MergeCommitError::StorageError)?;
301
302 Ok(())
303 }
304
305 pub fn merge_pending_commit<Provider: OpenMlsProvider>(
308 &mut self,
309 provider: &Provider,
310 ) -> Result<(), MergePendingCommitError<Provider::StorageError>> {
311 match &self.group_state {
312 MlsGroupState::PendingCommit(_) => {
313 let old_state = mem::replace(&mut self.group_state, MlsGroupState::Operational);
314 if let MlsGroupState::PendingCommit(pending_commit_state) = old_state {
315 self.merge_staged_commit(provider, (*pending_commit_state).into())?;
316 }
317 Ok(())
318 }
319 MlsGroupState::Inactive => Err(MlsGroupStateError::UseAfterEviction)?,
320 MlsGroupState::Operational => Ok(()),
321 }
322 }
323
324 pub(super) fn read_decryption_keypairs(
326 &self,
327 provider: &impl OpenMlsProvider,
328 own_leaf_nodes: &[LeafNode],
329 ) -> Result<(Vec<EncryptionKeyPair>, Vec<EncryptionKeyPair>), StageCommitError> {
330 let old_epoch_keypairs = self.read_epoch_keypairs(provider.storage()).map_err(|e| {
332 log::error!("Error reading epoch keypairs: {e:?}");
333 StageCommitError::MissingDecryptionKey
334 })?;
335
336 let leaf_node_keypairs = own_leaf_nodes
340 .iter()
341 .map(|leaf_node| {
342 EncryptionKeyPair::read(provider, leaf_node.encryption_key())
343 .ok_or(StageCommitError::MissingDecryptionKey)
344 })
345 .collect::<Result<Vec<EncryptionKeyPair>, StageCommitError>>()?;
346
347 Ok((old_epoch_keypairs, leaf_node_keypairs))
348 }
349
350 #[cfg(feature = "extensions-draft-08")]
353 pub fn process_unverified_message_with_app_data_updates<Provider: OpenMlsProvider>(
354 &self,
355 provider: &Provider,
356 unverified_message: UnverifiedMessage,
357 app_data_dict_updates: Option<AppDataUpdates>,
358 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
359 let (content, credential) =
365 unverified_message.verify(self.ciphersuite(), provider.crypto(), self.version())?;
366
367 #[cfg_attr(not(feature = "extensions-draft-08"), allow(unused_mut))]
368 let mut processed = match content.sender() {
369 Sender::Member(_) | Sender::NewMemberProposal | Sender::NewMemberCommit => self
370 .process_internal_authenticated_content_with_app_data_updates(
371 provider,
372 content,
373 credential,
374 app_data_dict_updates,
375 )?,
376 Sender::External(_) => {
377 self.process_external_authenticated_content(provider, content, credential)?
378 }
379 };
380 #[cfg(feature = "extensions-draft-08")]
381 if self.context().safe_aad_required() {
382 processed
383 .try_attach_safe_aad()
384 .map_err(|_| ProcessMessageError::MalformedSafeAad)?;
385 }
386 Ok(processed)
387 }
388
389 pub(crate) fn process_unverified_message<Provider: OpenMlsProvider>(
417 &self,
418 provider: &Provider,
419 unverified_message: UnverifiedMessage,
420 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
421 let (content, credential) =
427 unverified_message.verify(self.ciphersuite(), provider.crypto(), self.version())?;
428
429 #[cfg_attr(not(feature = "extensions-draft-08"), allow(unused_mut))]
430 let mut processed = match content.sender() {
431 Sender::Member(_) | Sender::NewMemberProposal | Sender::NewMemberCommit => {
432 self.process_internal_authenticated_content(provider, content, credential)?
433 }
434 Sender::External(_) => {
435 self.process_external_authenticated_content(provider, content, credential)?
436 }
437 };
438 #[cfg(feature = "extensions-draft-08")]
439 if self.context().safe_aad_required() {
440 processed
441 .try_attach_safe_aad()
442 .map_err(|_| ProcessMessageError::MalformedSafeAad)?;
443 }
444 Ok(processed)
445 }
446
447 #[cfg(feature = "extensions-draft-08")]
471 fn process_internal_authenticated_content_with_app_data_updates<Provider: OpenMlsProvider>(
472 &self,
473 provider: &Provider,
474 content: AuthenticatedContent,
475 credential: Credential,
476 app_data_dict_updates: Option<AppDataUpdates>,
477 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
478 let sender = content.sender().clone();
479 let authenticated_data = content.authenticated_data().to_owned();
480 let epoch = content.epoch();
481
482 let content = match content.content() {
483 FramedContentBody::Application(application_message) => {
484 ProcessedMessageContent::ApplicationMessage(ApplicationMessage::new(
485 application_message.as_slice().to_owned(),
486 ))
487 }
488 FramedContentBody::Proposal(_) => {
489 let proposal = Box::new(QueuedProposal::from_authenticated_content_by_ref(
490 self.ciphersuite(),
491 provider.crypto(),
492 content,
493 )?);
494
495 if matches!(sender, Sender::NewMemberProposal) {
496 ProcessedMessageContent::ExternalJoinProposalMessage(proposal)
497 } else {
498 ProcessedMessageContent::ProposalMessage(proposal)
499 }
500 }
501 FramedContentBody::Commit(_) => {
502 let (old_epoch_keypairs, leaf_node_keypairs) =
504 self.read_decryption_keypairs(provider, &self.own_leaf_nodes)?;
505
506 let staged_commit = self.stage_commit_with_app_data_updates(
507 &content,
508 old_epoch_keypairs,
509 leaf_node_keypairs,
510 app_data_dict_updates,
511 provider,
512 )?;
513
514 ProcessedMessageContent::StagedCommitMessage(Box::new(staged_commit))
515 }
516 };
517
518 Ok(ProcessedMessage::new(
519 self.group_id().clone(),
520 epoch,
521 sender,
522 authenticated_data,
523 content,
524 credential,
525 ))
526 }
527
528 fn process_internal_authenticated_content<Provider: OpenMlsProvider>(
529 &self,
530 provider: &Provider,
531 content: AuthenticatedContent,
532 credential: Credential,
533 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
534 let sender = content.sender().clone();
535 let authenticated_data = content.authenticated_data().to_owned();
536 let epoch = content.epoch();
537
538 let content = match content.content() {
539 FramedContentBody::Application(application_message) => {
540 ProcessedMessageContent::ApplicationMessage(ApplicationMessage::new(
541 application_message.as_slice().to_owned(),
542 ))
543 }
544 FramedContentBody::Proposal(_) => {
545 let proposal = Box::new(QueuedProposal::from_authenticated_content_by_ref(
546 self.ciphersuite(),
547 provider.crypto(),
548 content,
549 )?);
550
551 if matches!(sender, Sender::NewMemberProposal) {
552 ProcessedMessageContent::ExternalJoinProposalMessage(proposal)
553 } else {
554 ProcessedMessageContent::ProposalMessage(proposal)
555 }
556 }
557 FramedContentBody::Commit(_) => {
558 let (old_epoch_keypairs, leaf_node_keypairs) =
560 self.read_decryption_keypairs(provider, &self.own_leaf_nodes)?;
561
562 let staged_commit =
563 self.stage_commit(&content, old_epoch_keypairs, leaf_node_keypairs, provider)?;
564
565 ProcessedMessageContent::StagedCommitMessage(Box::new(staged_commit))
566 }
567 };
568
569 Ok(ProcessedMessage::new(
570 self.group_id().clone(),
571 epoch,
572 sender,
573 authenticated_data,
574 content,
575 credential,
576 ))
577 }
578
579 fn process_external_authenticated_content<Provider: OpenMlsProvider>(
585 &self,
586 provider: &Provider,
587 content: AuthenticatedContent,
588 credential: Credential,
589 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
590 let sender = content.sender().clone();
591 let data = content.authenticated_data().to_owned();
592
593 debug_assert!(matches!(sender, Sender::External(_)));
594
595 match content.content() {
597 FramedContentBody::Application(_) => {
598 Err(ProcessMessageError::UnauthorizedExternalApplicationMessage)
599 }
600 FramedContentBody::Proposal(Proposal::GroupContextExtensions(_)) => {
602 let content = ProcessedMessageContent::ProposalMessage(Box::new(
603 QueuedProposal::from_authenticated_content_by_ref(
604 self.ciphersuite(),
605 provider.crypto(),
606 content,
607 )?,
608 ));
609 Ok(ProcessedMessage::new(
610 self.group_id().clone(),
611 self.context().epoch(),
612 sender,
613 data,
614 content,
615 credential,
616 ))
617 }
618
619 FramedContentBody::Proposal(Proposal::Remove(_)) => {
620 let content = ProcessedMessageContent::ProposalMessage(Box::new(
621 QueuedProposal::from_authenticated_content_by_ref(
622 self.ciphersuite(),
623 provider.crypto(),
624 content,
625 )?,
626 ));
627 Ok(ProcessedMessage::new(
628 self.group_id().clone(),
629 self.context().epoch(),
630 sender,
631 data,
632 content,
633 credential,
634 ))
635 }
636 FramedContentBody::Proposal(Proposal::Add(_)) => {
637 let content = ProcessedMessageContent::ProposalMessage(Box::new(
638 QueuedProposal::from_authenticated_content_by_ref(
639 self.ciphersuite(),
640 provider.crypto(),
641 content,
642 )?,
643 ));
644 Ok(ProcessedMessage::new(
645 self.group_id().clone(),
646 self.context().epoch(),
647 sender,
648 data,
649 content,
650 credential,
651 ))
652 }
653 FramedContentBody::Proposal(_) => Err(ProcessMessageError::UnsupportedProposalType),
655 FramedContentBody::Commit(_) => {
656 Err(ProcessMessageError::UnauthorizedExternalCommitMessage)
657 }
658 }
659 }
660
661 pub(crate) fn decrypt_message(
673 &mut self,
674 crypto: &impl OpenMlsCrypto,
675 message: ProtocolMessage,
676 sender_ratchet_configuration: &SenderRatchetConfiguration,
677 ) -> Result<DecryptedMessage, ValidationError> {
678 self.public_group.validate_framing(&message)?;
682
683 let epoch = message.epoch();
684
685 match message {
689 ProtocolMessage::PublicMessage(public_message) => {
690 let message_secrets =
692 self.message_secrets_for_epoch(epoch).map_err(|e| match e {
693 SecretTreeError::TooDistantInThePast => ValidationError::NoPastEpochData,
694 _ => LibraryError::custom(
695 "Unexpected error while retrieving message secrets for epoch.",
696 )
697 .into(),
698 })?;
699 DecryptedMessage::from_inbound_public_message(
700 *public_message,
701 message_secrets,
702 message_secrets.serialized_context().to_vec(),
703 crypto,
704 self.ciphersuite(),
705 )
706 }
707 ProtocolMessage::PrivateMessage(ciphertext) => {
708 DecryptedMessage::from_inbound_ciphertext(
710 ciphertext,
711 crypto,
712 self,
713 sender_ratchet_configuration,
714 )
715 }
716 }
717 }
718}