1use core::fmt::Debug;
2use std::mem;
3
4use openmls_traits::storage::StorageProvider;
5use serde::{Deserialize, Serialize};
6use tls_codec::Serialize as _;
7
8use super::proposal_store::{
9 QueuedAddProposal, QueuedPskProposal, QueuedRemoveProposal, QueuedUpdateProposal,
10};
11
12use super::{
13 super::errors::*, load_psks, Credential, Extension, GroupContext, GroupEpochSecrets, GroupId,
14 JoinerSecret, KeySchedule, LeafNode, LibraryError, MessageSecrets, MlsGroup, OpenMlsProvider,
15 Proposal, ProposalQueue, PskSecret, QueuedProposal, Sender,
16};
17use crate::{
18 ciphersuite::{hash_ref::ProposalRef, Secret},
19 framing::mls_auth_content::AuthenticatedContent,
20 group::public_group::{
21 diff::{apply_proposals::ApplyProposalsValues, StagedPublicGroupDiff},
22 staged_commit::PublicStagedCommitState,
23 },
24 schedule::{CommitSecret, EpochAuthenticator, EpochSecrets, InitSecret, PreSharedKeyId},
25 treesync::node::encryption_keys::EncryptionKeyPair,
26};
27
28impl MlsGroup {
29 fn derive_epoch_secrets(
30 &self,
31 provider: &impl OpenMlsProvider,
32 apply_proposals_values: ApplyProposalsValues,
33 epoch_secrets: &GroupEpochSecrets,
34 commit_secret: CommitSecret,
35 serialized_provisional_group_context: &[u8],
36 ) -> Result<EpochSecrets, StageCommitError> {
37 let joiner_secret = if let Some(ref external_init_proposal) =
40 apply_proposals_values.external_init_proposal_option
41 {
42 let external_priv = epoch_secrets
44 .external_secret()
45 .derive_external_keypair(provider.crypto(), self.ciphersuite())
46 .map_err(LibraryError::unexpected_crypto_error)?
47 .private;
48 let init_secret = InitSecret::from_kem_output(
49 provider.crypto(),
50 self.ciphersuite(),
51 self.version(),
52 &external_priv,
53 external_init_proposal.kem_output(),
54 )?;
55 JoinerSecret::new(
56 provider.crypto(),
57 self.ciphersuite(),
58 commit_secret,
59 &init_secret,
60 serialized_provisional_group_context,
61 )
62 .map_err(LibraryError::unexpected_crypto_error)?
63 } else {
64 JoinerSecret::new(
65 provider.crypto(),
66 self.ciphersuite(),
67 commit_secret,
68 epoch_secrets.init_secret(),
69 serialized_provisional_group_context,
70 )
71 .map_err(LibraryError::unexpected_crypto_error)?
72 };
73
74 let psk_secret = {
77 let psks: Vec<(&PreSharedKeyId, Secret)> = load_psks(
78 provider.storage(),
79 &self.resumption_psk_store,
80 &apply_proposals_values.presharedkeys,
81 )?;
82
83 PskSecret::new(provider.crypto(), self.ciphersuite(), psks)?
84 };
85
86 let mut key_schedule = KeySchedule::init(
88 self.ciphersuite(),
89 provider.crypto(),
90 &joiner_secret,
91 psk_secret,
92 )?;
93
94 key_schedule
95 .add_context(provider.crypto(), serialized_provisional_group_context)
96 .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?;
97 Ok(key_schedule
98 .epoch_secrets(provider.crypto(), self.ciphersuite())
99 .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?)
100 }
101
102 pub(crate) fn stage_commit(
138 &self,
139 mls_content: &AuthenticatedContent,
140 old_epoch_keypairs: Vec<EncryptionKeyPair>,
141 leaf_node_keypairs: Vec<EncryptionKeyPair>,
142 provider: &impl OpenMlsProvider,
143 ) -> Result<StagedCommit, StageCommitError> {
144 if let Sender::Member(member) = mls_content.sender() {
146 if member == &self.own_leaf_index() {
147 return Err(StageCommitError::OwnCommit);
148 }
149 }
150
151 let ciphersuite = self.ciphersuite();
152
153 let (commit, proposal_queue, sender_index) = self
154 .public_group
155 .validate_commit(mls_content, provider.crypto())?;
156
157 let mut diff = self.public_group.empty_diff();
160
161 let apply_proposals_values =
162 diff.apply_proposals(&proposal_queue, self.own_leaf_index())?;
163
164 let (commit_secret, new_keypairs, new_leaf_keypair_option, update_path_leaf_node) =
166 if let Some(path) = commit.path.clone() {
167 diff.apply_received_update_path(
170 provider.crypto(),
171 ciphersuite,
172 sender_index,
173 &path,
174 )?;
175
176 diff.update_group_context(
178 provider.crypto(),
179 apply_proposals_values.extensions.clone(),
180 )?;
181
182 if apply_proposals_values.self_removed {
184 let staged_diff = diff.into_staged_diff(provider.crypto(), ciphersuite)?;
186 let staged_state = PublicStagedCommitState::new(
187 staged_diff,
188 commit.path.as_ref().map(|path| path.leaf_node().clone()),
189 );
190 return Ok(StagedCommit::new(
191 proposal_queue,
192 StagedCommitState::PublicState(Box::new(staged_state)),
193 ));
194 }
195
196 let decryption_keypairs: Vec<&EncryptionKeyPair> = old_epoch_keypairs
197 .iter()
198 .chain(leaf_node_keypairs.iter())
199 .collect();
200
201 let (new_keypairs, commit_secret) = diff.decrypt_path(
204 provider.crypto(),
205 &decryption_keypairs,
206 self.own_leaf_index(),
207 sender_index,
208 path.nodes(),
209 &apply_proposals_values.exclusion_list(),
210 )?;
211
212 let new_leaf_keypair_option = if let Some(leaf) = diff.leaf(self.own_leaf_index()) {
218 leaf_node_keypairs.into_iter().find_map(|keypair| {
219 if leaf.encryption_key() == keypair.public_key() {
220 Some(keypair)
221 } else {
222 None
223 }
224 })
225 } else {
226 debug_assert!(false);
228 None
229 };
230
231 let update_path_leaf_node = Some(path.leaf_node().clone());
235 debug_assert_eq!(diff.leaf(sender_index), path.leaf_node().into());
236
237 (
238 commit_secret,
239 new_keypairs,
240 new_leaf_keypair_option,
241 update_path_leaf_node,
242 )
243 } else {
244 if apply_proposals_values.path_required {
245 return Err(StageCommitError::RequiredPathNotFound);
247 }
248
249 diff.update_group_context(
251 provider.crypto(),
252 apply_proposals_values.extensions.clone(),
253 )?;
254
255 (CommitSecret::zero_secret(ciphersuite), vec![], None, None)
256 };
257
258 diff.update_confirmed_transcript_hash(provider.crypto(), mls_content)?;
260
261 let received_confirmation_tag = mls_content
262 .confirmation_tag()
263 .ok_or(StageCommitError::ConfirmationTagMissing)?;
264
265 let serialized_provisional_group_context = diff
266 .group_context()
267 .tls_serialize_detached()
268 .map_err(LibraryError::missing_bound_check)?;
269
270 let (provisional_group_secrets, provisional_message_secrets) = self
271 .derive_epoch_secrets(
272 provider,
273 apply_proposals_values,
274 self.group_epoch_secrets(),
275 commit_secret,
276 &serialized_provisional_group_context,
277 )?
278 .split_secrets(
279 serialized_provisional_group_context,
280 diff.tree_size(),
281 self.own_leaf_index(),
282 );
283
284 let own_confirmation_tag = provisional_message_secrets
287 .confirmation_key()
288 .tag(
289 provider.crypto(),
290 self.ciphersuite(),
291 diff.group_context().confirmed_transcript_hash(),
292 )
293 .map_err(LibraryError::unexpected_crypto_error)?;
294 if &own_confirmation_tag != received_confirmation_tag {
295 log::error!("Confirmation tag mismatch");
296 log_crypto!(trace, " Got: {:x?}", received_confirmation_tag);
297 log_crypto!(trace, " Expected: {:x?}", own_confirmation_tag);
298 if !crate::skip_validation::is_disabled::confirmation_tag() {
305 return Err(StageCommitError::ConfirmationTagMismatch);
306 }
307 }
308
309 diff.update_interim_transcript_hash(ciphersuite, provider.crypto(), own_confirmation_tag)?;
310
311 let staged_diff = diff.into_staged_diff(provider.crypto(), ciphersuite)?;
312 let staged_commit_state =
313 StagedCommitState::GroupMember(Box::new(MemberStagedCommitState::new(
314 provisional_group_secrets,
315 provisional_message_secrets,
316 staged_diff,
317 new_keypairs,
318 new_leaf_keypair_option,
319 update_path_leaf_node,
320 )));
321
322 Ok(StagedCommit::new(proposal_queue, staged_commit_state))
323 }
324
325 pub(crate) fn merge_commit<Provider: OpenMlsProvider>(
331 &mut self,
332 provider: &Provider,
333 staged_commit: StagedCommit,
334 ) -> Result<(), MergeCommitError<Provider::StorageError>> {
335 let old_epoch_keypairs = self
338 .read_epoch_keypairs(provider.storage())
339 .map_err(MergeCommitError::StorageError)?;
340 match staged_commit.state {
341 StagedCommitState::PublicState(staged_state) => {
342 self.public_group
343 .merge_diff(staged_state.into_staged_diff());
344 self.store(provider.storage())
345 .map_err(MergeCommitError::StorageError)?;
346 Ok(())
347 }
348 StagedCommitState::GroupMember(state) => {
349 let past_epoch = self.context().epoch();
351 let leaves = self.public_group().members().collect();
353 self.group_epoch_secrets = state.group_epoch_secrets;
356
357 let mut message_secrets = state.message_secrets;
359 mem::swap(
360 &mut message_secrets,
361 self.message_secrets_store.message_secrets_mut(),
362 );
363 self.message_secrets_store
364 .add(past_epoch, message_secrets, leaves);
365
366 self.public_group.merge_diff(state.staged_diff);
367
368 let leaf_keypair = if let Some(keypair) = &state.new_leaf_keypair_option {
369 vec![keypair.clone()]
370 } else {
371 vec![]
372 };
373
374 let new_owned_encryption_keys = self
376 .public_group()
377 .owned_encryption_keys(self.own_leaf_index());
378 let epoch_keypairs: Vec<EncryptionKeyPair> = old_epoch_keypairs
380 .into_iter()
381 .chain(state.new_keypairs)
382 .chain(leaf_keypair)
383 .filter(|keypair| new_owned_encryption_keys.contains(keypair.public_key()))
384 .collect();
385
386 debug_assert_eq!(new_owned_encryption_keys.len(), epoch_keypairs.len());
388 if new_owned_encryption_keys.len() != epoch_keypairs.len() {
389 return Err(LibraryError::custom(
390 "We should have all the private key material we need.",
391 )
392 .into());
393 }
394
395 let storage = provider.storage();
397 let group_id = self.group_id();
398
399 self.public_group
400 .store(storage)
401 .map_err(MergeCommitError::StorageError)?;
402 storage
403 .write_group_epoch_secrets(group_id, &self.group_epoch_secrets)
404 .map_err(MergeCommitError::StorageError)?;
405 storage
406 .write_message_secrets(group_id, &self.message_secrets_store)
407 .map_err(MergeCommitError::StorageError)?;
408
409 self.store_epoch_keypairs(storage, epoch_keypairs.as_slice())
411 .map_err(MergeCommitError::StorageError)?;
412
413 self.delete_previous_epoch_keypairs(storage)
415 .map_err(MergeCommitError::StorageError)?;
416 if let Some(keypair) = state.new_leaf_keypair_option {
417 keypair
418 .delete(storage)
419 .map_err(MergeCommitError::StorageError)?;
420 }
421
422 storage
424 .clear_proposal_queue::<GroupId, ProposalRef>(group_id)
425 .map_err(MergeCommitError::StorageError)?;
426 self.proposal_store_mut().empty();
427
428 Ok(())
429 }
430 }
431 }
432}
433
434#[derive(Debug, Serialize, Deserialize)]
435#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
436pub(crate) enum StagedCommitState {
437 PublicState(Box<PublicStagedCommitState>),
438 GroupMember(Box<MemberStagedCommitState>),
439}
440
441#[derive(Debug, Serialize, Deserialize)]
443#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
444pub struct StagedCommit {
445 staged_proposal_queue: ProposalQueue,
446 state: StagedCommitState,
447}
448
449impl StagedCommit {
450 pub(crate) fn new(staged_proposal_queue: ProposalQueue, state: StagedCommitState) -> Self {
453 StagedCommit {
454 staged_proposal_queue,
455 state,
456 }
457 }
458
459 pub fn add_proposals(&self) -> impl Iterator<Item = QueuedAddProposal> {
461 self.staged_proposal_queue.add_proposals()
462 }
463
464 pub fn remove_proposals(&self) -> impl Iterator<Item = QueuedRemoveProposal> {
466 self.staged_proposal_queue.remove_proposals()
467 }
468
469 pub fn update_proposals(&self) -> impl Iterator<Item = QueuedUpdateProposal> {
471 self.staged_proposal_queue.update_proposals()
472 }
473
474 pub fn psk_proposals(&self) -> impl Iterator<Item = QueuedPskProposal> {
476 self.staged_proposal_queue.psk_proposals()
477 }
478
479 pub fn queued_proposals(&self) -> impl Iterator<Item = &QueuedProposal> {
481 self.staged_proposal_queue.queued_proposals()
482 }
483
484 pub fn update_path_leaf_node(&self) -> Option<&LeafNode> {
486 match self.state {
487 StagedCommitState::PublicState(ref public_state) => {
488 public_state.update_path_leaf_node()
489 }
490 StagedCommitState::GroupMember(ref group_member_state) => {
491 group_member_state.update_path_leaf_node.as_ref()
492 }
493 }
494 }
495
496 pub fn credentials_to_verify(&self) -> impl Iterator<Item = &Credential> {
498 let update_path_leaf_node_cred = if let Some(node) = self.update_path_leaf_node() {
499 vec![node.credential()]
500 } else {
501 vec![]
502 };
503
504 update_path_leaf_node_cred
505 .into_iter()
506 .chain(
507 self.queued_proposals()
508 .flat_map(|proposal: &QueuedProposal| match proposal.proposal() {
509 Proposal::Update(update_proposal) => {
510 vec![update_proposal.leaf_node().credential()].into_iter()
511 }
512 Proposal::Add(add_proposal) => {
513 vec![add_proposal.key_package().leaf_node().credential()].into_iter()
514 }
515 Proposal::GroupContextExtensions(gce_proposal) => gce_proposal
516 .extensions()
517 .iter()
518 .flat_map(|extension| {
519 match extension {
520 Extension::ExternalSenders(external_senders) => {
521 external_senders
522 .iter()
523 .map(|external_sender| external_sender.credential())
524 .collect()
525 }
526 _ => vec![],
527 }
528 .into_iter()
529 })
530 .collect::<Vec<_>>()
537 .into_iter(),
538 _ => vec![].into_iter(),
539 }),
540 )
541 }
542
543 pub fn self_removed(&self) -> bool {
546 matches!(self.state, StagedCommitState::PublicState(_))
547 }
548
549 pub fn group_context(&self) -> &GroupContext {
551 match self.state {
552 StagedCommitState::PublicState(ref ps) => ps.staged_diff().group_context(),
553 StagedCommitState::GroupMember(ref gm) => gm.group_context(),
554 }
555 }
556
557 pub(crate) fn into_state(self) -> StagedCommitState {
559 self.state
560 }
561
562 pub fn epoch_authenticator(&self) -> Option<&EpochAuthenticator> {
566 if let StagedCommitState::GroupMember(ref gm) = self.state {
567 Some(gm.group_epoch_secrets.epoch_authenticator())
568 } else {
569 None
570 }
571 }
572}
573
574#[derive(Debug, Serialize, Deserialize)]
576#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
577pub(crate) struct MemberStagedCommitState {
578 group_epoch_secrets: GroupEpochSecrets,
579 message_secrets: MessageSecrets,
580 staged_diff: StagedPublicGroupDiff,
581 new_keypairs: Vec<EncryptionKeyPair>,
582 new_leaf_keypair_option: Option<EncryptionKeyPair>,
583 update_path_leaf_node: Option<LeafNode>,
584}
585
586impl MemberStagedCommitState {
587 pub(crate) fn new(
588 group_epoch_secrets: GroupEpochSecrets,
589 message_secrets: MessageSecrets,
590 staged_diff: StagedPublicGroupDiff,
591 new_keypairs: Vec<EncryptionKeyPair>,
592 new_leaf_keypair_option: Option<EncryptionKeyPair>,
593 update_path_leaf_node: Option<LeafNode>,
594 ) -> Self {
595 Self {
596 group_epoch_secrets,
597 message_secrets,
598 staged_diff,
599 new_keypairs,
600 new_leaf_keypair_option,
601 update_path_leaf_node,
602 }
603 }
604
605 pub(crate) fn group_context(&self) -> &GroupContext {
607 self.staged_diff.group_context()
608 }
609}