1use std::mem;
4
5use errors::{CommitToPendingProposalsError, MergePendingCommitError};
6use openmls_traits::{crypto::OpenMlsCrypto, signatures::Signer, storage::StorageProvider as _};
7
8use crate::{
9 framing::mls_content::FramedContentBody,
10 group::{errors::MergeCommitError, StageCommitError, ValidationError},
11 messages::group_info::GroupInfo,
12 storage::OpenMlsProvider,
13 tree::sender_ratchet::SenderRatchetConfiguration,
14};
15
16use super::{errors::ProcessMessageError, *};
17
18impl MlsGroup {
19 pub fn process_message<Provider: OpenMlsProvider>(
29 &mut self,
30 provider: &Provider,
31 message: impl Into<ProtocolMessage>,
32 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
33 if !self.is_active() {
35 return Err(ProcessMessageError::GroupStateError(
36 MlsGroupStateError::UseAfterEviction,
37 ));
38 }
39 let message = message.into();
40
41 if !message.is_external()
43 && message.is_handshake_message()
44 && !self
45 .configuration()
46 .wire_format_policy()
47 .incoming()
48 .is_compatible_with(message.wire_format())
49 {
50 return Err(ProcessMessageError::IncompatibleWireFormat);
51 }
52
53 let sender_ratchet_configuration = *self.configuration().sender_ratchet_configuration();
55
56 let will_modify_secret_tree = matches!(message, ProtocolMessage::PrivateMessage(_));
59
60 let decrypted_message =
66 self.decrypt_message(provider.crypto(), message, &sender_ratchet_configuration)?;
67
68 let unverified_message = self
69 .public_group
70 .parse_message(decrypted_message, &self.message_secrets_store)
71 .map_err(ProcessMessageError::from)?;
72
73 let (old_epoch_keypairs, leaf_node_keypairs) =
75 if let ContentType::Commit = unverified_message.content_type() {
76 self.read_decryption_keypairs(provider, &self.own_leaf_nodes)?
77 } else {
78 (vec![], vec![])
79 };
80
81 let processed_message = self.process_unverified_message(
82 provider,
83 unverified_message,
84 old_epoch_keypairs,
85 leaf_node_keypairs,
86 )?;
87
88 if will_modify_secret_tree {
90 provider
91 .storage()
92 .write_message_secrets(self.group_id(), &self.message_secrets_store)
93 .map_err(ProcessMessageError::StorageError)?;
94 }
95
96 Ok(processed_message)
97 }
98
99 pub fn store_pending_proposal<Storage: StorageProvider>(
101 &mut self,
102 storage: &Storage,
103 proposal: QueuedProposal,
104 ) -> Result<(), Storage::Error> {
105 storage.queue_proposal(self.group_id(), &proposal.proposal_reference(), &proposal)?;
106 self.proposal_store_mut().add(proposal);
108
109 Ok(())
110 }
111
112 pub fn has_pending_proposals(&self) -> bool {
114 !self.proposal_store().is_empty()
115 }
116
117 #[allow(clippy::type_complexity)]
128 pub fn commit_to_pending_proposals<Provider: OpenMlsProvider>(
129 &mut self,
130 provider: &Provider,
131 signer: &impl Signer,
132 ) -> Result<
133 (MlsMessageOut, Option<MlsMessageOut>, Option<GroupInfo>),
134 CommitToPendingProposalsError<Provider::StorageError>,
135 > {
136 self.is_operational()?;
137
138 let (commit, welcome, group_info) = self
141 .commit_builder()
142 .consume_proposal_store(true)
144 .load_psks(provider.storage())?
145 .build(provider.rand(), provider.crypto(), signer, |_| true)?
146 .stage_commit(provider)?
147 .into_contents();
148
149 Ok((
150 commit,
151 welcome.map(|welcome| MlsMessageOut::from_welcome(welcome, self.version())),
153 group_info,
154 ))
155 }
156
157 pub fn merge_staged_commit<Provider: OpenMlsProvider>(
160 &mut self,
161 provider: &Provider,
162 staged_commit: StagedCommit,
163 ) -> Result<(), MergeCommitError<Provider::StorageError>> {
164 if staged_commit.self_removed() {
166 self.group_state = MlsGroupState::Inactive;
167 }
168 provider
169 .storage()
170 .write_group_state(self.group_id(), &self.group_state)
171 .map_err(MergeCommitError::StorageError)?;
172
173 self.merge_commit(provider, staged_commit)?;
175
176 let resumption_psk = self.group_epoch_secrets().resumption_psk();
178 self.resumption_psk_store
179 .add(self.context().epoch(), resumption_psk.clone());
180 provider
181 .storage()
182 .write_resumption_psk_store(self.group_id(), &self.resumption_psk_store)
183 .map_err(MergeCommitError::StorageError)?;
184
185 self.own_leaf_nodes.clear();
187 provider
188 .storage()
189 .delete_own_leaf_nodes(self.group_id())
190 .map_err(MergeCommitError::StorageError)?;
191
192 self.clear_pending_commit(provider.storage())
194 .map_err(MergeCommitError::StorageError)?;
195
196 Ok(())
197 }
198
199 pub fn merge_pending_commit<Provider: OpenMlsProvider>(
202 &mut self,
203 provider: &Provider,
204 ) -> Result<(), MergePendingCommitError<Provider::StorageError>> {
205 match &self.group_state {
206 MlsGroupState::PendingCommit(_) => {
207 let old_state = mem::replace(&mut self.group_state, MlsGroupState::Operational);
208 if let MlsGroupState::PendingCommit(pending_commit_state) = old_state {
209 self.merge_staged_commit(provider, (*pending_commit_state).into())?;
210 }
211 Ok(())
212 }
213 MlsGroupState::Inactive => Err(MlsGroupStateError::UseAfterEviction)?,
214 MlsGroupState::Operational => Ok(()),
215 }
216 }
217
218 pub(super) fn read_decryption_keypairs(
220 &self,
221 provider: &impl OpenMlsProvider,
222 own_leaf_nodes: &[LeafNode],
223 ) -> Result<(Vec<EncryptionKeyPair>, Vec<EncryptionKeyPair>), StageCommitError> {
224 let old_epoch_keypairs = self.read_epoch_keypairs(provider.storage()).map_err(|e| {
226 log::error!("Error reading epoch keypairs: {e:?}");
227 StageCommitError::MissingDecryptionKey
228 })?;
229
230 let leaf_node_keypairs = own_leaf_nodes
234 .iter()
235 .map(|leaf_node| {
236 EncryptionKeyPair::read(provider, leaf_node.encryption_key())
237 .ok_or(StageCommitError::MissingDecryptionKey)
238 })
239 .collect::<Result<Vec<EncryptionKeyPair>, StageCommitError>>()?;
240
241 Ok((old_epoch_keypairs, leaf_node_keypairs))
242 }
243
244 pub(crate) fn process_unverified_message<Provider: OpenMlsProvider>(
273 &self,
274 provider: &Provider,
275 unverified_message: UnverifiedMessage,
276 old_epoch_keypairs: Vec<EncryptionKeyPair>,
277 leaf_node_keypairs: Vec<EncryptionKeyPair>,
278 ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
279 let (content, credential) =
285 unverified_message.verify(self.ciphersuite(), provider.crypto(), self.version())?;
286
287 match content.sender() {
288 Sender::Member(_) | Sender::NewMemberCommit | Sender::NewMemberProposal => {
289 let sender = content.sender().clone();
290 let authenticated_data = content.authenticated_data().to_owned();
291 let epoch = content.epoch();
292
293 let content = match content.content() {
294 FramedContentBody::Application(application_message) => {
295 ProcessedMessageContent::ApplicationMessage(ApplicationMessage::new(
296 application_message.as_slice().to_owned(),
297 ))
298 }
299 FramedContentBody::Proposal(_) => {
300 let proposal = Box::new(QueuedProposal::from_authenticated_content_by_ref(
301 self.ciphersuite(),
302 provider.crypto(),
303 content,
304 )?);
305
306 if matches!(sender, Sender::NewMemberProposal) {
307 ProcessedMessageContent::ExternalJoinProposalMessage(proposal)
308 } else {
309 ProcessedMessageContent::ProposalMessage(proposal)
310 }
311 }
312 FramedContentBody::Commit(_) => {
313 let staged_commit = self.stage_commit(
314 &content,
315 old_epoch_keypairs,
316 leaf_node_keypairs,
317 provider,
318 )?;
319 ProcessedMessageContent::StagedCommitMessage(Box::new(staged_commit))
320 }
321 };
322
323 Ok(ProcessedMessage::new(
324 self.group_id().clone(),
325 epoch,
326 sender,
327 authenticated_data,
328 content,
329 credential,
330 ))
331 }
332 Sender::External(_) => {
333 let sender = content.sender().clone();
334 let data = content.authenticated_data().to_owned();
335 match content.content() {
337 FramedContentBody::Application(_) => {
338 Err(ProcessMessageError::UnauthorizedExternalApplicationMessage)
339 }
340 FramedContentBody::Proposal(Proposal::GroupContextExtensions(_)) => {
342 let content = ProcessedMessageContent::ProposalMessage(Box::new(
343 QueuedProposal::from_authenticated_content_by_ref(
344 self.ciphersuite(),
345 provider.crypto(),
346 content,
347 )?,
348 ));
349 Ok(ProcessedMessage::new(
350 self.group_id().clone(),
351 self.context().epoch(),
352 sender,
353 data,
354 content,
355 credential,
356 ))
357 }
358
359 FramedContentBody::Proposal(Proposal::Remove(_)) => {
360 let content = ProcessedMessageContent::ProposalMessage(Box::new(
361 QueuedProposal::from_authenticated_content_by_ref(
362 self.ciphersuite(),
363 provider.crypto(),
364 content,
365 )?,
366 ));
367 Ok(ProcessedMessage::new(
368 self.group_id().clone(),
369 self.context().epoch(),
370 sender,
371 data,
372 content,
373 credential,
374 ))
375 }
376 FramedContentBody::Proposal(Proposal::Add(_)) => {
377 let content = ProcessedMessageContent::ProposalMessage(Box::new(
378 QueuedProposal::from_authenticated_content_by_ref(
379 self.ciphersuite(),
380 provider.crypto(),
381 content,
382 )?,
383 ));
384 Ok(ProcessedMessage::new(
385 self.group_id().clone(),
386 self.context().epoch(),
387 sender,
388 data,
389 content,
390 credential,
391 ))
392 }
393 FramedContentBody::Proposal(_) => {
395 Err(ProcessMessageError::UnsupportedProposalType)
396 }
397 FramedContentBody::Commit(_) => {
398 Err(ProcessMessageError::UnauthorizedExternalCommitMessage)
399 }
400 }
401 }
402 }
403 }
404
405 pub(crate) fn decrypt_message(
417 &mut self,
418 crypto: &impl OpenMlsCrypto,
419 message: ProtocolMessage,
420 sender_ratchet_configuration: &SenderRatchetConfiguration,
421 ) -> Result<DecryptedMessage, ValidationError> {
422 self.public_group.validate_framing(&message)?;
426
427 let epoch = message.epoch();
428
429 match message {
433 ProtocolMessage::PublicMessage(public_message) => {
434 let message_secrets =
436 self.message_secrets_for_epoch(epoch).map_err(|e| match e {
437 SecretTreeError::TooDistantInThePast => ValidationError::NoPastEpochData,
438 _ => LibraryError::custom(
439 "Unexpected error while retrieving message secrets for epoch.",
440 )
441 .into(),
442 })?;
443 DecryptedMessage::from_inbound_public_message(
444 *public_message,
445 message_secrets,
446 message_secrets.serialized_context().to_vec(),
447 crypto,
448 self.ciphersuite(),
449 )
450 }
451 ProtocolMessage::PrivateMessage(ciphertext) => {
452 DecryptedMessage::from_inbound_ciphertext(
454 ciphertext,
455 crypto,
456 self,
457 sender_ratchet_configuration,
458 )
459 }
460 }
461 }
462}