Skip to main content

openmls/group/mls_group/
staged_commit.rs

1use core::fmt::Debug;
2
3use openmls_traits::crypto::OpenMlsCrypto;
4use openmls_traits::storage::StorageProvider as _;
5use serde::{Deserialize, Serialize};
6use tls_codec::Serialize as _;
7
8use super::proposal_store::{
9    QueuedAddProposal, QueuedPskProposal, QueuedRemoveProposal, QueuedUpdateProposal,
10};
11
12#[cfg(not(feature = "virtual-clients-draft"))]
13use super::Sender;
14use super::{
15    super::errors::*, load_psks, Credential, Extension, GroupContext, GroupEpochSecrets, GroupId,
16    JoinerSecret, KeySchedule, LeafNode, LibraryError, MessageSecrets, MlsGroup, OpenMlsProvider,
17    Proposal, ProposalQueue, PskSecret, QueuedProposal,
18};
19use crate::group::diff::PublicGroupDiff;
20use crate::group::GroupEpoch;
21use crate::prelude::{Commit, LeafNodeIndex};
22#[cfg(feature = "extensions-draft-08")]
23use crate::{component::ComponentId, schedule::application_export_tree::ApplicationExportTree};
24
25use crate::treesync::errors::TreeSyncFromNodesError;
26use crate::treesync::RatchetTree;
27use crate::{
28    ciphersuite::{hash_ref::ProposalRef, Secret},
29    framing::mls_auth_content::AuthenticatedContent,
30    group::public_group::{
31        diff::{apply_proposals::ApplyProposalsValues, StagedPublicGroupDiff},
32        staged_commit::PublicStagedCommitState,
33    },
34    schedule::{
35        CommitSecret, EpochAuthenticator, EpochSecretsResult, InitSecret, PreSharedKeyId,
36        ResumptionPskSecret,
37    },
38    treesync::node::encryption_keys::EncryptionKeyPair,
39};
40
41#[cfg(feature = "extensions-draft-08")]
42use super::proposal_store::{QueuedAppDataUpdateProposal, QueuedAppEphemeralProposal};
43#[cfg(feature = "extensions-draft-08")]
44use crate::prelude::processing::AppDataUpdates;
45
46impl MlsGroup {
47    fn derive_epoch_secrets(
48        &self,
49        provider: &impl OpenMlsProvider,
50        apply_proposals_values: ApplyProposalsValues,
51        epoch_secrets: &GroupEpochSecrets,
52        commit_secret: CommitSecret,
53        serialized_provisional_group_context: &[u8],
54    ) -> Result<EpochSecretsResult, StageCommitError> {
55        // Check if we need to include the init secret from an external commit
56        // we applied earlier or if we use the one from the previous epoch.
57        let joiner_secret = if let Some(ref external_init_proposal) =
58            apply_proposals_values.external_init_proposal_option
59        {
60            // Decrypt the content and derive the external init secret.
61            let external_priv = epoch_secrets
62                .external_secret()
63                .derive_external_keypair(provider.crypto(), self.ciphersuite())
64                .map_err(LibraryError::unexpected_crypto_error)?
65                .private;
66            let init_secret = InitSecret::from_kem_output(
67                provider.crypto(),
68                self.ciphersuite(),
69                self.version(),
70                &external_priv,
71                external_init_proposal.kem_output(),
72            )?;
73            JoinerSecret::new(
74                provider.crypto(),
75                self.ciphersuite(),
76                commit_secret,
77                &init_secret,
78                serialized_provisional_group_context,
79            )
80            .map_err(LibraryError::unexpected_crypto_error)?
81        } else {
82            JoinerSecret::new(
83                provider.crypto(),
84                self.ciphersuite(),
85                commit_secret,
86                epoch_secrets.init_secret(),
87                serialized_provisional_group_context,
88            )
89            .map_err(LibraryError::unexpected_crypto_error)?
90        };
91
92        // Prepare the PskSecret
93        // Fails if PSKs are missing ([valn1205](https://validation.openmls.tech/#valn1205))
94        let psk_secret = {
95            let psks: Vec<(&PreSharedKeyId, Secret)> = load_psks(
96                provider.storage(),
97                &self.resumption_psk_store,
98                &apply_proposals_values.presharedkeys,
99            )?;
100
101            PskSecret::new(provider.crypto(), self.ciphersuite(), psks)?
102        };
103
104        // Create key schedule
105        let mut key_schedule = KeySchedule::init(
106            self.ciphersuite(),
107            provider.crypto(),
108            &joiner_secret,
109            psk_secret,
110        )?;
111
112        key_schedule
113            .add_context(provider.crypto(), serialized_provisional_group_context)
114            .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?;
115        Ok(key_schedule
116            .epoch_secrets(provider.crypto(), self.ciphersuite())
117            .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?)
118    }
119
120    /// Stages a commit message that was sent by another group member. This
121    /// function does the following:
122    ///  - Applies the proposals covered by the commit to the tree
123    ///  - Applies the (optional) update path to the tree
124    ///  - Decrypts and calculates the path secrets
125    ///  - Initializes the key schedule for epoch rollover
126    ///  - Verifies the confirmation tag
127    ///
128    /// Returns a [StagedCommit] that can be inspected and later merged into the
129    /// group state with [MlsGroup::merge_commit()]. If the member was not
130    /// removed from the group, the function also returns an
131    /// [ApplicationExportSecret].
132    ///
133    /// This function does the following checks:
134    ///  - ValSem101
135    ///  - ValSem102
136    ///  - ValSem104
137    ///  - ValSem105
138    ///  - ValSem106
139    ///  - ValSem107
140    ///  - ValSem108
141    ///  - ValSem110
142    ///  - ValSem111
143    ///  - ValSem112
144    ///  - ValSem113: All Proposals: The proposal type must be supported by all
145    ///    members of the group
146    ///  - ValSem200
147    ///  - ValSem201
148    ///  - ValSem202: Path must be the right length
149    ///  - ValSem203: Path secrets must decrypt correctly
150    ///  - ValSem204: Public keys from Path must be verified and match the
151    ///    private keys from the direct path
152    ///  - ValSem205
153    ///  - ValSem240
154    ///  - ValSem241
155    ///  - ValSem242
156    ///  - ValSem244 Returns an error if the given commit was sent by the owner
157    ///    of this group.
158    pub(crate) fn stage_commit(
159        &self,
160        mls_content: &AuthenticatedContent,
161        old_epoch_keypairs: Vec<EncryptionKeyPair>,
162        leaf_node_keypairs: Vec<EncryptionKeyPair>,
163        provider: &impl OpenMlsProvider,
164        #[cfg(feature = "virtual-clients-draft")] vc_material: Option<
165            crate::components::vc_derivation_info::OperationSecret,
166        >,
167    ) -> Result<StagedCommit, StageCommitError> {
168        // Check that the sender is another member of the group
169        #[cfg(not(feature = "virtual-clients-draft"))]
170        if let Sender::Member(member) = mls_content.sender() {
171            if member == &self.own_leaf_index() {
172                return Err(StageCommitError::OwnCommit);
173            }
174        }
175
176        let (commit, proposal_queue, sender_index) = self
177            .public_group
178            .validate_commit(mls_content, provider.crypto())?;
179
180        // Create the provisional public group state (including the tree and
181        // group context) and apply proposals.
182        let mut diff = self.public_group.empty_diff();
183
184        #[cfg(not(feature = "extensions-draft-08"))]
185        let apply_proposals_values =
186            diff.apply_proposals(&proposal_queue, self.own_leaf_index())?;
187
188        #[cfg(feature = "extensions-draft-08")]
189        let apply_proposals_values = diff.apply_proposals_with_app_data_updates(
190            &proposal_queue,
191            self.own_leaf_index(),
192            None,
193        )?;
194        self.stage_applied_proposal_values(
195            apply_proposals_values,
196            diff,
197            commit,
198            proposal_queue,
199            sender_index,
200            mls_content,
201            old_epoch_keypairs,
202            leaf_node_keypairs,
203            provider,
204            #[cfg(feature = "virtual-clients-draft")]
205            vc_material,
206        )
207    }
208
209    #[cfg(feature = "extensions-draft-08")]
210    pub(crate) fn stage_commit_with_app_data_updates(
211        &self,
212        mls_content: &AuthenticatedContent,
213        old_epoch_keypairs: Vec<EncryptionKeyPair>,
214        leaf_node_keypairs: Vec<EncryptionKeyPair>,
215        app_data_dict_updates: Option<AppDataUpdates>,
216        provider: &impl OpenMlsProvider,
217        #[cfg(feature = "virtual-clients-draft")] vc_material: Option<
218            crate::components::vc_derivation_info::OperationSecret,
219        >,
220    ) -> Result<StagedCommit, StageCommitError> {
221        // Check that the sender is another member of the group
222        #[cfg(not(feature = "virtual-clients-draft"))]
223        if let Sender::Member(member) = mls_content.sender() {
224            if member == &self.own_leaf_index() {
225                return Err(StageCommitError::OwnCommit);
226            }
227        }
228
229        let (commit, proposal_queue, sender_index) = self
230            .public_group
231            .validate_commit(mls_content, provider.crypto())?;
232
233        // Create the provisional public group state (including the tree and
234        // group context) and apply proposals.
235        let mut diff = self.public_group.empty_diff();
236
237        let apply_proposals_values = diff.apply_proposals_with_app_data_updates(
238            &proposal_queue,
239            self.own_leaf_index(),
240            app_data_dict_updates,
241        )?;
242
243        self.stage_applied_proposal_values(
244            apply_proposals_values,
245            diff,
246            commit,
247            proposal_queue,
248            sender_index,
249            mls_content,
250            old_epoch_keypairs,
251            leaf_node_keypairs,
252            provider,
253            #[cfg(feature = "virtual-clients-draft")]
254            vc_material,
255        )
256    }
257
258    #[allow(clippy::too_many_arguments)]
259    fn stage_applied_proposal_values(
260        &self,
261        apply_proposals_values: ApplyProposalsValues,
262        mut diff: PublicGroupDiff,
263        commit: &Commit,
264        proposal_queue: ProposalQueue,
265        sender_index: LeafNodeIndex,
266        mls_content: &AuthenticatedContent,
267        old_epoch_keypairs: Vec<EncryptionKeyPair>,
268        leaf_node_keypairs: Vec<EncryptionKeyPair>,
269        provider: &impl OpenMlsProvider,
270        #[cfg(feature = "virtual-clients-draft")] vc_material: Option<
271            crate::components::vc_derivation_info::OperationSecret,
272        >,
273    ) -> Result<StagedCommit, StageCommitError> {
274        let ciphersuite = self.ciphersuite();
275        // Determine if Commit has a path
276        let (commit_secret, new_keypairs, new_leaf_keypair_option, update_path_leaf_node) =
277            if let Some(path) = commit.path.clone() {
278                // Update the public group
279                // ValSem202: Path must be the right length
280                diff.apply_received_update_path(
281                    provider.crypto(),
282                    ciphersuite,
283                    sender_index,
284                    &path,
285                )?;
286
287                // Update group context
288                diff.update_group_context(
289                    provider.crypto(),
290                    apply_proposals_values.extensions.clone(),
291                )?;
292
293                // Check if we were removed from the group
294                if apply_proposals_values.self_removed {
295                    // If so, we return here, because we can't decrypt the path
296                    let staged_diff = diff.into_staged_diff(provider.crypto(), ciphersuite)?;
297                    let staged_state = PublicStagedCommitState::new(
298                        staged_diff,
299                        commit.path.as_ref().map(|path| path.leaf_node().clone()),
300                    );
301                    let staged_commit = StagedCommit::new(
302                        proposal_queue,
303                        StagedCommitState::PublicState(Box::new(staged_state)),
304                    );
305                    return Ok(staged_commit);
306                }
307
308                // When processing our own commit (sent by a sibling virtual
309                // client), we have no HPKE recipient on the path. Re-derive
310                // the path from the per-commit `OperationSecret` the receiver
311                // computed by evaluating the per-epoch PPRF, and verify the
312                // resulting public keys against the commit. The non-VC
313                // `decrypt_path` is the fallback for everyone else.
314                #[cfg(feature = "virtual-clients-draft")]
315                let vc_path: Option<(Vec<EncryptionKeyPair>, CommitSecret)> =
316                    if sender_index == self.own_leaf_index() {
317                        let operation_secret = vc_material.ok_or(
318                            crate::components::vc_derivation_info::VirtualClientsError::MissingPprf,
319                        )?;
320                        Some(self.recreate_path_for_own_commit(
321                            &diff,
322                            &path,
323                            ciphersuite,
324                            provider.crypto(),
325                            operation_secret,
326                        )?)
327                    } else {
328                        None
329                    };
330                #[cfg(not(feature = "virtual-clients-draft"))]
331                let vc_path: Option<(Vec<EncryptionKeyPair>, CommitSecret)> = None;
332
333                // ValSem203: Path secrets must decrypt correctly
334                // ValSem204: Public keys from Path must be verified and match the private keys from the direct path
335                let (new_keypairs, commit_secret) = if let Some(pair) = vc_path {
336                    pair
337                } else {
338                    let decryption_keypairs: Vec<&EncryptionKeyPair> = old_epoch_keypairs
339                        .iter()
340                        .chain(leaf_node_keypairs.iter())
341                        .collect();
342                    diff.decrypt_path(
343                        provider.crypto(),
344                        &decryption_keypairs,
345                        self.own_leaf_index(),
346                        sender_index,
347                        path.nodes(),
348                        &apply_proposals_values.exclusion_list(),
349                    )?
350                };
351
352                // Check if one of our update proposals was applied. If so, we
353                // need to store that keypair separately, because after merging
354                // it needs to be removed from the key store separately and in
355                // addition to the removal of the keypairs of the previous
356                // epoch.
357                let new_leaf_keypair_option = if let Some(leaf) = diff.leaf(self.own_leaf_index()) {
358                    leaf_node_keypairs.into_iter().find_map(|keypair| {
359                        if leaf.encryption_key() == keypair.public_key() {
360                            Some(keypair)
361                        } else {
362                            None
363                        }
364                    })
365                } else {
366                    // We should have an own leaf at this point.
367                    debug_assert!(false);
368                    None
369                };
370
371                // Return the leaf node in the update path so the credential can be validated.
372                // Since the diff has already been updated, this should be the same as the leaf
373                // at the sender index.
374                let update_path_leaf_node = Some(path.leaf_node().clone());
375                debug_assert_eq!(diff.leaf(sender_index), path.leaf_node().into());
376
377                (
378                    commit_secret,
379                    new_keypairs,
380                    new_leaf_keypair_option,
381                    update_path_leaf_node,
382                )
383            } else {
384                if apply_proposals_values.path_required {
385                    // ValSem201
386                    return Err(StageCommitError::RequiredPathNotFound);
387                }
388
389                // Even if there is no path, we have to update the group context.
390                diff.update_group_context(
391                    provider.crypto(),
392                    apply_proposals_values.extensions.clone(),
393                )?;
394
395                (CommitSecret::zero_secret(ciphersuite), vec![], None, None)
396            };
397
398        // Update the confirmed transcript hash before we compute the confirmation tag.
399        diff.update_confirmed_transcript_hash(provider.crypto(), mls_content)?;
400
401        let received_confirmation_tag = mls_content
402            .confirmation_tag()
403            .ok_or(StageCommitError::ConfirmationTagMissing)?;
404
405        let serialized_provisional_group_context = diff
406            .group_context()
407            .tls_serialize_detached()
408            .map_err(LibraryError::missing_bound_check)?;
409
410        let EpochSecretsResult {
411            epoch_secrets,
412            #[cfg(feature = "extensions-draft-08")]
413            application_exporter,
414        } = self.derive_epoch_secrets(
415            provider,
416            apply_proposals_values,
417            self.group_epoch_secrets(),
418            commit_secret,
419            &serialized_provisional_group_context,
420        )?;
421        let (provisional_group_secrets, provisional_message_secrets) = epoch_secrets.split_secrets(
422            serialized_provisional_group_context,
423            diff.tree_size(),
424            self.own_leaf_index(),
425        );
426
427        // Verify confirmation tag
428        // ValSem205
429        let own_confirmation_tag = provisional_message_secrets
430            .confirmation_key()
431            .tag(
432                provider.crypto(),
433                self.ciphersuite(),
434                diff.group_context().confirmed_transcript_hash(),
435            )
436            .map_err(LibraryError::unexpected_crypto_error)?;
437        if &own_confirmation_tag != received_confirmation_tag {
438            log::error!("Confirmation tag mismatch");
439            log_crypto!(trace, "  Got:      {:x?}", received_confirmation_tag);
440            log_crypto!(trace, "  Expected: {:x?}", own_confirmation_tag);
441            // TODO: We have tests expecting this error.
442            //       They need to be rewritten.
443            // debug_assert!(false, "Confirmation tag mismatch");
444
445            // in some tests we need to be able to proceed despite the tag being wrong,
446            // e.g. to test whether a later validation check is performed correctly.
447            if !crate::skip_validation::is_disabled::confirmation_tag() {
448                return Err(StageCommitError::ConfirmationTagMismatch);
449            }
450        }
451
452        diff.update_interim_transcript_hash(ciphersuite, provider.crypto(), own_confirmation_tag)?;
453
454        let staged_diff = diff.into_staged_diff(provider.crypto(), ciphersuite)?;
455        #[cfg(feature = "extensions-draft-08")]
456        let application_export_tree = ApplicationExportTree::new(application_exporter);
457        let staged_commit_state =
458            StagedCommitState::GroupMember(Box::new(MemberStagedCommitState::new(
459                provisional_group_secrets,
460                provisional_message_secrets,
461                staged_diff,
462                new_keypairs,
463                new_leaf_keypair_option,
464                update_path_leaf_node,
465                #[cfg(feature = "extensions-draft-08")]
466                application_export_tree,
467            )));
468        let staged_commit = StagedCommit::new(proposal_queue, staged_commit_state);
469
470        Ok(staged_commit)
471    }
472
473    /// Re-derive the path of an own commit (i.e. one sent by a sibling
474    /// virtual client) from the per-commit `OperationSecret` resolved by
475    /// the caller (in `process_message`, via PPRF evaluation), and verify
476    /// that the public keys derived from the path secret match the leaf
477    /// node in the path. Returns the derived parent-node keypairs and
478    /// commit secret, ready to be slotted in where `decrypt_path` would
479    /// normally produce them.
480    ///
481    /// Signature key changes are not verified here: not every signature
482    /// scheme has an interoperable seed-to-keypair construction, so the
483    /// application is responsible for supplying any rotated signature key
484    /// pair to the storage provider out-of-band. The new public key on the
485    /// path leaf is already authenticated by the commit's standard
486    /// path-validation against the previous signature key.
487    #[cfg(feature = "virtual-clients-draft")]
488    fn recreate_path_for_own_commit(
489        &self,
490        diff: &PublicGroupDiff,
491        path: &crate::treesync::treekem::UpdatePath,
492        ciphersuite: openmls_traits::types::Ciphersuite,
493        crypto: &impl OpenMlsCrypto,
494        operation_secret: crate::components::vc_derivation_info::OperationSecret,
495    ) -> Result<(Vec<EncryptionKeyPair>, CommitSecret), StageCommitError> {
496        use crate::components::vc_derivation_info::VirtualClientsError;
497
498        let path_secret = operation_secret
499            .derive_path_generation_secret(crypto, ciphersuite)?
500            .into();
501        let (encryption_key_pairs, commit_secret) = diff.recreate_path_from_path_secret(
502            crypto,
503            path_secret,
504            self.own_leaf_index(),
505            path.nodes(),
506        )?;
507
508        // Verify that the leaf encryption key in the path matches the one
509        // derived from the operation secret.
510        let leaf_keypair = operation_secret
511            .derive_encryption_key_secret(crypto, ciphersuite)?
512            .generate_encryption_key_pair(crypto, ciphersuite)?;
513        if leaf_keypair.public_key() != path.leaf_node().encryption_key() {
514            return Err(VirtualClientsError::EncryptionKeyMismatch.into());
515        }
516
517        // Mirror the sender's `apply_own_update_path`: the leaf keypair is
518        // prepended to the parent keypairs so the merge step has private
519        // keys for all of the new epoch's owned encryption keys.
520        let mut keypairs = Vec::with_capacity(1 + encryption_key_pairs.len());
521        keypairs.push(leaf_keypair);
522        keypairs.extend(encryption_key_pairs);
523        Ok((keypairs, commit_secret))
524    }
525
526    /// Merges a [StagedCommit] into the group state and optionally return a [`SecretTree`]
527    /// from the previous epoch. The secret tree is returned if the Commit does not contain a self removal.
528    ///
529    /// This function should not fail and only returns a [`Result`], because it
530    /// might throw a `LibraryError`.
531    pub(crate) fn merge_commit<Provider: OpenMlsProvider>(
532        &mut self,
533        provider: &Provider,
534        staged_commit: StagedCommit,
535    ) -> Result<(), MergeCommitError<Provider::StorageError>> {
536        // Get all keypairs from the old epoch, so we can later store the ones
537        // that are still relevant in the new epoch.
538        let old_epoch_keypairs = self
539            .read_epoch_keypairs(provider.storage())
540            .map_err(MergeCommitError::StorageError)?;
541        match staged_commit.state {
542            StagedCommitState::PublicState(staged_state) => {
543                self.public_group
544                    .merge_diff(staged_state.into_staged_diff());
545                self.store(provider.storage())
546                    .map_err(MergeCommitError::StorageError)?;
547                Ok(())
548            }
549            StagedCommitState::GroupMember(state) => {
550                // Save the past epoch
551                let past_epoch = self.context().epoch();
552                // Get all the full leaves
553                let leaves = self.public_group().members().collect();
554                // Merge the staged commit into the group state and store the secret tree from the
555                // previous epoch in the message secrets store.
556                self.group_epoch_secrets = state.group_epoch_secrets;
557
558                // Replace the previous message secrets with the new ones and return the previous message secrets
559                let old_message_secrets = self
560                    .message_secrets_store
561                    .replace_current_message_secrets(state.message_secrets);
562                self.message_secrets_store.add_past_epoch_tree(
563                    past_epoch,
564                    old_message_secrets,
565                    leaves,
566                );
567
568                // Replace the previous exporter tree with the new one.
569                #[cfg(feature = "extensions-draft-08")]
570                {
571                    // The application exporter is only None if the group was
572                    // stored using an older version of OpenMLS that did not
573                    // support the application exporter.
574                    if let Some(application_export_tree) = state.application_export_tree {
575                        // Overwrite the existing exporter tree in the storage.
576
577                        use openmls_traits::storage::StorageProvider as _;
578                        provider
579                            .storage()
580                            .write_application_export_tree(
581                                self.group_id(),
582                                &application_export_tree,
583                            )
584                            .map_err(MergeCommitError::StorageError)?;
585
586                        self.application_export_tree = Some(application_export_tree);
587                    }
588                }
589
590                self.public_group.merge_diff(state.staged_diff);
591
592                let leaf_keypair = if let Some(keypair) = &state.new_leaf_keypair_option {
593                    vec![keypair.clone()]
594                } else {
595                    vec![]
596                };
597
598                // Figure out which keys we need in the new epoch.
599                let new_owned_encryption_keys = self
600                    .public_group()
601                    .owned_encryption_keys(self.own_leaf_index());
602                // From the old and new keys, keep the ones that are still relevant in the new epoch.
603                let epoch_keypairs: Vec<EncryptionKeyPair> = old_epoch_keypairs
604                    .into_iter()
605                    .chain(state.new_keypairs)
606                    .chain(leaf_keypair)
607                    .filter(|keypair| new_owned_encryption_keys.contains(keypair.public_key()))
608                    .collect();
609
610                // We should have private keys for all owned encryption keys.
611                debug_assert_eq!(new_owned_encryption_keys.len(), epoch_keypairs.len());
612                if new_owned_encryption_keys.len() != epoch_keypairs.len() {
613                    return Err(LibraryError::custom(
614                        "We should have all the private key material we need.",
615                    )
616                    .into());
617                }
618
619                // Store the updated group state
620                let storage = provider.storage();
621                let group_id = self.group_id();
622
623                self.public_group
624                    .store(storage)
625                    .map_err(MergeCommitError::StorageError)?;
626                storage
627                    .write_group_epoch_secrets(group_id, &self.group_epoch_secrets)
628                    .map_err(MergeCommitError::StorageError)?;
629                storage
630                    .write_message_secrets(group_id, &self.message_secrets_store)
631                    .map_err(MergeCommitError::StorageError)?;
632
633                // Store the relevant keys under the new epoch
634                self.store_epoch_keypairs(storage, epoch_keypairs.as_slice())
635                    .map_err(MergeCommitError::StorageError)?;
636
637                // Delete the old keys.
638                self.delete_previous_epoch_keypairs(storage)
639                    .map_err(MergeCommitError::StorageError)?;
640                if let Some(keypair) = state.new_leaf_keypair_option {
641                    keypair
642                        .delete(storage)
643                        .map_err(MergeCommitError::StorageError)?;
644                }
645
646                // Empty the proposal store
647                storage
648                    .clear_proposal_queue::<GroupId, ProposalRef>(group_id)
649                    .map_err(MergeCommitError::StorageError)?;
650                self.proposal_store_mut().empty();
651
652                Ok(())
653            }
654        }
655    }
656}
657
658#[derive(Debug, Serialize, Deserialize)]
659#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
660pub(crate) enum StagedCommitState {
661    PublicState(Box<PublicStagedCommitState>),
662    /// The group member variant of the staged commit state.
663    GroupMember(Box<MemberStagedCommitState>),
664}
665
666/// Contains the changes from a commit to the group state.
667#[derive(Debug, Serialize, Deserialize)]
668#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
669pub struct StagedCommit {
670    /// A queue containing the proposals associated with the commit.
671    pub staged_proposal_queue: ProposalQueue,
672    /// The staged commit state.
673    pub(super) state: StagedCommitState,
674}
675
676impl StagedCommit {
677    /// Create a new [`StagedCommit`] from the provisional group state created
678    /// during the commit process.
679    pub(crate) fn new(staged_proposal_queue: ProposalQueue, state: StagedCommitState) -> Self {
680        StagedCommit {
681            staged_proposal_queue,
682            state,
683        }
684    }
685
686    /// Returns the epoch that this commit moves the group into
687    pub fn epoch(&self) -> GroupEpoch {
688        self.group_context().epoch()
689    }
690
691    /// Returns the ratchet tree of the staged commit state.
692    pub fn export_ratchet_tree(
693        &self,
694        crypto: &impl OpenMlsCrypto,
695        original_tree: RatchetTree,
696    ) -> Result<Option<RatchetTree>, TreeSyncFromNodesError> {
697        match &self.state {
698            StagedCommitState::PublicState(_public_staged_commit_state) => Ok(None),
699            StagedCommitState::GroupMember(member_staged_commit_state) => Ok(Some(
700                member_staged_commit_state.staged_diff.export_ratchet_tree(
701                    crypto,
702                    self.group_context().ciphersuite(),
703                    original_tree,
704                )?,
705            )),
706        }
707    }
708
709    /// Returns the Add proposals that are covered by the Commit message as in iterator over [QueuedAddProposal].
710    pub fn add_proposals(&self) -> impl Iterator<Item = QueuedAddProposal<'_>> {
711        self.staged_proposal_queue.add_proposals()
712    }
713
714    /// Returns the Remove proposals that are covered by the Commit message as in iterator over [QueuedRemoveProposal].
715    pub fn remove_proposals(&self) -> impl Iterator<Item = QueuedRemoveProposal<'_>> {
716        self.staged_proposal_queue.remove_proposals()
717    }
718
719    /// Returns the Update proposals that are covered by the Commit message as in iterator over [QueuedUpdateProposal].
720    pub fn update_proposals(&self) -> impl Iterator<Item = QueuedUpdateProposal<'_>> {
721        self.staged_proposal_queue.update_proposals()
722    }
723
724    /// Returns the PresharedKey proposals that are covered by the Commit message as in iterator over [QueuedPskProposal].
725    pub fn psk_proposals(&self) -> impl Iterator<Item = QueuedPskProposal<'_>> {
726        self.staged_proposal_queue.psk_proposals()
727    }
728
729    #[cfg(feature = "extensions-draft-08")]
730    /// Returns the AppEphemeral proposals that are covered by the Commit message as an iterator
731    /// over [`QueuedAppEphemeralProposal`].
732    pub fn queued_app_ephemeral_proposals(
733        &self,
734    ) -> impl Iterator<Item = QueuedAppEphemeralProposal<'_>> {
735        self.staged_proposal_queue.app_ephemeral_proposals()
736    }
737    // NOTE: this is not a default proposal type
738    #[cfg(feature = "extensions-draft-08")]
739    /// Returns the AppDataUpdate proposals that are covered by the Commit message as an iterator
740    /// over [`QueuedAppDataUpdateProposal`].
741    pub fn app_data_update_proposals(
742        &self,
743    ) -> impl Iterator<Item = QueuedAppDataUpdateProposal<'_>> {
744        self.staged_proposal_queue.app_data_update_proposals()
745    }
746
747    /// Returns an iterator over all [`QueuedProposal`]s.
748    pub fn queued_proposals(&self) -> impl Iterator<Item = &QueuedProposal> {
749        self.staged_proposal_queue.queued_proposals()
750    }
751
752    /// Returns the leaf node of the (optional) update path.
753    pub fn update_path_leaf_node(&self) -> Option<&LeafNode> {
754        match self.state {
755            StagedCommitState::PublicState(ref public_state) => {
756                public_state.update_path_leaf_node()
757            }
758            StagedCommitState::GroupMember(ref group_member_state) => {
759                group_member_state.update_path_leaf_node.as_ref()
760            }
761        }
762    }
763
764    /// Returns the credentials that the caller needs to verify are valid.
765    pub fn credentials_to_verify(&self) -> impl Iterator<Item = &Credential> {
766        let update_path_leaf_node_cred = if let Some(node) = self.update_path_leaf_node() {
767            vec![node.credential()]
768        } else {
769            vec![]
770        };
771
772        update_path_leaf_node_cred
773            .into_iter()
774            .chain(
775                self.queued_proposals()
776                    .flat_map(|proposal: &QueuedProposal| match proposal.proposal() {
777                        Proposal::Update(update_proposal) => {
778                            vec![update_proposal.leaf_node().credential()].into_iter()
779                        }
780                        Proposal::Add(add_proposal) => {
781                            vec![add_proposal.key_package().leaf_node().credential()].into_iter()
782                        }
783                        Proposal::GroupContextExtensions(gce_proposal) => gce_proposal
784                            .extensions()
785                            .iter()
786                            .flat_map(|extension| {
787                                match extension {
788                                    Extension::ExternalSenders(external_senders) => {
789                                        external_senders
790                                            .iter()
791                                            .map(|external_sender| external_sender.credential())
792                                            .collect()
793                                    }
794                                    _ => vec![],
795                                }
796                                .into_iter()
797                            })
798                            // TODO: ideally we wouldn't collect in between here, but the match arms
799                            //       have to all return the same type. We solve this by having them all
800                            //       be vec::IntoIter, but it would be nice if we just didn't have to
801                            //       do this.
802                            //       It might be possible to solve this by letting all match arms
803                            //       evaluate to a dyn Iterator.
804                            .collect::<Vec<_>>()
805                            .into_iter(),
806                        _ => vec![].into_iter(),
807                    }),
808            )
809    }
810
811    /// Returns `true` if the member was removed through a proposal covered by this Commit message
812    /// and `false` otherwise.
813    pub fn self_removed(&self) -> bool {
814        matches!(self.state, StagedCommitState::PublicState(_))
815    }
816
817    /// Returns the [`GroupContext`] of the staged commit state.
818    pub fn group_context(&self) -> &GroupContext {
819        match self.state {
820            StagedCommitState::PublicState(ref ps) => ps.staged_diff().group_context(),
821            StagedCommitState::GroupMember(ref gm) => gm.group_context(),
822        }
823    }
824    /// Consume this [`StagedCommit`] and return the internal [`StagedCommitState`].
825    pub(crate) fn into_state(self) -> StagedCommitState {
826        self.state
827    }
828
829    /// Returns the [`EpochAuthenticator`] of the staged commit state if the
830    /// owner of the originating group state is a member of the group. Returns
831    /// `None` otherwise.
832    pub fn epoch_authenticator(&self) -> Option<&EpochAuthenticator> {
833        if let StagedCommitState::GroupMember(ref gm) = self.state {
834            Some(gm.group_epoch_secrets.epoch_authenticator())
835        } else {
836            None
837        }
838    }
839
840    /// Returns the [`ResumptionPskSecret`] of the staged commit state if the
841    /// owner of the originating group state is a member of the group. Returns
842    /// `None` otherwise.
843    pub fn resumption_psk_secret(&self) -> Option<&ResumptionPskSecret> {
844        if let StagedCommitState::GroupMember(ref gm) = self.state {
845            Some(gm.group_epoch_secrets.resumption_psk())
846        } else {
847            None
848        }
849    }
850
851    #[cfg(feature = "extensions-draft-08")]
852    pub(crate) fn safe_export_secret(
853        &mut self,
854        crypto: &impl OpenMlsCrypto,
855        component_id: ComponentId,
856    ) -> Result<Vec<u8>, StagedSafeExportSecretError> {
857        let ciphersuite = self.group_context().ciphersuite();
858        let StagedCommitState::GroupMember(ref mut staged_commit) = self.state else {
859            return Err(StagedSafeExportSecretError::NotGroupMember);
860        };
861        let Some(application_export_tree) = staged_commit.application_export_tree.as_mut() else {
862            return Err(StagedSafeExportSecretError::Unsupported);
863        };
864        let secret =
865            application_export_tree.safe_export_secret(crypto, ciphersuite, component_id)?;
866        Ok(secret.as_slice().to_vec())
867    }
868
869    /// Exports a secret from the epoch that the staged commit moves to.
870    /// Returns [`ExportSecretError::KeyLengthTooLong`] if the requested
871    /// key length is too long.
872    /// Returns [`ExportSecretError::GroupStateError(MlsGroupStateError::UseAfterEviction)`]
873    /// if the commit removed us from the group.
874    ///
875    /// [`ExportSecretError::GroupStateError(MlsGroupStateError::UseAfterEviction)`]: MlsGroupStateError::UseAfterEviction
876    pub fn export_secret<CryptoProvider: OpenMlsCrypto>(
877        &self,
878        crypto: &CryptoProvider,
879        label: &str,
880        context: &[u8],
881        key_length: usize,
882    ) -> Result<Vec<u8>, ExportSecretError> {
883        if key_length > u16::MAX as usize {
884            log::error!("Got a key that is larger than u16::MAX");
885            return Err(ExportSecretError::KeyLengthTooLong);
886        }
887
888        match &self.state {
889            StagedCommitState::PublicState(_public_staged_commit_state) => Err(
890                ExportSecretError::GroupStateError(MlsGroupStateError::UseAfterEviction),
891            ),
892            StagedCommitState::GroupMember(member_staged_commit_state) => {
893                Ok(member_staged_commit_state
894                    .group_epoch_secrets
895                    .exporter_secret()
896                    .derive_exported_secret(
897                        self.group_context().ciphersuite(),
898                        crypto,
899                        label,
900                        context,
901                        key_length,
902                    )
903                    .map_err(LibraryError::unexpected_crypto_error)?)
904            }
905        }
906    }
907}
908
909/// This struct is used internally by [`StagedCommit`] to encapsulate all the modified group state.
910#[derive(Debug, Serialize, Deserialize)]
911#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
912pub(crate) struct MemberStagedCommitState {
913    group_epoch_secrets: GroupEpochSecrets,
914    message_secrets: MessageSecrets,
915    staged_diff: StagedPublicGroupDiff,
916    new_keypairs: Vec<EncryptionKeyPair>,
917    new_leaf_keypair_option: Option<EncryptionKeyPair>,
918    update_path_leaf_node: Option<LeafNode>,
919    #[cfg(feature = "extensions-draft-08")]
920    #[serde(default)]
921    // This is `None` only if the group was stored using an older version of
922    // OpenMLS that did not support the application exporter.
923    application_export_tree: Option<ApplicationExportTree>,
924}
925
926impl MemberStagedCommitState {
927    pub(crate) fn new(
928        group_epoch_secrets: GroupEpochSecrets,
929        message_secrets: MessageSecrets,
930        staged_diff: StagedPublicGroupDiff,
931        new_keypairs: Vec<EncryptionKeyPair>,
932        new_leaf_keypair_option: Option<EncryptionKeyPair>,
933        update_path_leaf_node: Option<LeafNode>,
934        #[cfg(feature = "extensions-draft-08")] application_export_tree: ApplicationExportTree,
935    ) -> Self {
936        Self {
937            group_epoch_secrets,
938            message_secrets,
939            staged_diff,
940            new_keypairs,
941            new_leaf_keypair_option,
942            update_path_leaf_node,
943            #[cfg(feature = "extensions-draft-08")]
944            application_export_tree: Some(application_export_tree),
945        }
946    }
947
948    /// Get the staged [`GroupContext`].
949    pub(crate) fn group_context(&self) -> &GroupContext {
950        self.staged_diff.group_context()
951    }
952}