openmls/group/mls_group/
staged_commit.rs

1use core::fmt::Debug;
2use std::mem;
3
4use openmls_traits::storage::StorageProvider;
5use serde::{Deserialize, Serialize};
6use tls_codec::Serialize as _;
7
8use super::proposal_store::{
9    QueuedAddProposal, QueuedPskProposal, QueuedRemoveProposal, QueuedUpdateProposal,
10};
11
12use super::{
13    super::errors::*, load_psks, Credential, Extension, GroupContext, GroupEpochSecrets, GroupId,
14    JoinerSecret, KeySchedule, LeafNode, LibraryError, MessageSecrets, MlsGroup, OpenMlsProvider,
15    Proposal, ProposalQueue, PskSecret, QueuedProposal, Sender,
16};
17use crate::{
18    ciphersuite::{hash_ref::ProposalRef, Secret},
19    framing::mls_auth_content::AuthenticatedContent,
20    group::public_group::{
21        diff::{apply_proposals::ApplyProposalsValues, StagedPublicGroupDiff},
22        staged_commit::PublicStagedCommitState,
23    },
24    schedule::{CommitSecret, EpochAuthenticator, EpochSecrets, InitSecret, PreSharedKeyId},
25    treesync::node::encryption_keys::EncryptionKeyPair,
26};
27
28impl MlsGroup {
29    fn derive_epoch_secrets(
30        &self,
31        provider: &impl OpenMlsProvider,
32        apply_proposals_values: ApplyProposalsValues,
33        epoch_secrets: &GroupEpochSecrets,
34        commit_secret: CommitSecret,
35        serialized_provisional_group_context: &[u8],
36    ) -> Result<EpochSecrets, StageCommitError> {
37        // Check if we need to include the init secret from an external commit
38        // we applied earlier or if we use the one from the previous epoch.
39        let joiner_secret = if let Some(ref external_init_proposal) =
40            apply_proposals_values.external_init_proposal_option
41        {
42            // Decrypt the content and derive the external init secret.
43            let external_priv = epoch_secrets
44                .external_secret()
45                .derive_external_keypair(provider.crypto(), self.ciphersuite())
46                .map_err(LibraryError::unexpected_crypto_error)?
47                .private;
48            let init_secret = InitSecret::from_kem_output(
49                provider.crypto(),
50                self.ciphersuite(),
51                self.version(),
52                &external_priv,
53                external_init_proposal.kem_output(),
54            )?;
55            JoinerSecret::new(
56                provider.crypto(),
57                self.ciphersuite(),
58                commit_secret,
59                &init_secret,
60                serialized_provisional_group_context,
61            )
62            .map_err(LibraryError::unexpected_crypto_error)?
63        } else {
64            JoinerSecret::new(
65                provider.crypto(),
66                self.ciphersuite(),
67                commit_secret,
68                epoch_secrets.init_secret(),
69                serialized_provisional_group_context,
70            )
71            .map_err(LibraryError::unexpected_crypto_error)?
72        };
73
74        // Prepare the PskSecret
75        // Fails if PSKs are missing ([valn1205](https://validation.openmls.tech/#valn1205))
76        let psk_secret = {
77            let psks: Vec<(&PreSharedKeyId, Secret)> = load_psks(
78                provider.storage(),
79                &self.resumption_psk_store,
80                &apply_proposals_values.presharedkeys,
81            )?;
82
83            PskSecret::new(provider.crypto(), self.ciphersuite(), psks)?
84        };
85
86        // Create key schedule
87        let mut key_schedule = KeySchedule::init(
88            self.ciphersuite(),
89            provider.crypto(),
90            &joiner_secret,
91            psk_secret,
92        )?;
93
94        key_schedule
95            .add_context(provider.crypto(), serialized_provisional_group_context)
96            .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?;
97        Ok(key_schedule
98            .epoch_secrets(provider.crypto(), self.ciphersuite())
99            .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?)
100    }
101
102    /// Stages a commit message that was sent by another group member. This
103    /// function does the following:
104    ///  - Applies the proposals covered by the commit to the tree
105    ///  - Applies the (optional) update path to the tree
106    ///  - Decrypts and calculates the path secrets
107    ///  - Initializes the key schedule for epoch rollover
108    ///  - Verifies the confirmation tag
109    ///
110    /// Returns a [StagedCommit] that can be inspected and later merged into the
111    /// group state with [MlsGroup::merge_commit()] This function does the
112    /// following checks:
113    ///  - ValSem101
114    ///  - ValSem102
115    ///  - ValSem104
116    ///  - ValSem105
117    ///  - ValSem106
118    ///  - ValSem107
119    ///  - ValSem108
120    ///  - ValSem110
121    ///  - ValSem111
122    ///  - ValSem112
123    ///  - ValSem113: All Proposals: The proposal type must be supported by all
124    ///    members of the group
125    ///  - ValSem200
126    ///  - ValSem201
127    ///  - ValSem202: Path must be the right length
128    ///  - ValSem203: Path secrets must decrypt correctly
129    ///  - ValSem204: Public keys from Path must be verified and match the
130    ///    private keys from the direct path
131    ///  - ValSem205
132    ///  - ValSem240
133    ///  - ValSem241
134    ///  - ValSem242
135    ///  - ValSem244 Returns an error if the given commit was sent by the owner
136    ///    of this group.
137    pub(crate) fn stage_commit(
138        &self,
139        mls_content: &AuthenticatedContent,
140        old_epoch_keypairs: Vec<EncryptionKeyPair>,
141        leaf_node_keypairs: Vec<EncryptionKeyPair>,
142        provider: &impl OpenMlsProvider,
143    ) -> Result<StagedCommit, StageCommitError> {
144        // Check that the sender is another member of the group
145        if let Sender::Member(member) = mls_content.sender() {
146            if member == &self.own_leaf_index() {
147                return Err(StageCommitError::OwnCommit);
148            }
149        }
150
151        let ciphersuite = self.ciphersuite();
152
153        let (commit, proposal_queue, sender_index) = self
154            .public_group
155            .validate_commit(mls_content, provider.crypto())?;
156
157        // Create the provisional public group state (including the tree and
158        // group context) and apply proposals.
159        let mut diff = self.public_group.empty_diff();
160
161        let apply_proposals_values =
162            diff.apply_proposals(&proposal_queue, self.own_leaf_index())?;
163
164        // Determine if Commit has a path
165        let (commit_secret, new_keypairs, new_leaf_keypair_option, update_path_leaf_node) =
166            if let Some(path) = commit.path.clone() {
167                // Update the public group
168                // ValSem202: Path must be the right length
169                diff.apply_received_update_path(
170                    provider.crypto(),
171                    ciphersuite,
172                    sender_index,
173                    &path,
174                )?;
175
176                // Update group context
177                diff.update_group_context(
178                    provider.crypto(),
179                    apply_proposals_values.extensions.clone(),
180                )?;
181
182                // Check if we were removed from the group
183                if apply_proposals_values.self_removed {
184                    // If so, we return here, because we can't decrypt the path
185                    let staged_diff = diff.into_staged_diff(provider.crypto(), ciphersuite)?;
186                    let staged_state = PublicStagedCommitState::new(
187                        staged_diff,
188                        commit.path.as_ref().map(|path| path.leaf_node().clone()),
189                    );
190                    return Ok(StagedCommit::new(
191                        proposal_queue,
192                        StagedCommitState::PublicState(Box::new(staged_state)),
193                    ));
194                }
195
196                let decryption_keypairs: Vec<&EncryptionKeyPair> = old_epoch_keypairs
197                    .iter()
198                    .chain(leaf_node_keypairs.iter())
199                    .collect();
200
201                // ValSem203: Path secrets must decrypt correctly
202                // ValSem204: Public keys from Path must be verified and match the private keys from the direct path
203                let (new_keypairs, commit_secret) = diff.decrypt_path(
204                    provider.crypto(),
205                    &decryption_keypairs,
206                    self.own_leaf_index(),
207                    sender_index,
208                    path.nodes(),
209                    &apply_proposals_values.exclusion_list(),
210                )?;
211
212                // Check if one of our update proposals was applied. If so, we
213                // need to store that keypair separately, because after merging
214                // it needs to be removed from the key store separately and in
215                // addition to the removal of the keypairs of the previous
216                // epoch.
217                let new_leaf_keypair_option = if let Some(leaf) = diff.leaf(self.own_leaf_index()) {
218                    leaf_node_keypairs.into_iter().find_map(|keypair| {
219                        if leaf.encryption_key() == keypair.public_key() {
220                            Some(keypair)
221                        } else {
222                            None
223                        }
224                    })
225                } else {
226                    // We should have an own leaf at this point.
227                    debug_assert!(false);
228                    None
229                };
230
231                // Return the leaf node in the update path so the credential can be validated.
232                // Since the diff has already been updated, this should be the same as the leaf
233                // at the sender index.
234                let update_path_leaf_node = Some(path.leaf_node().clone());
235                debug_assert_eq!(diff.leaf(sender_index), path.leaf_node().into());
236
237                (
238                    commit_secret,
239                    new_keypairs,
240                    new_leaf_keypair_option,
241                    update_path_leaf_node,
242                )
243            } else {
244                if apply_proposals_values.path_required {
245                    // ValSem201
246                    return Err(StageCommitError::RequiredPathNotFound);
247                }
248
249                // Even if there is no path, we have to update the group context.
250                diff.update_group_context(
251                    provider.crypto(),
252                    apply_proposals_values.extensions.clone(),
253                )?;
254
255                (CommitSecret::zero_secret(ciphersuite), vec![], None, None)
256            };
257
258        // Update the confirmed transcript hash before we compute the confirmation tag.
259        diff.update_confirmed_transcript_hash(provider.crypto(), mls_content)?;
260
261        let received_confirmation_tag = mls_content
262            .confirmation_tag()
263            .ok_or(StageCommitError::ConfirmationTagMissing)?;
264
265        let serialized_provisional_group_context = diff
266            .group_context()
267            .tls_serialize_detached()
268            .map_err(LibraryError::missing_bound_check)?;
269
270        let (provisional_group_secrets, provisional_message_secrets) = self
271            .derive_epoch_secrets(
272                provider,
273                apply_proposals_values,
274                self.group_epoch_secrets(),
275                commit_secret,
276                &serialized_provisional_group_context,
277            )?
278            .split_secrets(
279                serialized_provisional_group_context,
280                diff.tree_size(),
281                self.own_leaf_index(),
282            );
283
284        // Verify confirmation tag
285        // ValSem205
286        let own_confirmation_tag = provisional_message_secrets
287            .confirmation_key()
288            .tag(
289                provider.crypto(),
290                self.ciphersuite(),
291                diff.group_context().confirmed_transcript_hash(),
292            )
293            .map_err(LibraryError::unexpected_crypto_error)?;
294        if &own_confirmation_tag != received_confirmation_tag {
295            log::error!("Confirmation tag mismatch");
296            log_crypto!(trace, "  Got:      {:x?}", received_confirmation_tag);
297            log_crypto!(trace, "  Expected: {:x?}", own_confirmation_tag);
298            // TODO: We have tests expecting this error.
299            //       They need to be rewritten.
300            // debug_assert!(false, "Confirmation tag mismatch");
301
302            // in some tests we need to be able to proceed despite the tag being wrong,
303            // e.g. to test whether a later validation check is performed correctly.
304            if !crate::skip_validation::is_disabled::confirmation_tag() {
305                return Err(StageCommitError::ConfirmationTagMismatch);
306            }
307        }
308
309        diff.update_interim_transcript_hash(ciphersuite, provider.crypto(), own_confirmation_tag)?;
310
311        let staged_diff = diff.into_staged_diff(provider.crypto(), ciphersuite)?;
312        let staged_commit_state =
313            StagedCommitState::GroupMember(Box::new(MemberStagedCommitState::new(
314                provisional_group_secrets,
315                provisional_message_secrets,
316                staged_diff,
317                new_keypairs,
318                new_leaf_keypair_option,
319                update_path_leaf_node,
320            )));
321
322        Ok(StagedCommit::new(proposal_queue, staged_commit_state))
323    }
324
325    /// Merges a [StagedCommit] into the group state and optionally return a [`SecretTree`]
326    /// from the previous epoch. The secret tree is returned if the Commit does not contain a self removal.
327    ///
328    /// This function should not fail and only returns a [`Result`], because it
329    /// might throw a `LibraryError`.
330    pub(crate) fn merge_commit<Provider: OpenMlsProvider>(
331        &mut self,
332        provider: &Provider,
333        staged_commit: StagedCommit,
334    ) -> Result<(), MergeCommitError<Provider::StorageError>> {
335        // Get all keypairs from the old epoch, so we can later store the ones
336        // that are still relevant in the new epoch.
337        let old_epoch_keypairs = self
338            .read_epoch_keypairs(provider.storage())
339            .map_err(MergeCommitError::StorageError)?;
340        match staged_commit.state {
341            StagedCommitState::PublicState(staged_state) => {
342                self.public_group
343                    .merge_diff(staged_state.into_staged_diff());
344                self.store(provider.storage())
345                    .map_err(MergeCommitError::StorageError)?;
346                Ok(())
347            }
348            StagedCommitState::GroupMember(state) => {
349                // Save the past epoch
350                let past_epoch = self.context().epoch();
351                // Get all the full leaves
352                let leaves = self.public_group().members().collect();
353                // Merge the staged commit into the group state and store the secret tree from the
354                // previous epoch in the message secrets store.
355                self.group_epoch_secrets = state.group_epoch_secrets;
356
357                // Replace the previous message secrets with the new ones and return the previous message secrets
358                let mut message_secrets = state.message_secrets;
359                mem::swap(
360                    &mut message_secrets,
361                    self.message_secrets_store.message_secrets_mut(),
362                );
363                self.message_secrets_store
364                    .add(past_epoch, message_secrets, leaves);
365
366                self.public_group.merge_diff(state.staged_diff);
367
368                let leaf_keypair = if let Some(keypair) = &state.new_leaf_keypair_option {
369                    vec![keypair.clone()]
370                } else {
371                    vec![]
372                };
373
374                // Figure out which keys we need in the new epoch.
375                let new_owned_encryption_keys = self
376                    .public_group()
377                    .owned_encryption_keys(self.own_leaf_index());
378                // From the old and new keys, keep the ones that are still relevant in the new epoch.
379                let epoch_keypairs: Vec<EncryptionKeyPair> = old_epoch_keypairs
380                    .into_iter()
381                    .chain(state.new_keypairs)
382                    .chain(leaf_keypair)
383                    .filter(|keypair| new_owned_encryption_keys.contains(keypair.public_key()))
384                    .collect();
385
386                // We should have private keys for all owned encryption keys.
387                debug_assert_eq!(new_owned_encryption_keys.len(), epoch_keypairs.len());
388                if new_owned_encryption_keys.len() != epoch_keypairs.len() {
389                    return Err(LibraryError::custom(
390                        "We should have all the private key material we need.",
391                    )
392                    .into());
393                }
394
395                // Store the updated group state
396                let storage = provider.storage();
397                let group_id = self.group_id();
398
399                self.public_group
400                    .store(storage)
401                    .map_err(MergeCommitError::StorageError)?;
402                storage
403                    .write_group_epoch_secrets(group_id, &self.group_epoch_secrets)
404                    .map_err(MergeCommitError::StorageError)?;
405                storage
406                    .write_message_secrets(group_id, &self.message_secrets_store)
407                    .map_err(MergeCommitError::StorageError)?;
408
409                // Store the relevant keys under the new epoch
410                self.store_epoch_keypairs(storage, epoch_keypairs.as_slice())
411                    .map_err(MergeCommitError::StorageError)?;
412
413                // Delete the old keys.
414                self.delete_previous_epoch_keypairs(storage)
415                    .map_err(MergeCommitError::StorageError)?;
416                if let Some(keypair) = state.new_leaf_keypair_option {
417                    keypair
418                        .delete(storage)
419                        .map_err(MergeCommitError::StorageError)?;
420                }
421
422                // Empty the proposal store
423                storage
424                    .clear_proposal_queue::<GroupId, ProposalRef>(group_id)
425                    .map_err(MergeCommitError::StorageError)?;
426                self.proposal_store_mut().empty();
427
428                Ok(())
429            }
430        }
431    }
432}
433
434#[derive(Debug, Serialize, Deserialize)]
435#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
436pub(crate) enum StagedCommitState {
437    PublicState(Box<PublicStagedCommitState>),
438    GroupMember(Box<MemberStagedCommitState>),
439}
440
441/// Contains the changes from a commit to the group state.
442#[derive(Debug, Serialize, Deserialize)]
443#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
444pub struct StagedCommit {
445    staged_proposal_queue: ProposalQueue,
446    state: StagedCommitState,
447}
448
449impl StagedCommit {
450    /// Create a new [`StagedCommit`] from the provisional group state created
451    /// during the commit process.
452    pub(crate) fn new(staged_proposal_queue: ProposalQueue, state: StagedCommitState) -> Self {
453        StagedCommit {
454            staged_proposal_queue,
455            state,
456        }
457    }
458
459    /// Returns the Add proposals that are covered by the Commit message as in iterator over [QueuedAddProposal].
460    pub fn add_proposals(&self) -> impl Iterator<Item = QueuedAddProposal> {
461        self.staged_proposal_queue.add_proposals()
462    }
463
464    /// Returns the Remove proposals that are covered by the Commit message as in iterator over [QueuedRemoveProposal].
465    pub fn remove_proposals(&self) -> impl Iterator<Item = QueuedRemoveProposal> {
466        self.staged_proposal_queue.remove_proposals()
467    }
468
469    /// Returns the Update proposals that are covered by the Commit message as in iterator over [QueuedUpdateProposal].
470    pub fn update_proposals(&self) -> impl Iterator<Item = QueuedUpdateProposal> {
471        self.staged_proposal_queue.update_proposals()
472    }
473
474    /// Returns the PresharedKey proposals that are covered by the Commit message as in iterator over [QueuedPskProposal].
475    pub fn psk_proposals(&self) -> impl Iterator<Item = QueuedPskProposal> {
476        self.staged_proposal_queue.psk_proposals()
477    }
478
479    /// Returns an iterator over all [`QueuedProposal`]s.
480    pub fn queued_proposals(&self) -> impl Iterator<Item = &QueuedProposal> {
481        self.staged_proposal_queue.queued_proposals()
482    }
483
484    /// Returns the leaf node of the (optional) update path.
485    pub fn update_path_leaf_node(&self) -> Option<&LeafNode> {
486        match self.state {
487            StagedCommitState::PublicState(ref public_state) => {
488                public_state.update_path_leaf_node()
489            }
490            StagedCommitState::GroupMember(ref group_member_state) => {
491                group_member_state.update_path_leaf_node.as_ref()
492            }
493        }
494    }
495
496    /// Returns the credentials that the caller needs to verify are valid.
497    pub fn credentials_to_verify(&self) -> impl Iterator<Item = &Credential> {
498        let update_path_leaf_node_cred = if let Some(node) = self.update_path_leaf_node() {
499            vec![node.credential()]
500        } else {
501            vec![]
502        };
503
504        update_path_leaf_node_cred
505            .into_iter()
506            .chain(
507                self.queued_proposals()
508                    .flat_map(|proposal: &QueuedProposal| match proposal.proposal() {
509                        Proposal::Update(update_proposal) => {
510                            vec![update_proposal.leaf_node().credential()].into_iter()
511                        }
512                        Proposal::Add(add_proposal) => {
513                            vec![add_proposal.key_package().leaf_node().credential()].into_iter()
514                        }
515                        Proposal::GroupContextExtensions(gce_proposal) => gce_proposal
516                            .extensions()
517                            .iter()
518                            .flat_map(|extension| {
519                                match extension {
520                                    Extension::ExternalSenders(external_senders) => {
521                                        external_senders
522                                            .iter()
523                                            .map(|external_sender| external_sender.credential())
524                                            .collect()
525                                    }
526                                    _ => vec![],
527                                }
528                                .into_iter()
529                            })
530                            // TODO: ideally we wouldn't collect in between here, but the match arms
531                            //       have to all return the same type. We solve this by having them all
532                            //       be vec::IntoIter, but it would be nice if we just didn't have to
533                            //       do this.
534                            //       It might be possible to solve this by letting all match arms
535                            //       evaluate to a dyn Iterator.
536                            .collect::<Vec<_>>()
537                            .into_iter(),
538                        _ => vec![].into_iter(),
539                    }),
540            )
541    }
542
543    /// Returns `true` if the member was removed through a proposal covered by this Commit message
544    /// and `false` otherwise.
545    pub fn self_removed(&self) -> bool {
546        matches!(self.state, StagedCommitState::PublicState(_))
547    }
548
549    /// Returns the [`GroupContext`] of the staged commit state.
550    pub fn group_context(&self) -> &GroupContext {
551        match self.state {
552            StagedCommitState::PublicState(ref ps) => ps.staged_diff().group_context(),
553            StagedCommitState::GroupMember(ref gm) => gm.group_context(),
554        }
555    }
556
557    /// Consume this [`StagedCommit`] and return the internal [`StagedCommitState`].
558    pub(crate) fn into_state(self) -> StagedCommitState {
559        self.state
560    }
561
562    /// Returns the [`EpochAuthenticator`] of the staged commit state if the
563    /// owner of the originating group state is a member of the group. Returns
564    /// `None` otherwise.
565    pub fn epoch_authenticator(&self) -> Option<&EpochAuthenticator> {
566        if let StagedCommitState::GroupMember(ref gm) = self.state {
567            Some(gm.group_epoch_secrets.epoch_authenticator())
568        } else {
569            None
570        }
571    }
572}
573
574/// This struct is used internally by [StagedCommit] to encapsulate all the modified group state.
575#[derive(Debug, Serialize, Deserialize)]
576#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
577pub(crate) struct MemberStagedCommitState {
578    group_epoch_secrets: GroupEpochSecrets,
579    message_secrets: MessageSecrets,
580    staged_diff: StagedPublicGroupDiff,
581    new_keypairs: Vec<EncryptionKeyPair>,
582    new_leaf_keypair_option: Option<EncryptionKeyPair>,
583    update_path_leaf_node: Option<LeafNode>,
584}
585
586impl MemberStagedCommitState {
587    pub(crate) fn new(
588        group_epoch_secrets: GroupEpochSecrets,
589        message_secrets: MessageSecrets,
590        staged_diff: StagedPublicGroupDiff,
591        new_keypairs: Vec<EncryptionKeyPair>,
592        new_leaf_keypair_option: Option<EncryptionKeyPair>,
593        update_path_leaf_node: Option<LeafNode>,
594    ) -> Self {
595        Self {
596            group_epoch_secrets,
597            message_secrets,
598            staged_diff,
599            new_keypairs,
600            new_leaf_keypair_option,
601            update_path_leaf_node,
602        }
603    }
604
605    /// Get the staged [`GroupContext`].
606    pub(crate) fn group_context(&self) -> &GroupContext {
607        self.staged_diff.group_context()
608    }
609}