1use std::mem;
4
5use errors::{CommitToPendingProposalsError, MergePendingCommitError};
6use openmls_traits::{crypto::OpenMlsCrypto, signatures::Signer, storage::StorageProvider as _};
7
8use crate::{
9 framing::mls_content::FramedContentBody,
10 group::{errors::MergeCommitError, StageCommitError, ValidationError},
11 messages::group_info::GroupInfo,
12 storage::OpenMlsProvider,
13 tree::sender_ratchet::SenderRatchetConfiguration,
14};
15
16#[cfg(feature = "extensions-draft-08")]
17use crate::{
18 component::{ComponentData, ComponentId},
19 extensions::AppDataDictionary,
20 messages::proposals_in::{ProposalIn, ProposalOrRefIn},
21};
22
23#[cfg(feature = "extensions-draft-08")]
24use std::collections::BTreeMap;
25
26use super::{errors::ProcessMessageError, *};
27
28#[cfg(feature = "extensions-draft-08")]
29pub struct AppDataDictionaryUpdater<'a> {
31 old_dict: Option<&'a AppDataDictionary>,
32 new_entries: Option<AppDataUpdates>,
33}
34
35#[cfg(feature = "extensions-draft-08")]
39#[derive(Default, Debug)]
40pub struct AppDataUpdates(BTreeMap<ComponentId, Option<Vec<u8>>>);
41
42#[cfg(feature = "extensions-draft-08")]
43impl IntoIterator for AppDataUpdates {
44 type Item = (ComponentId, Option<Vec<u8>>);
45
46 type IntoIter = <BTreeMap<ComponentId, Option<Vec<u8>>> as IntoIterator>::IntoIter;
47
48 fn into_iter(self) -> Self::IntoIter {
49 self.0.into_iter()
50 }
51}
52
53#[cfg(feature = "extensions-draft-08")]
54impl AppDataUpdates {
55 pub fn len(&self) -> usize {
57 self.0.len()
58 }
59
60 pub fn is_empty(&self) -> bool {
62 self.0.is_empty()
63 }
64}
65
66#[cfg(feature = "extensions-draft-08")]
67impl<'a> AppDataDictionaryUpdater<'a> {
68 pub fn new(old_dict: Option<&'a AppDataDictionary>) -> Self {
70 Self {
71 old_dict,
72 new_entries: None,
73 }
74 }
75
76 pub fn old_value(&self, component_id: ComponentId) -> Option<&[u8]> {
78 self.old_dict?.get(&component_id)
79 }
80
81 fn new_entries_mut(&mut self) -> &mut AppDataUpdates {
84 self.new_entries
85 .get_or_insert_with(|| AppDataUpdates(BTreeMap::new()))
86 }
87
88 pub fn set(&mut self, component_data: ComponentData) {
91 let (id, data) = component_data.into_parts();
92
93 self.new_entries_mut().0.insert(id, Some(data.into()));
94 }
95
96 pub fn remove(&mut self, id: &ComponentId) {
98 self.new_entries_mut().0.insert(*id, None);
99 }
100
101 pub fn changes(self) -> Option<AppDataUpdates> {
105 self.new_entries
106 }
107}
108
109impl MlsGroup {
110 pub fn process_message<Provider: OpenMlsProvider>(
120 &mut self,
121 provider: &Provider,
122 message: impl Into<ProtocolMessage>,
123 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
124 let unverified_message = self.unprotect_message(provider, message)?;
125
126 #[cfg(feature = "extensions-draft-08")]
129 if let Some(proposals) = unverified_message.committed_proposals() {
130 for proposal_or_ref in proposals {
131 if let ProposalOrRefIn::Proposal(proposal) = proposal_or_ref {
132 if matches!(proposal.as_ref(), ProposalIn::AppDataUpdate(_)) {
133 return Err(ProcessMessageError::FoundAppDataUpdateProposal);
134 }
135 }
136 }
137 }
138 self.process_unverified_message(provider, unverified_message)
139 }
140
141 #[cfg(feature = "extensions-draft-08")]
142 pub fn app_data_dictionary_updater<'a>(&'a self) -> AppDataDictionaryUpdater<'a> {
144 AppDataDictionaryUpdater::new(self.context().app_data_dict())
145 }
146
147 pub fn unprotect_message<Provider: OpenMlsProvider>(
150 &mut self,
151 provider: &Provider,
152 message: impl Into<ProtocolMessage>,
153 ) -> Result<UnverifiedMessage, ProcessMessageError<Provider::StorageError>> {
154 if !self.is_active() {
156 return Err(ProcessMessageError::GroupStateError(
157 MlsGroupStateError::UseAfterEviction,
158 ));
159 }
160 let message = message.into();
161
162 if !message.is_external()
164 && message.is_handshake_message()
165 && !self
166 .configuration()
167 .wire_format_policy()
168 .incoming()
169 .is_compatible_with(message.wire_format())
170 {
171 return Err(ProcessMessageError::IncompatibleWireFormat);
172 }
173
174 let sender_ratchet_configuration = *self.configuration().sender_ratchet_configuration();
176
177 let will_modify_secret_tree = matches!(message, ProtocolMessage::PrivateMessage(_));
180
181 let decrypted_message =
187 self.decrypt_message(provider.crypto(), message, &sender_ratchet_configuration)?;
188
189 if will_modify_secret_tree {
191 provider
192 .storage()
193 .write_message_secrets(self.group_id(), &self.message_secrets_store)
194 .map_err(ProcessMessageError::StorageError)?;
195 }
196
197 let unverified_message = self
198 .public_group
199 .parse_message(decrypted_message, &self.message_secrets_store)
200 .map_err(ProcessMessageError::from)?;
201
202 Ok(unverified_message)
203 }
204
205 pub fn store_pending_proposal<Storage: StorageProvider>(
207 &mut self,
208 storage: &Storage,
209 proposal: QueuedProposal,
210 ) -> Result<(), Storage::Error> {
211 storage.queue_proposal(self.group_id(), &proposal.proposal_reference(), &proposal)?;
212 self.proposal_store_mut().add(proposal);
214
215 Ok(())
216 }
217
218 pub fn has_pending_proposals(&self) -> bool {
220 !self.proposal_store().is_empty()
221 }
222
223 #[allow(clippy::type_complexity)]
234 pub fn commit_to_pending_proposals<Provider: OpenMlsProvider>(
235 &mut self,
236 provider: &Provider,
237 signer: &impl Signer,
238 ) -> Result<
239 (MlsMessageOut, Option<MlsMessageOut>, Option<GroupInfo>),
240 CommitToPendingProposalsError<Provider::StorageError>,
241 > {
242 self.is_operational()?;
243
244 let (commit, welcome, group_info) = self
247 .commit_builder()
248 .consume_proposal_store(true)
250 .load_psks(provider.storage())?
251 .build(provider.rand(), provider.crypto(), signer, |_| true)?
252 .stage_commit(provider)?
253 .into_contents();
254
255 Ok((
256 commit,
257 welcome.map(|welcome| MlsMessageOut::from_welcome(welcome, self.version())),
259 group_info,
260 ))
261 }
262
263 pub fn merge_staged_commit<Provider: OpenMlsProvider>(
266 &mut self,
267 provider: &Provider,
268 staged_commit: StagedCommit,
269 ) -> Result<(), MergeCommitError<Provider::StorageError>> {
270 if staged_commit.self_removed() {
272 self.group_state = MlsGroupState::Inactive;
273 }
274 provider
275 .storage()
276 .write_group_state(self.group_id(), &self.group_state)
277 .map_err(MergeCommitError::StorageError)?;
278
279 self.merge_commit(provider, staged_commit)?;
281
282 let resumption_psk = self.group_epoch_secrets().resumption_psk();
284 self.resumption_psk_store
285 .add(self.context().epoch(), resumption_psk.clone());
286 provider
287 .storage()
288 .write_resumption_psk_store(self.group_id(), &self.resumption_psk_store)
289 .map_err(MergeCommitError::StorageError)?;
290
291 self.own_leaf_nodes.clear();
293 provider
294 .storage()
295 .delete_own_leaf_nodes(self.group_id())
296 .map_err(MergeCommitError::StorageError)?;
297
298 self.clear_pending_commit(provider.storage())
300 .map_err(MergeCommitError::StorageError)?;
301
302 Ok(())
303 }
304
305 pub fn merge_pending_commit<Provider: OpenMlsProvider>(
308 &mut self,
309 provider: &Provider,
310 ) -> Result<(), MergePendingCommitError<Provider::StorageError>> {
311 match &self.group_state {
312 MlsGroupState::PendingCommit(_) => {
313 let old_state = mem::replace(&mut self.group_state, MlsGroupState::Operational);
314 if let MlsGroupState::PendingCommit(pending_commit_state) = old_state {
315 self.merge_staged_commit(provider, (*pending_commit_state).into())?;
316 }
317 Ok(())
318 }
319 MlsGroupState::Inactive => Err(MlsGroupStateError::UseAfterEviction)?,
320 MlsGroupState::Operational => Ok(()),
321 }
322 }
323
324 pub(super) fn read_decryption_keypairs(
326 &self,
327 provider: &impl OpenMlsProvider,
328 own_leaf_nodes: &[LeafNode],
329 ) -> Result<(Vec<EncryptionKeyPair>, Vec<EncryptionKeyPair>), StageCommitError> {
330 let old_epoch_keypairs = self.read_epoch_keypairs(provider.storage()).map_err(|e| {
332 log::error!("Error reading epoch keypairs: {e:?}");
333 StageCommitError::MissingDecryptionKey
334 })?;
335
336 let leaf_node_keypairs = own_leaf_nodes
340 .iter()
341 .map(|leaf_node| {
342 EncryptionKeyPair::read(provider, leaf_node.encryption_key())
343 .ok_or(StageCommitError::MissingDecryptionKey)
344 })
345 .collect::<Result<Vec<EncryptionKeyPair>, StageCommitError>>()?;
346
347 Ok((old_epoch_keypairs, leaf_node_keypairs))
348 }
349
350 #[cfg(feature = "extensions-draft-08")]
353 pub fn process_unverified_message_with_app_data_updates<Provider: OpenMlsProvider>(
354 &self,
355 provider: &Provider,
356 unverified_message: UnverifiedMessage,
357 app_data_dict_updates: Option<AppDataUpdates>,
358 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
359 let (content, credential) =
365 unverified_message.verify(self.ciphersuite(), provider.crypto(), self.version())?;
366
367 match content.sender() {
368 Sender::Member(_) | Sender::NewMemberProposal | Sender::NewMemberCommit => self
369 .process_internal_authenticated_content_with_app_data_updates(
370 provider,
371 content,
372 credential,
373 app_data_dict_updates,
374 ),
375 Sender::External(_) => {
376 self.process_external_authenticated_content(provider, content, credential)
377 }
378 }
379 }
380
381 pub(crate) fn process_unverified_message<Provider: OpenMlsProvider>(
409 &self,
410 provider: &Provider,
411 unverified_message: UnverifiedMessage,
412 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
413 let (content, credential) =
419 unverified_message.verify(self.ciphersuite(), provider.crypto(), self.version())?;
420
421 match content.sender() {
422 Sender::Member(_) | Sender::NewMemberProposal | Sender::NewMemberCommit => {
423 self.process_internal_authenticated_content(provider, content, credential)
424 }
425 Sender::External(_) => {
426 self.process_external_authenticated_content(provider, content, credential)
427 }
428 }
429 }
430
431 #[cfg(feature = "extensions-draft-08")]
455 fn process_internal_authenticated_content_with_app_data_updates<Provider: OpenMlsProvider>(
456 &self,
457 provider: &Provider,
458 content: AuthenticatedContent,
459 credential: Credential,
460 app_data_dict_updates: Option<AppDataUpdates>,
461 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
462 let sender = content.sender().clone();
463 let authenticated_data = content.authenticated_data().to_owned();
464 let epoch = content.epoch();
465
466 let content = match content.content() {
467 FramedContentBody::Application(application_message) => {
468 ProcessedMessageContent::ApplicationMessage(ApplicationMessage::new(
469 application_message.as_slice().to_owned(),
470 ))
471 }
472 FramedContentBody::Proposal(_) => {
473 let proposal = Box::new(QueuedProposal::from_authenticated_content_by_ref(
474 self.ciphersuite(),
475 provider.crypto(),
476 content,
477 )?);
478
479 if matches!(sender, Sender::NewMemberProposal) {
480 ProcessedMessageContent::ExternalJoinProposalMessage(proposal)
481 } else {
482 ProcessedMessageContent::ProposalMessage(proposal)
483 }
484 }
485 FramedContentBody::Commit(_) => {
486 let (old_epoch_keypairs, leaf_node_keypairs) =
488 self.read_decryption_keypairs(provider, &self.own_leaf_nodes)?;
489
490 let staged_commit = self.stage_commit_with_app_data_updates(
491 &content,
492 old_epoch_keypairs,
493 leaf_node_keypairs,
494 app_data_dict_updates,
495 provider,
496 )?;
497
498 ProcessedMessageContent::StagedCommitMessage(Box::new(staged_commit))
499 }
500 };
501
502 Ok(ProcessedMessage::new(
503 self.group_id().clone(),
504 epoch,
505 sender,
506 authenticated_data,
507 content,
508 credential,
509 ))
510 }
511
512 fn process_internal_authenticated_content<Provider: OpenMlsProvider>(
513 &self,
514 provider: &Provider,
515 content: AuthenticatedContent,
516 credential: Credential,
517 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
518 let sender = content.sender().clone();
519 let authenticated_data = content.authenticated_data().to_owned();
520 let epoch = content.epoch();
521
522 let content = match content.content() {
523 FramedContentBody::Application(application_message) => {
524 ProcessedMessageContent::ApplicationMessage(ApplicationMessage::new(
525 application_message.as_slice().to_owned(),
526 ))
527 }
528 FramedContentBody::Proposal(_) => {
529 let proposal = Box::new(QueuedProposal::from_authenticated_content_by_ref(
530 self.ciphersuite(),
531 provider.crypto(),
532 content,
533 )?);
534
535 if matches!(sender, Sender::NewMemberProposal) {
536 ProcessedMessageContent::ExternalJoinProposalMessage(proposal)
537 } else {
538 ProcessedMessageContent::ProposalMessage(proposal)
539 }
540 }
541 FramedContentBody::Commit(_) => {
542 let (old_epoch_keypairs, leaf_node_keypairs) =
544 self.read_decryption_keypairs(provider, &self.own_leaf_nodes)?;
545
546 let staged_commit =
547 self.stage_commit(&content, old_epoch_keypairs, leaf_node_keypairs, provider)?;
548
549 ProcessedMessageContent::StagedCommitMessage(Box::new(staged_commit))
550 }
551 };
552
553 Ok(ProcessedMessage::new(
554 self.group_id().clone(),
555 epoch,
556 sender,
557 authenticated_data,
558 content,
559 credential,
560 ))
561 }
562
563 fn process_external_authenticated_content<Provider: OpenMlsProvider>(
569 &self,
570 provider: &Provider,
571 content: AuthenticatedContent,
572 credential: Credential,
573 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
574 let sender = content.sender().clone();
575 let data = content.authenticated_data().to_owned();
576
577 debug_assert!(matches!(sender, Sender::External(_)));
578
579 match content.content() {
581 FramedContentBody::Application(_) => {
582 Err(ProcessMessageError::UnauthorizedExternalApplicationMessage)
583 }
584 FramedContentBody::Proposal(Proposal::GroupContextExtensions(_)) => {
586 let content = ProcessedMessageContent::ProposalMessage(Box::new(
587 QueuedProposal::from_authenticated_content_by_ref(
588 self.ciphersuite(),
589 provider.crypto(),
590 content,
591 )?,
592 ));
593 Ok(ProcessedMessage::new(
594 self.group_id().clone(),
595 self.context().epoch(),
596 sender,
597 data,
598 content,
599 credential,
600 ))
601 }
602
603 FramedContentBody::Proposal(Proposal::Remove(_)) => {
604 let content = ProcessedMessageContent::ProposalMessage(Box::new(
605 QueuedProposal::from_authenticated_content_by_ref(
606 self.ciphersuite(),
607 provider.crypto(),
608 content,
609 )?,
610 ));
611 Ok(ProcessedMessage::new(
612 self.group_id().clone(),
613 self.context().epoch(),
614 sender,
615 data,
616 content,
617 credential,
618 ))
619 }
620 FramedContentBody::Proposal(Proposal::Add(_)) => {
621 let content = ProcessedMessageContent::ProposalMessage(Box::new(
622 QueuedProposal::from_authenticated_content_by_ref(
623 self.ciphersuite(),
624 provider.crypto(),
625 content,
626 )?,
627 ));
628 Ok(ProcessedMessage::new(
629 self.group_id().clone(),
630 self.context().epoch(),
631 sender,
632 data,
633 content,
634 credential,
635 ))
636 }
637 FramedContentBody::Proposal(_) => Err(ProcessMessageError::UnsupportedProposalType),
639 FramedContentBody::Commit(_) => {
640 Err(ProcessMessageError::UnauthorizedExternalCommitMessage)
641 }
642 }
643 }
644
645 pub(crate) fn decrypt_message(
657 &mut self,
658 crypto: &impl OpenMlsCrypto,
659 message: ProtocolMessage,
660 sender_ratchet_configuration: &SenderRatchetConfiguration,
661 ) -> Result<DecryptedMessage, ValidationError> {
662 self.public_group.validate_framing(&message)?;
666
667 let epoch = message.epoch();
668
669 match message {
673 ProtocolMessage::PublicMessage(public_message) => {
674 let message_secrets =
676 self.message_secrets_for_epoch(epoch).map_err(|e| match e {
677 SecretTreeError::TooDistantInThePast => ValidationError::NoPastEpochData,
678 _ => LibraryError::custom(
679 "Unexpected error while retrieving message secrets for epoch.",
680 )
681 .into(),
682 })?;
683 DecryptedMessage::from_inbound_public_message(
684 *public_message,
685 message_secrets,
686 message_secrets.serialized_context().to_vec(),
687 crypto,
688 self.ciphersuite(),
689 )
690 }
691 ProtocolMessage::PrivateMessage(ciphertext) => {
692 DecryptedMessage::from_inbound_ciphertext(
694 ciphertext,
695 crypto,
696 self,
697 sender_ratchet_configuration,
698 )
699 }
700 }
701 }
702}