Skip to main content

openmls/group/mls_group/
staged_commit.rs

1use core::fmt::Debug;
2use std::mem;
3
4#[cfg(feature = "extensions-draft-08")]
5use openmls_traits::crypto::OpenMlsCrypto;
6use openmls_traits::storage::StorageProvider as _;
7use serde::{Deserialize, Serialize};
8use tls_codec::Serialize as _;
9
10use super::proposal_store::{
11    QueuedAddProposal, QueuedPskProposal, QueuedRemoveProposal, QueuedUpdateProposal,
12};
13
14use super::{
15    super::errors::*, load_psks, Credential, Extension, GroupContext, GroupEpochSecrets, GroupId,
16    JoinerSecret, KeySchedule, LeafNode, LibraryError, MessageSecrets, MlsGroup, OpenMlsProvider,
17    Proposal, ProposalQueue, PskSecret, QueuedProposal, Sender,
18};
19use crate::group::diff::PublicGroupDiff;
20use crate::prelude::{Commit, LeafNodeIndex};
21#[cfg(feature = "extensions-draft-08")]
22use crate::schedule::application_export_tree::ApplicationExportTree;
23
24use crate::{
25    ciphersuite::{hash_ref::ProposalRef, Secret},
26    framing::mls_auth_content::AuthenticatedContent,
27    group::public_group::{
28        diff::{apply_proposals::ApplyProposalsValues, StagedPublicGroupDiff},
29        staged_commit::PublicStagedCommitState,
30    },
31    schedule::{CommitSecret, EpochAuthenticator, EpochSecretsResult, InitSecret, PreSharedKeyId},
32    treesync::node::encryption_keys::EncryptionKeyPair,
33};
34
35#[cfg(feature = "extensions-draft-08")]
36use super::proposal_store::{QueuedAppDataUpdateProposal, QueuedAppEphemeralProposal};
37#[cfg(feature = "extensions-draft-08")]
38use crate::prelude::processing::AppDataUpdates;
39
40impl MlsGroup {
41    fn derive_epoch_secrets(
42        &self,
43        provider: &impl OpenMlsProvider,
44        apply_proposals_values: ApplyProposalsValues,
45        epoch_secrets: &GroupEpochSecrets,
46        commit_secret: CommitSecret,
47        serialized_provisional_group_context: &[u8],
48    ) -> Result<EpochSecretsResult, StageCommitError> {
49        // Check if we need to include the init secret from an external commit
50        // we applied earlier or if we use the one from the previous epoch.
51        let joiner_secret = if let Some(ref external_init_proposal) =
52            apply_proposals_values.external_init_proposal_option
53        {
54            // Decrypt the content and derive the external init secret.
55            let external_priv = epoch_secrets
56                .external_secret()
57                .derive_external_keypair(provider.crypto(), self.ciphersuite())
58                .map_err(LibraryError::unexpected_crypto_error)?
59                .private;
60            let init_secret = InitSecret::from_kem_output(
61                provider.crypto(),
62                self.ciphersuite(),
63                self.version(),
64                &external_priv,
65                external_init_proposal.kem_output(),
66            )?;
67            JoinerSecret::new(
68                provider.crypto(),
69                self.ciphersuite(),
70                commit_secret,
71                &init_secret,
72                serialized_provisional_group_context,
73            )
74            .map_err(LibraryError::unexpected_crypto_error)?
75        } else {
76            JoinerSecret::new(
77                provider.crypto(),
78                self.ciphersuite(),
79                commit_secret,
80                epoch_secrets.init_secret(),
81                serialized_provisional_group_context,
82            )
83            .map_err(LibraryError::unexpected_crypto_error)?
84        };
85
86        // Prepare the PskSecret
87        // Fails if PSKs are missing ([valn1205](https://validation.openmls.tech/#valn1205))
88        let psk_secret = {
89            let psks: Vec<(&PreSharedKeyId, Secret)> = load_psks(
90                provider.storage(),
91                &self.resumption_psk_store,
92                &apply_proposals_values.presharedkeys,
93            )?;
94
95            PskSecret::new(provider.crypto(), self.ciphersuite(), psks)?
96        };
97
98        // Create key schedule
99        let mut key_schedule = KeySchedule::init(
100            self.ciphersuite(),
101            provider.crypto(),
102            &joiner_secret,
103            psk_secret,
104        )?;
105
106        key_schedule
107            .add_context(provider.crypto(), serialized_provisional_group_context)
108            .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?;
109        Ok(key_schedule
110            .epoch_secrets(provider.crypto(), self.ciphersuite())
111            .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?)
112    }
113
114    /// Stages a commit message that was sent by another group member. This
115    /// function does the following:
116    ///  - Applies the proposals covered by the commit to the tree
117    ///  - Applies the (optional) update path to the tree
118    ///  - Decrypts and calculates the path secrets
119    ///  - Initializes the key schedule for epoch rollover
120    ///  - Verifies the confirmation tag
121    ///
122    /// Returns a [StagedCommit] that can be inspected and later merged into the
123    /// group state with [MlsGroup::merge_commit()]. If the member was not
124    /// removed from the group, the function also returns an
125    /// [ApplicationExportSecret].
126    ///
127    /// This function does the following checks:
128    ///  - ValSem101
129    ///  - ValSem102
130    ///  - ValSem104
131    ///  - ValSem105
132    ///  - ValSem106
133    ///  - ValSem107
134    ///  - ValSem108
135    ///  - ValSem110
136    ///  - ValSem111
137    ///  - ValSem112
138    ///  - ValSem113: All Proposals: The proposal type must be supported by all
139    ///    members of the group
140    ///  - ValSem200
141    ///  - ValSem201
142    ///  - ValSem202: Path must be the right length
143    ///  - ValSem203: Path secrets must decrypt correctly
144    ///  - ValSem204: Public keys from Path must be verified and match the
145    ///    private keys from the direct path
146    ///  - ValSem205
147    ///  - ValSem240
148    ///  - ValSem241
149    ///  - ValSem242
150    ///  - ValSem244 Returns an error if the given commit was sent by the owner
151    ///    of this group.
152    pub(crate) fn stage_commit(
153        &self,
154        mls_content: &AuthenticatedContent,
155        old_epoch_keypairs: Vec<EncryptionKeyPair>,
156        leaf_node_keypairs: Vec<EncryptionKeyPair>,
157        provider: &impl OpenMlsProvider,
158    ) -> Result<StagedCommit, StageCommitError> {
159        // Check that the sender is another member of the group
160        if let Sender::Member(member) = mls_content.sender() {
161            if member == &self.own_leaf_index() {
162                return Err(StageCommitError::OwnCommit);
163            }
164        }
165
166        let (commit, proposal_queue, sender_index) = self
167            .public_group
168            .validate_commit(mls_content, provider.crypto())?;
169
170        // Create the provisional public group state (including the tree and
171        // group context) and apply proposals.
172        let mut diff = self.public_group.empty_diff();
173
174        #[cfg(not(feature = "extensions-draft-08"))]
175        let apply_proposals_values =
176            diff.apply_proposals(&proposal_queue, self.own_leaf_index())?;
177
178        #[cfg(feature = "extensions-draft-08")]
179        let apply_proposals_values = diff.apply_proposals_with_app_data_updates(
180            &proposal_queue,
181            self.own_leaf_index(),
182            None,
183        )?;
184        self.stage_applied_proposal_values(
185            apply_proposals_values,
186            diff,
187            commit,
188            proposal_queue,
189            sender_index,
190            mls_content,
191            old_epoch_keypairs,
192            leaf_node_keypairs,
193            provider,
194        )
195    }
196
197    #[cfg(feature = "extensions-draft-08")]
198    pub(crate) fn stage_commit_with_app_data_updates(
199        &self,
200        mls_content: &AuthenticatedContent,
201        old_epoch_keypairs: Vec<EncryptionKeyPair>,
202        leaf_node_keypairs: Vec<EncryptionKeyPair>,
203        app_data_dict_updates: Option<AppDataUpdates>,
204        provider: &impl OpenMlsProvider,
205    ) -> Result<StagedCommit, StageCommitError> {
206        // Check that the sender is another member of the group
207        if let Sender::Member(member) = mls_content.sender() {
208            if member == &self.own_leaf_index() {
209                return Err(StageCommitError::OwnCommit);
210            }
211        }
212
213        let (commit, proposal_queue, sender_index) = self
214            .public_group
215            .validate_commit(mls_content, provider.crypto())?;
216
217        // Create the provisional public group state (including the tree and
218        // group context) and apply proposals.
219        let mut diff = self.public_group.empty_diff();
220
221        let apply_proposals_values = diff.apply_proposals_with_app_data_updates(
222            &proposal_queue,
223            self.own_leaf_index(),
224            app_data_dict_updates,
225        )?;
226
227        self.stage_applied_proposal_values(
228            apply_proposals_values,
229            diff,
230            commit,
231            proposal_queue,
232            sender_index,
233            mls_content,
234            old_epoch_keypairs,
235            leaf_node_keypairs,
236            provider,
237        )
238    }
239
240    #[allow(clippy::too_many_arguments)]
241    fn stage_applied_proposal_values(
242        &self,
243        apply_proposals_values: ApplyProposalsValues,
244        mut diff: PublicGroupDiff,
245        commit: &Commit,
246        proposal_queue: ProposalQueue,
247        sender_index: LeafNodeIndex,
248        mls_content: &AuthenticatedContent,
249        old_epoch_keypairs: Vec<EncryptionKeyPair>,
250        leaf_node_keypairs: Vec<EncryptionKeyPair>,
251        provider: &impl OpenMlsProvider,
252    ) -> Result<StagedCommit, StageCommitError> {
253        let ciphersuite = self.ciphersuite();
254        // Determine if Commit has a path
255        let (commit_secret, new_keypairs, new_leaf_keypair_option, update_path_leaf_node) =
256            if let Some(path) = commit.path.clone() {
257                // Update the public group
258                // ValSem202: Path must be the right length
259                diff.apply_received_update_path(
260                    provider.crypto(),
261                    ciphersuite,
262                    sender_index,
263                    &path,
264                )?;
265
266                // Update group context
267                diff.update_group_context(
268                    provider.crypto(),
269                    apply_proposals_values.extensions.clone(),
270                )?;
271
272                // Check if we were removed from the group
273                if apply_proposals_values.self_removed {
274                    // If so, we return here, because we can't decrypt the path
275                    let staged_diff = diff.into_staged_diff(provider.crypto(), ciphersuite)?;
276                    let staged_state = PublicStagedCommitState::new(
277                        staged_diff,
278                        commit.path.as_ref().map(|path| path.leaf_node().clone()),
279                    );
280                    let staged_commit = StagedCommit::new(
281                        proposal_queue,
282                        StagedCommitState::PublicState(Box::new(staged_state)),
283                    );
284                    return Ok(staged_commit);
285                }
286
287                let decryption_keypairs: Vec<&EncryptionKeyPair> = old_epoch_keypairs
288                    .iter()
289                    .chain(leaf_node_keypairs.iter())
290                    .collect();
291
292                // ValSem203: Path secrets must decrypt correctly
293                // ValSem204: Public keys from Path must be verified and match the private keys from the direct path
294                let (new_keypairs, commit_secret) = diff.decrypt_path(
295                    provider.crypto(),
296                    &decryption_keypairs,
297                    self.own_leaf_index(),
298                    sender_index,
299                    path.nodes(),
300                    &apply_proposals_values.exclusion_list(),
301                )?;
302
303                // Check if one of our update proposals was applied. If so, we
304                // need to store that keypair separately, because after merging
305                // it needs to be removed from the key store separately and in
306                // addition to the removal of the keypairs of the previous
307                // epoch.
308                let new_leaf_keypair_option = if let Some(leaf) = diff.leaf(self.own_leaf_index()) {
309                    leaf_node_keypairs.into_iter().find_map(|keypair| {
310                        if leaf.encryption_key() == keypair.public_key() {
311                            Some(keypair)
312                        } else {
313                            None
314                        }
315                    })
316                } else {
317                    // We should have an own leaf at this point.
318                    debug_assert!(false);
319                    None
320                };
321
322                // Return the leaf node in the update path so the credential can be validated.
323                // Since the diff has already been updated, this should be the same as the leaf
324                // at the sender index.
325                let update_path_leaf_node = Some(path.leaf_node().clone());
326                debug_assert_eq!(diff.leaf(sender_index), path.leaf_node().into());
327
328                (
329                    commit_secret,
330                    new_keypairs,
331                    new_leaf_keypair_option,
332                    update_path_leaf_node,
333                )
334            } else {
335                if apply_proposals_values.path_required {
336                    // ValSem201
337                    return Err(StageCommitError::RequiredPathNotFound);
338                }
339
340                // Even if there is no path, we have to update the group context.
341                diff.update_group_context(
342                    provider.crypto(),
343                    apply_proposals_values.extensions.clone(),
344                )?;
345
346                (CommitSecret::zero_secret(ciphersuite), vec![], None, None)
347            };
348
349        // Update the confirmed transcript hash before we compute the confirmation tag.
350        diff.update_confirmed_transcript_hash(provider.crypto(), mls_content)?;
351
352        let received_confirmation_tag = mls_content
353            .confirmation_tag()
354            .ok_or(StageCommitError::ConfirmationTagMissing)?;
355
356        let serialized_provisional_group_context = diff
357            .group_context()
358            .tls_serialize_detached()
359            .map_err(LibraryError::missing_bound_check)?;
360
361        let EpochSecretsResult {
362            epoch_secrets,
363            #[cfg(feature = "extensions-draft-08")]
364            application_exporter,
365        } = self.derive_epoch_secrets(
366            provider,
367            apply_proposals_values,
368            self.group_epoch_secrets(),
369            commit_secret,
370            &serialized_provisional_group_context,
371        )?;
372        let (provisional_group_secrets, provisional_message_secrets) = epoch_secrets.split_secrets(
373            serialized_provisional_group_context,
374            diff.tree_size(),
375            self.own_leaf_index(),
376        );
377
378        // Verify confirmation tag
379        // ValSem205
380        let own_confirmation_tag = provisional_message_secrets
381            .confirmation_key()
382            .tag(
383                provider.crypto(),
384                self.ciphersuite(),
385                diff.group_context().confirmed_transcript_hash(),
386            )
387            .map_err(LibraryError::unexpected_crypto_error)?;
388        if &own_confirmation_tag != received_confirmation_tag {
389            log::error!("Confirmation tag mismatch");
390            log_crypto!(trace, "  Got:      {:x?}", received_confirmation_tag);
391            log_crypto!(trace, "  Expected: {:x?}", own_confirmation_tag);
392            // TODO: We have tests expecting this error.
393            //       They need to be rewritten.
394            // debug_assert!(false, "Confirmation tag mismatch");
395
396            // in some tests we need to be able to proceed despite the tag being wrong,
397            // e.g. to test whether a later validation check is performed correctly.
398            if !crate::skip_validation::is_disabled::confirmation_tag() {
399                return Err(StageCommitError::ConfirmationTagMismatch);
400            }
401        }
402
403        diff.update_interim_transcript_hash(ciphersuite, provider.crypto(), own_confirmation_tag)?;
404
405        let staged_diff = diff.into_staged_diff(provider.crypto(), ciphersuite)?;
406        #[cfg(feature = "extensions-draft-08")]
407        let application_export_tree = ApplicationExportTree::new(application_exporter);
408        let staged_commit_state =
409            StagedCommitState::GroupMember(Box::new(MemberStagedCommitState::new(
410                provisional_group_secrets,
411                provisional_message_secrets,
412                staged_diff,
413                new_keypairs,
414                new_leaf_keypair_option,
415                update_path_leaf_node,
416                #[cfg(feature = "extensions-draft-08")]
417                application_export_tree,
418            )));
419        let staged_commit = StagedCommit::new(proposal_queue, staged_commit_state);
420
421        Ok(staged_commit)
422    }
423
424    /// Merges a [StagedCommit] into the group state and optionally return a [`SecretTree`]
425    /// from the previous epoch. The secret tree is returned if the Commit does not contain a self removal.
426    ///
427    /// This function should not fail and only returns a [`Result`], because it
428    /// might throw a `LibraryError`.
429    pub(crate) fn merge_commit<Provider: OpenMlsProvider>(
430        &mut self,
431        provider: &Provider,
432        staged_commit: StagedCommit,
433    ) -> Result<(), MergeCommitError<Provider::StorageError>> {
434        // Get all keypairs from the old epoch, so we can later store the ones
435        // that are still relevant in the new epoch.
436        let old_epoch_keypairs = self
437            .read_epoch_keypairs(provider.storage())
438            .map_err(MergeCommitError::StorageError)?;
439        match staged_commit.state {
440            StagedCommitState::PublicState(staged_state) => {
441                self.public_group
442                    .merge_diff(staged_state.into_staged_diff());
443                self.store(provider.storage())
444                    .map_err(MergeCommitError::StorageError)?;
445                Ok(())
446            }
447            StagedCommitState::GroupMember(state) => {
448                // Save the past epoch
449                let past_epoch = self.context().epoch();
450                // Get all the full leaves
451                let leaves = self.public_group().members().collect();
452                // Merge the staged commit into the group state and store the secret tree from the
453                // previous epoch in the message secrets store.
454                self.group_epoch_secrets = state.group_epoch_secrets;
455
456                // Replace the previous message secrets with the new ones and return the previous message secrets
457                let mut message_secrets = state.message_secrets;
458                mem::swap(
459                    &mut message_secrets,
460                    self.message_secrets_store.message_secrets_mut(),
461                );
462                self.message_secrets_store
463                    .add(past_epoch, message_secrets, leaves);
464
465                // Replace the previous exporter tree with the new one.
466                #[cfg(feature = "extensions-draft-08")]
467                {
468                    // The application exporter is only None if the group was
469                    // stored using an older version of OpenMLS that did not
470                    // support the application exporter.
471                    if let Some(application_export_tree) = state.application_export_tree {
472                        // Overwrite the existing exporter tree in the storage.
473
474                        use openmls_traits::storage::StorageProvider as _;
475                        provider
476                            .storage()
477                            .write_application_export_tree(
478                                self.group_id(),
479                                &application_export_tree,
480                            )
481                            .map_err(MergeCommitError::StorageError)?;
482
483                        self.application_export_tree = Some(application_export_tree);
484                    }
485                }
486
487                self.public_group.merge_diff(state.staged_diff);
488
489                let leaf_keypair = if let Some(keypair) = &state.new_leaf_keypair_option {
490                    vec![keypair.clone()]
491                } else {
492                    vec![]
493                };
494
495                // Figure out which keys we need in the new epoch.
496                let new_owned_encryption_keys = self
497                    .public_group()
498                    .owned_encryption_keys(self.own_leaf_index());
499                // From the old and new keys, keep the ones that are still relevant in the new epoch.
500                let epoch_keypairs: Vec<EncryptionKeyPair> = old_epoch_keypairs
501                    .into_iter()
502                    .chain(state.new_keypairs)
503                    .chain(leaf_keypair)
504                    .filter(|keypair| new_owned_encryption_keys.contains(keypair.public_key()))
505                    .collect();
506
507                // We should have private keys for all owned encryption keys.
508                debug_assert_eq!(new_owned_encryption_keys.len(), epoch_keypairs.len());
509                if new_owned_encryption_keys.len() != epoch_keypairs.len() {
510                    return Err(LibraryError::custom(
511                        "We should have all the private key material we need.",
512                    )
513                    .into());
514                }
515
516                // Store the updated group state
517                let storage = provider.storage();
518                let group_id = self.group_id();
519
520                self.public_group
521                    .store(storage)
522                    .map_err(MergeCommitError::StorageError)?;
523                storage
524                    .write_group_epoch_secrets(group_id, &self.group_epoch_secrets)
525                    .map_err(MergeCommitError::StorageError)?;
526                storage
527                    .write_message_secrets(group_id, &self.message_secrets_store)
528                    .map_err(MergeCommitError::StorageError)?;
529
530                // Store the relevant keys under the new epoch
531                self.store_epoch_keypairs(storage, epoch_keypairs.as_slice())
532                    .map_err(MergeCommitError::StorageError)?;
533
534                // Delete the old keys.
535                self.delete_previous_epoch_keypairs(storage)
536                    .map_err(MergeCommitError::StorageError)?;
537                if let Some(keypair) = state.new_leaf_keypair_option {
538                    keypair
539                        .delete(storage)
540                        .map_err(MergeCommitError::StorageError)?;
541                }
542
543                // Empty the proposal store
544                storage
545                    .clear_proposal_queue::<GroupId, ProposalRef>(group_id)
546                    .map_err(MergeCommitError::StorageError)?;
547                self.proposal_store_mut().empty();
548
549                Ok(())
550            }
551        }
552    }
553}
554
555#[derive(Debug, Serialize, Deserialize)]
556#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
557pub(crate) enum StagedCommitState {
558    PublicState(Box<PublicStagedCommitState>),
559    /// The group member variant of the staged commit state.
560    GroupMember(Box<MemberStagedCommitState>),
561}
562
563/// Contains the changes from a commit to the group state.
564#[derive(Debug, Serialize, Deserialize)]
565#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
566pub struct StagedCommit {
567    /// A queue containing the proposals associated with the commit.
568    pub staged_proposal_queue: ProposalQueue,
569    /// The staged commit state.
570    pub(super) state: StagedCommitState,
571}
572
573impl StagedCommit {
574    /// Create a new [`StagedCommit`] from the provisional group state created
575    /// during the commit process.
576    pub(crate) fn new(staged_proposal_queue: ProposalQueue, state: StagedCommitState) -> Self {
577        StagedCommit {
578            staged_proposal_queue,
579            state,
580        }
581    }
582
583    /// Returns the Add proposals that are covered by the Commit message as in iterator over [QueuedAddProposal].
584    pub fn add_proposals(&self) -> impl Iterator<Item = QueuedAddProposal<'_>> {
585        self.staged_proposal_queue.add_proposals()
586    }
587
588    /// Returns the Remove proposals that are covered by the Commit message as in iterator over [QueuedRemoveProposal].
589    pub fn remove_proposals(&self) -> impl Iterator<Item = QueuedRemoveProposal<'_>> {
590        self.staged_proposal_queue.remove_proposals()
591    }
592
593    /// Returns the Update proposals that are covered by the Commit message as in iterator over [QueuedUpdateProposal].
594    pub fn update_proposals(&self) -> impl Iterator<Item = QueuedUpdateProposal<'_>> {
595        self.staged_proposal_queue.update_proposals()
596    }
597
598    /// Returns the PresharedKey proposals that are covered by the Commit message as in iterator over [QueuedPskProposal].
599    pub fn psk_proposals(&self) -> impl Iterator<Item = QueuedPskProposal<'_>> {
600        self.staged_proposal_queue.psk_proposals()
601    }
602
603    #[cfg(feature = "extensions-draft-08")]
604    /// Returns the AppEphemeral proposals that are covered by the Commit message as an iterator
605    /// over [`QueuedAppEphemeralProposal`].
606    pub fn queued_app_ephemeral_proposals(
607        &self,
608    ) -> impl Iterator<Item = QueuedAppEphemeralProposal<'_>> {
609        self.staged_proposal_queue.app_ephemeral_proposals()
610    }
611    // NOTE: this is not a default proposal type
612    #[cfg(feature = "extensions-draft-08")]
613    /// Returns the AppDataUpdate proposals that are covered by the Commit message as an iterator
614    /// over [`QueuedAppDataUpdateProposal`].
615    pub fn app_data_update_proposals(
616        &self,
617    ) -> impl Iterator<Item = QueuedAppDataUpdateProposal<'_>> {
618        self.staged_proposal_queue.app_data_update_proposals()
619    }
620
621    /// Returns an iterator over all [`QueuedProposal`]s.
622    pub fn queued_proposals(&self) -> impl Iterator<Item = &QueuedProposal> {
623        self.staged_proposal_queue.queued_proposals()
624    }
625
626    /// Returns the leaf node of the (optional) update path.
627    pub fn update_path_leaf_node(&self) -> Option<&LeafNode> {
628        match self.state {
629            StagedCommitState::PublicState(ref public_state) => {
630                public_state.update_path_leaf_node()
631            }
632            StagedCommitState::GroupMember(ref group_member_state) => {
633                group_member_state.update_path_leaf_node.as_ref()
634            }
635        }
636    }
637
638    /// Returns the credentials that the caller needs to verify are valid.
639    pub fn credentials_to_verify(&self) -> impl Iterator<Item = &Credential> {
640        let update_path_leaf_node_cred = if let Some(node) = self.update_path_leaf_node() {
641            vec![node.credential()]
642        } else {
643            vec![]
644        };
645
646        update_path_leaf_node_cred
647            .into_iter()
648            .chain(
649                self.queued_proposals()
650                    .flat_map(|proposal: &QueuedProposal| match proposal.proposal() {
651                        Proposal::Update(update_proposal) => {
652                            vec![update_proposal.leaf_node().credential()].into_iter()
653                        }
654                        Proposal::Add(add_proposal) => {
655                            vec![add_proposal.key_package().leaf_node().credential()].into_iter()
656                        }
657                        Proposal::GroupContextExtensions(gce_proposal) => gce_proposal
658                            .extensions()
659                            .iter()
660                            .flat_map(|extension| {
661                                match extension {
662                                    Extension::ExternalSenders(external_senders) => {
663                                        external_senders
664                                            .iter()
665                                            .map(|external_sender| external_sender.credential())
666                                            .collect()
667                                    }
668                                    _ => vec![],
669                                }
670                                .into_iter()
671                            })
672                            // TODO: ideally we wouldn't collect in between here, but the match arms
673                            //       have to all return the same type. We solve this by having them all
674                            //       be vec::IntoIter, but it would be nice if we just didn't have to
675                            //       do this.
676                            //       It might be possible to solve this by letting all match arms
677                            //       evaluate to a dyn Iterator.
678                            .collect::<Vec<_>>()
679                            .into_iter(),
680                        _ => vec![].into_iter(),
681                    }),
682            )
683    }
684
685    /// Returns `true` if the member was removed through a proposal covered by this Commit message
686    /// and `false` otherwise.
687    pub fn self_removed(&self) -> bool {
688        matches!(self.state, StagedCommitState::PublicState(_))
689    }
690
691    /// Returns the [`GroupContext`] of the staged commit state.
692    pub fn group_context(&self) -> &GroupContext {
693        match self.state {
694            StagedCommitState::PublicState(ref ps) => ps.staged_diff().group_context(),
695            StagedCommitState::GroupMember(ref gm) => gm.group_context(),
696        }
697    }
698    /// Consume this [`StagedCommit`] and return the internal [`StagedCommitState`].
699    pub(crate) fn into_state(self) -> StagedCommitState {
700        self.state
701    }
702
703    /// Returns the [`EpochAuthenticator`] of the staged commit state if the
704    /// owner of the originating group state is a member of the group. Returns
705    /// `None` otherwise.
706    pub fn epoch_authenticator(&self) -> Option<&EpochAuthenticator> {
707        if let StagedCommitState::GroupMember(ref gm) = self.state {
708            Some(gm.group_epoch_secrets.epoch_authenticator())
709        } else {
710            None
711        }
712    }
713
714    #[cfg(feature = "extensions-draft-08")]
715    pub(crate) fn safe_export_secret(
716        &mut self,
717        crypto: &impl OpenMlsCrypto,
718        component_id: u16,
719    ) -> Result<Vec<u8>, StagedSafeExportSecretError> {
720        let ciphersuite = self.group_context().ciphersuite();
721        let StagedCommitState::GroupMember(ref mut staged_commit) = self.state else {
722            return Err(StagedSafeExportSecretError::NotGroupMember);
723        };
724        let Some(application_export_tree) = staged_commit.application_export_tree.as_mut() else {
725            return Err(StagedSafeExportSecretError::Unsupported);
726        };
727        let secret =
728            application_export_tree.safe_export_secret(crypto, ciphersuite, component_id)?;
729        Ok(secret.as_slice().to_vec())
730    }
731}
732
733/// This struct is used internally by [`StagedCommit`] to encapsulate all the modified group state.
734#[derive(Debug, Serialize, Deserialize)]
735#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
736pub(crate) struct MemberStagedCommitState {
737    group_epoch_secrets: GroupEpochSecrets,
738    message_secrets: MessageSecrets,
739    staged_diff: StagedPublicGroupDiff,
740    new_keypairs: Vec<EncryptionKeyPair>,
741    new_leaf_keypair_option: Option<EncryptionKeyPair>,
742    update_path_leaf_node: Option<LeafNode>,
743    #[cfg(feature = "extensions-draft-08")]
744    #[serde(default)]
745    // This is `None` only if the group was stored using an older version of
746    // OpenMLS that did not support the application exporter.
747    application_export_tree: Option<ApplicationExportTree>,
748}
749
750impl MemberStagedCommitState {
751    pub(crate) fn new(
752        group_epoch_secrets: GroupEpochSecrets,
753        message_secrets: MessageSecrets,
754        staged_diff: StagedPublicGroupDiff,
755        new_keypairs: Vec<EncryptionKeyPair>,
756        new_leaf_keypair_option: Option<EncryptionKeyPair>,
757        update_path_leaf_node: Option<LeafNode>,
758        #[cfg(feature = "extensions-draft-08")] application_export_tree: ApplicationExportTree,
759    ) -> Self {
760        Self {
761            group_epoch_secrets,
762            message_secrets,
763            staged_diff,
764            new_keypairs,
765            new_leaf_keypair_option,
766            update_path_leaf_node,
767            #[cfg(feature = "extensions-draft-08")]
768            application_export_tree: Some(application_export_tree),
769        }
770    }
771
772    /// Get the staged [`GroupContext`].
773    pub(crate) fn group_context(&self) -> &GroupContext {
774        self.staged_diff.group_context()
775    }
776}