1use std::mem;
4
5use errors::{CommitToPendingProposalsError, MergePendingCommitError};
6use openmls_traits::{crypto::OpenMlsCrypto, signatures::Signer, storage::StorageProvider as _};
7
8use crate::{
9    framing::mls_content::FramedContentBody,
10    group::{errors::MergeCommitError, StageCommitError, ValidationError},
11    messages::group_info::GroupInfo,
12    storage::OpenMlsProvider,
13    tree::sender_ratchet::SenderRatchetConfiguration,
14};
15
16use super::{errors::ProcessMessageError, *};
17
18impl MlsGroup {
19    pub fn process_message<Provider: OpenMlsProvider>(
29        &mut self,
30        provider: &Provider,
31        message: impl Into<ProtocolMessage>,
32    ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
33        if !self.is_active() {
35            return Err(ProcessMessageError::GroupStateError(
36                MlsGroupStateError::UseAfterEviction,
37            ));
38        }
39        let message = message.into();
40
41        if !message.is_external()
43            && message.is_handshake_message()
44            && !self
45                .configuration()
46                .wire_format_policy()
47                .incoming()
48                .is_compatible_with(message.wire_format())
49        {
50            return Err(ProcessMessageError::IncompatibleWireFormat);
51        }
52
53        let sender_ratchet_configuration = *self.configuration().sender_ratchet_configuration();
55
56        let will_modify_secret_tree = matches!(message, ProtocolMessage::PrivateMessage(_));
59
60        let decrypted_message =
66            self.decrypt_message(provider.crypto(), message, &sender_ratchet_configuration)?;
67
68        let unverified_message = self
69            .public_group
70            .parse_message(decrypted_message, &self.message_secrets_store)
71            .map_err(ProcessMessageError::from)?;
72
73        let (old_epoch_keypairs, leaf_node_keypairs) =
75            if let ContentType::Commit = unverified_message.content_type() {
76                self.read_decryption_keypairs(provider, &self.own_leaf_nodes)?
77            } else {
78                (vec![], vec![])
79            };
80
81        let processed_message = self.process_unverified_message(
82            provider,
83            unverified_message,
84            old_epoch_keypairs,
85            leaf_node_keypairs,
86        )?;
87
88        if will_modify_secret_tree {
90            provider
91                .storage()
92                .write_message_secrets(self.group_id(), &self.message_secrets_store)
93                .map_err(ProcessMessageError::StorageError)?;
94        }
95
96        Ok(processed_message)
97    }
98
99    pub fn store_pending_proposal<Storage: StorageProvider>(
101        &mut self,
102        storage: &Storage,
103        proposal: QueuedProposal,
104    ) -> Result<(), Storage::Error> {
105        storage.queue_proposal(self.group_id(), &proposal.proposal_reference(), &proposal)?;
106        self.proposal_store_mut().add(proposal);
108
109        Ok(())
110    }
111
112    pub fn has_pending_proposals(&self) -> bool {
114        !self.proposal_store().is_empty()
115    }
116
117    #[allow(clippy::type_complexity)]
128    pub fn commit_to_pending_proposals<Provider: OpenMlsProvider>(
129        &mut self,
130        provider: &Provider,
131        signer: &impl Signer,
132    ) -> Result<
133        (MlsMessageOut, Option<MlsMessageOut>, Option<GroupInfo>),
134        CommitToPendingProposalsError<Provider::StorageError>,
135    > {
136        self.is_operational()?;
137
138        let (commit, welcome, group_info) = self
141            .commit_builder()
142            .consume_proposal_store(true)
144            .load_psks(provider.storage())?
145            .build(provider.rand(), provider.crypto(), signer, |_| true)?
146            .stage_commit(provider)?
147            .into_contents();
148
149        Ok((
150            commit,
151            welcome.map(|welcome| MlsMessageOut::from_welcome(welcome, self.version())),
153            group_info,
154        ))
155    }
156
157    pub fn merge_staged_commit<Provider: OpenMlsProvider>(
160        &mut self,
161        provider: &Provider,
162        staged_commit: StagedCommit,
163    ) -> Result<(), MergeCommitError<Provider::StorageError>> {
164        if staged_commit.self_removed() {
166            self.group_state = MlsGroupState::Inactive;
167        }
168        provider
169            .storage()
170            .write_group_state(self.group_id(), &self.group_state)
171            .map_err(MergeCommitError::StorageError)?;
172
173        self.merge_commit(provider, staged_commit)?;
175
176        let resumption_psk = self.group_epoch_secrets().resumption_psk();
178        self.resumption_psk_store
179            .add(self.context().epoch(), resumption_psk.clone());
180        provider
181            .storage()
182            .write_resumption_psk_store(self.group_id(), &self.resumption_psk_store)
183            .map_err(MergeCommitError::StorageError)?;
184
185        self.own_leaf_nodes.clear();
187        provider
188            .storage()
189            .delete_own_leaf_nodes(self.group_id())
190            .map_err(MergeCommitError::StorageError)?;
191
192        self.clear_pending_commit(provider.storage())
194            .map_err(MergeCommitError::StorageError)?;
195
196        Ok(())
197    }
198
199    pub fn merge_pending_commit<Provider: OpenMlsProvider>(
202        &mut self,
203        provider: &Provider,
204    ) -> Result<(), MergePendingCommitError<Provider::StorageError>> {
205        match &self.group_state {
206            MlsGroupState::PendingCommit(_) => {
207                let old_state = mem::replace(&mut self.group_state, MlsGroupState::Operational);
208                if let MlsGroupState::PendingCommit(pending_commit_state) = old_state {
209                    self.merge_staged_commit(provider, (*pending_commit_state).into())?;
210                }
211                Ok(())
212            }
213            MlsGroupState::Inactive => Err(MlsGroupStateError::UseAfterEviction)?,
214            MlsGroupState::Operational => Ok(()),
215        }
216    }
217
218    pub(super) fn read_decryption_keypairs(
220        &self,
221        provider: &impl OpenMlsProvider,
222        own_leaf_nodes: &[LeafNode],
223    ) -> Result<(Vec<EncryptionKeyPair>, Vec<EncryptionKeyPair>), StageCommitError> {
224        let old_epoch_keypairs = self.read_epoch_keypairs(provider.storage()).map_err(|e| {
226            log::error!("Error reading epoch keypairs: {e:?}");
227            StageCommitError::MissingDecryptionKey
228        })?;
229
230        let leaf_node_keypairs = own_leaf_nodes
234            .iter()
235            .map(|leaf_node| {
236                EncryptionKeyPair::read(provider, leaf_node.encryption_key())
237                    .ok_or(StageCommitError::MissingDecryptionKey)
238            })
239            .collect::<Result<Vec<EncryptionKeyPair>, StageCommitError>>()?;
240
241        Ok((old_epoch_keypairs, leaf_node_keypairs))
242    }
243
244    pub(crate) fn process_unverified_message<Provider: OpenMlsProvider>(
273        &self,
274        provider: &Provider,
275        unverified_message: UnverifiedMessage,
276        old_epoch_keypairs: Vec<EncryptionKeyPair>,
277        leaf_node_keypairs: Vec<EncryptionKeyPair>,
278    ) -> Result<ProcessedMessage, ProcessMessageError<Provider::StorageError>> {
279        let (content, credential) =
285            unverified_message.verify(self.ciphersuite(), provider.crypto(), self.version())?;
286
287        match content.sender() {
288            Sender::Member(_) | Sender::NewMemberCommit | Sender::NewMemberProposal => {
289                let sender = content.sender().clone();
290                let authenticated_data = content.authenticated_data().to_owned();
291                let epoch = content.epoch();
292
293                let content = match content.content() {
294                    FramedContentBody::Application(application_message) => {
295                        ProcessedMessageContent::ApplicationMessage(ApplicationMessage::new(
296                            application_message.as_slice().to_owned(),
297                        ))
298                    }
299                    FramedContentBody::Proposal(_) => {
300                        let proposal = Box::new(QueuedProposal::from_authenticated_content_by_ref(
301                            self.ciphersuite(),
302                            provider.crypto(),
303                            content,
304                        )?);
305
306                        if matches!(sender, Sender::NewMemberProposal) {
307                            ProcessedMessageContent::ExternalJoinProposalMessage(proposal)
308                        } else {
309                            ProcessedMessageContent::ProposalMessage(proposal)
310                        }
311                    }
312                    FramedContentBody::Commit(_) => {
313                        let staged_commit = self.stage_commit(
314                            &content,
315                            old_epoch_keypairs,
316                            leaf_node_keypairs,
317                            provider,
318                        )?;
319                        ProcessedMessageContent::StagedCommitMessage(Box::new(staged_commit))
320                    }
321                };
322
323                Ok(ProcessedMessage::new(
324                    self.group_id().clone(),
325                    epoch,
326                    sender,
327                    authenticated_data,
328                    content,
329                    credential,
330                ))
331            }
332            Sender::External(_) => {
333                let sender = content.sender().clone();
334                let data = content.authenticated_data().to_owned();
335                match content.content() {
337                    FramedContentBody::Application(_) => {
338                        Err(ProcessMessageError::UnauthorizedExternalApplicationMessage)
339                    }
340                    FramedContentBody::Proposal(Proposal::GroupContextExtensions(_)) => {
342                        let content = ProcessedMessageContent::ProposalMessage(Box::new(
343                            QueuedProposal::from_authenticated_content_by_ref(
344                                self.ciphersuite(),
345                                provider.crypto(),
346                                content,
347                            )?,
348                        ));
349                        Ok(ProcessedMessage::new(
350                            self.group_id().clone(),
351                            self.context().epoch(),
352                            sender,
353                            data,
354                            content,
355                            credential,
356                        ))
357                    }
358
359                    FramedContentBody::Proposal(Proposal::Remove(_)) => {
360                        let content = ProcessedMessageContent::ProposalMessage(Box::new(
361                            QueuedProposal::from_authenticated_content_by_ref(
362                                self.ciphersuite(),
363                                provider.crypto(),
364                                content,
365                            )?,
366                        ));
367                        Ok(ProcessedMessage::new(
368                            self.group_id().clone(),
369                            self.context().epoch(),
370                            sender,
371                            data,
372                            content,
373                            credential,
374                        ))
375                    }
376                    FramedContentBody::Proposal(Proposal::Add(_)) => {
377                        let content = ProcessedMessageContent::ProposalMessage(Box::new(
378                            QueuedProposal::from_authenticated_content_by_ref(
379                                self.ciphersuite(),
380                                provider.crypto(),
381                                content,
382                            )?,
383                        ));
384                        Ok(ProcessedMessage::new(
385                            self.group_id().clone(),
386                            self.context().epoch(),
387                            sender,
388                            data,
389                            content,
390                            credential,
391                        ))
392                    }
393                    FramedContentBody::Proposal(_) => {
395                        Err(ProcessMessageError::UnsupportedProposalType)
396                    }
397                    FramedContentBody::Commit(_) => {
398                        Err(ProcessMessageError::UnauthorizedExternalCommitMessage)
399                    }
400                }
401            }
402        }
403    }
404
405    pub(crate) fn decrypt_message(
417        &mut self,
418        crypto: &impl OpenMlsCrypto,
419        message: ProtocolMessage,
420        sender_ratchet_configuration: &SenderRatchetConfiguration,
421    ) -> Result<DecryptedMessage, ValidationError> {
422        self.public_group.validate_framing(&message)?;
426
427        let epoch = message.epoch();
428
429        match message {
433            ProtocolMessage::PublicMessage(public_message) => {
434                let message_secrets =
436                    self.message_secrets_for_epoch(epoch).map_err(|e| match e {
437                        SecretTreeError::TooDistantInThePast => ValidationError::NoPastEpochData,
438                        _ => LibraryError::custom(
439                            "Unexpected error while retrieving message secrets for epoch.",
440                        )
441                        .into(),
442                    })?;
443                DecryptedMessage::from_inbound_public_message(
444                    *public_message,
445                    message_secrets,
446                    message_secrets.serialized_context().to_vec(),
447                    crypto,
448                    self.ciphersuite(),
449                )
450            }
451            ProtocolMessage::PrivateMessage(ciphertext) => {
452                DecryptedMessage::from_inbound_ciphertext(
454                    ciphertext,
455                    crypto,
456                    self,
457                    sender_ratchet_configuration,
458                )
459            }
460        }
461    }
462}