1use core::fmt::Debug;
2use std::mem;
3
4#[cfg(feature = "extensions-draft-08")]
5use openmls_traits::crypto::OpenMlsCrypto;
6use openmls_traits::storage::StorageProvider as _;
7use serde::{Deserialize, Serialize};
8use tls_codec::Serialize as _;
9
10use super::proposal_store::{
11 QueuedAddProposal, QueuedPskProposal, QueuedRemoveProposal, QueuedUpdateProposal,
12};
13
14use super::{
15 super::errors::*, load_psks, Credential, Extension, GroupContext, GroupEpochSecrets, GroupId,
16 JoinerSecret, KeySchedule, LeafNode, LibraryError, MessageSecrets, MlsGroup, OpenMlsProvider,
17 Proposal, ProposalQueue, PskSecret, QueuedProposal, Sender,
18};
19#[cfg(feature = "extensions-draft-08")]
20use crate::schedule::application_export_tree::ApplicationExportTree;
21
22use crate::{
23 ciphersuite::{hash_ref::ProposalRef, Secret},
24 framing::mls_auth_content::AuthenticatedContent,
25 group::public_group::{
26 diff::{apply_proposals::ApplyProposalsValues, StagedPublicGroupDiff},
27 staged_commit::PublicStagedCommitState,
28 },
29 schedule::{CommitSecret, EpochAuthenticator, EpochSecretsResult, InitSecret, PreSharedKeyId},
30 treesync::node::encryption_keys::EncryptionKeyPair,
31};
32
33impl MlsGroup {
34 fn derive_epoch_secrets(
35 &self,
36 provider: &impl OpenMlsProvider,
37 apply_proposals_values: ApplyProposalsValues,
38 epoch_secrets: &GroupEpochSecrets,
39 commit_secret: CommitSecret,
40 serialized_provisional_group_context: &[u8],
41 ) -> Result<EpochSecretsResult, StageCommitError> {
42 let joiner_secret = if let Some(ref external_init_proposal) =
45 apply_proposals_values.external_init_proposal_option
46 {
47 let external_priv = epoch_secrets
49 .external_secret()
50 .derive_external_keypair(provider.crypto(), self.ciphersuite())
51 .map_err(LibraryError::unexpected_crypto_error)?
52 .private;
53 let init_secret = InitSecret::from_kem_output(
54 provider.crypto(),
55 self.ciphersuite(),
56 self.version(),
57 &external_priv,
58 external_init_proposal.kem_output(),
59 )?;
60 JoinerSecret::new(
61 provider.crypto(),
62 self.ciphersuite(),
63 commit_secret,
64 &init_secret,
65 serialized_provisional_group_context,
66 )
67 .map_err(LibraryError::unexpected_crypto_error)?
68 } else {
69 JoinerSecret::new(
70 provider.crypto(),
71 self.ciphersuite(),
72 commit_secret,
73 epoch_secrets.init_secret(),
74 serialized_provisional_group_context,
75 )
76 .map_err(LibraryError::unexpected_crypto_error)?
77 };
78
79 let psk_secret = {
82 let psks: Vec<(&PreSharedKeyId, Secret)> = load_psks(
83 provider.storage(),
84 &self.resumption_psk_store,
85 &apply_proposals_values.presharedkeys,
86 )?;
87
88 PskSecret::new(provider.crypto(), self.ciphersuite(), psks)?
89 };
90
91 let mut key_schedule = KeySchedule::init(
93 self.ciphersuite(),
94 provider.crypto(),
95 &joiner_secret,
96 psk_secret,
97 )?;
98
99 key_schedule
100 .add_context(provider.crypto(), serialized_provisional_group_context)
101 .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?;
102 Ok(key_schedule
103 .epoch_secrets(provider.crypto(), self.ciphersuite())
104 .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?)
105 }
106
107 pub(crate) fn stage_commit(
146 &self,
147 mls_content: &AuthenticatedContent,
148 old_epoch_keypairs: Vec<EncryptionKeyPair>,
149 leaf_node_keypairs: Vec<EncryptionKeyPair>,
150 provider: &impl OpenMlsProvider,
151 ) -> Result<StagedCommit, StageCommitError> {
152 if let Sender::Member(member) = mls_content.sender() {
154 if member == &self.own_leaf_index() {
155 return Err(StageCommitError::OwnCommit);
156 }
157 }
158
159 let ciphersuite = self.ciphersuite();
160
161 let (commit, proposal_queue, sender_index) = self
162 .public_group
163 .validate_commit(mls_content, provider.crypto())?;
164
165 let mut diff = self.public_group.empty_diff();
168
169 let apply_proposals_values =
170 diff.apply_proposals(&proposal_queue, self.own_leaf_index())?;
171
172 let (commit_secret, new_keypairs, new_leaf_keypair_option, update_path_leaf_node) =
174 if let Some(path) = commit.path.clone() {
175 diff.apply_received_update_path(
178 provider.crypto(),
179 ciphersuite,
180 sender_index,
181 &path,
182 )?;
183
184 diff.update_group_context(
186 provider.crypto(),
187 apply_proposals_values.extensions.clone(),
188 )?;
189
190 if apply_proposals_values.self_removed {
192 let staged_diff = diff.into_staged_diff(provider.crypto(), ciphersuite)?;
194 let staged_state = PublicStagedCommitState::new(
195 staged_diff,
196 commit.path.as_ref().map(|path| path.leaf_node().clone()),
197 );
198 let staged_commit = StagedCommit::new(
199 proposal_queue,
200 StagedCommitState::PublicState(Box::new(staged_state)),
201 );
202 return Ok(staged_commit);
203 }
204
205 let decryption_keypairs: Vec<&EncryptionKeyPair> = old_epoch_keypairs
206 .iter()
207 .chain(leaf_node_keypairs.iter())
208 .collect();
209
210 let (new_keypairs, commit_secret) = diff.decrypt_path(
213 provider.crypto(),
214 &decryption_keypairs,
215 self.own_leaf_index(),
216 sender_index,
217 path.nodes(),
218 &apply_proposals_values.exclusion_list(),
219 )?;
220
221 let new_leaf_keypair_option = if let Some(leaf) = diff.leaf(self.own_leaf_index()) {
227 leaf_node_keypairs.into_iter().find_map(|keypair| {
228 if leaf.encryption_key() == keypair.public_key() {
229 Some(keypair)
230 } else {
231 None
232 }
233 })
234 } else {
235 debug_assert!(false);
237 None
238 };
239
240 let update_path_leaf_node = Some(path.leaf_node().clone());
244 debug_assert_eq!(diff.leaf(sender_index), path.leaf_node().into());
245
246 (
247 commit_secret,
248 new_keypairs,
249 new_leaf_keypair_option,
250 update_path_leaf_node,
251 )
252 } else {
253 if apply_proposals_values.path_required {
254 return Err(StageCommitError::RequiredPathNotFound);
256 }
257
258 diff.update_group_context(
260 provider.crypto(),
261 apply_proposals_values.extensions.clone(),
262 )?;
263
264 (CommitSecret::zero_secret(ciphersuite), vec![], None, None)
265 };
266
267 diff.update_confirmed_transcript_hash(provider.crypto(), mls_content)?;
269
270 let received_confirmation_tag = mls_content
271 .confirmation_tag()
272 .ok_or(StageCommitError::ConfirmationTagMissing)?;
273
274 let serialized_provisional_group_context = diff
275 .group_context()
276 .tls_serialize_detached()
277 .map_err(LibraryError::missing_bound_check)?;
278
279 let EpochSecretsResult {
280 epoch_secrets,
281 #[cfg(feature = "extensions-draft-08")]
282 application_exporter,
283 } = self.derive_epoch_secrets(
284 provider,
285 apply_proposals_values,
286 self.group_epoch_secrets(),
287 commit_secret,
288 &serialized_provisional_group_context,
289 )?;
290 let (provisional_group_secrets, provisional_message_secrets) = epoch_secrets.split_secrets(
291 serialized_provisional_group_context,
292 diff.tree_size(),
293 self.own_leaf_index(),
294 );
295
296 let own_confirmation_tag = provisional_message_secrets
299 .confirmation_key()
300 .tag(
301 provider.crypto(),
302 self.ciphersuite(),
303 diff.group_context().confirmed_transcript_hash(),
304 )
305 .map_err(LibraryError::unexpected_crypto_error)?;
306 if &own_confirmation_tag != received_confirmation_tag {
307 log::error!("Confirmation tag mismatch");
308 log_crypto!(trace, " Got: {:x?}", received_confirmation_tag);
309 log_crypto!(trace, " Expected: {:x?}", own_confirmation_tag);
310 if !crate::skip_validation::is_disabled::confirmation_tag() {
317 return Err(StageCommitError::ConfirmationTagMismatch);
318 }
319 }
320
321 diff.update_interim_transcript_hash(ciphersuite, provider.crypto(), own_confirmation_tag)?;
322
323 let staged_diff = diff.into_staged_diff(provider.crypto(), ciphersuite)?;
324 #[cfg(feature = "extensions-draft-08")]
325 let application_export_tree = ApplicationExportTree::new(application_exporter);
326 let staged_commit_state =
327 StagedCommitState::GroupMember(Box::new(MemberStagedCommitState::new(
328 provisional_group_secrets,
329 provisional_message_secrets,
330 staged_diff,
331 new_keypairs,
332 new_leaf_keypair_option,
333 update_path_leaf_node,
334 #[cfg(feature = "extensions-draft-08")]
335 application_export_tree,
336 )));
337 let staged_commit = StagedCommit::new(proposal_queue, staged_commit_state);
338
339 Ok(staged_commit)
340 }
341
342 pub(crate) fn merge_commit<Provider: OpenMlsProvider>(
348 &mut self,
349 provider: &Provider,
350 staged_commit: StagedCommit,
351 ) -> Result<(), MergeCommitError<Provider::StorageError>> {
352 let old_epoch_keypairs = self
355 .read_epoch_keypairs(provider.storage())
356 .map_err(MergeCommitError::StorageError)?;
357 match staged_commit.state {
358 StagedCommitState::PublicState(staged_state) => {
359 self.public_group
360 .merge_diff(staged_state.into_staged_diff());
361 self.store(provider.storage())
362 .map_err(MergeCommitError::StorageError)?;
363 Ok(())
364 }
365 StagedCommitState::GroupMember(state) => {
366 let past_epoch = self.context().epoch();
368 let leaves = self.public_group().members().collect();
370 self.group_epoch_secrets = state.group_epoch_secrets;
373
374 let mut message_secrets = state.message_secrets;
376 mem::swap(
377 &mut message_secrets,
378 self.message_secrets_store.message_secrets_mut(),
379 );
380 self.message_secrets_store
381 .add(past_epoch, message_secrets, leaves);
382
383 #[cfg(feature = "extensions-draft-08")]
385 {
386 if let Some(application_export_tree) = state.application_export_tree {
390 use openmls_traits::storage::StorageProvider as _;
393 provider
394 .storage()
395 .write_application_export_tree(
396 self.group_id(),
397 &application_export_tree,
398 )
399 .map_err(MergeCommitError::StorageError)?;
400
401 self.application_export_tree = Some(application_export_tree);
402 }
403 }
404
405 self.public_group.merge_diff(state.staged_diff);
406
407 let leaf_keypair = if let Some(keypair) = &state.new_leaf_keypair_option {
408 vec![keypair.clone()]
409 } else {
410 vec![]
411 };
412
413 let new_owned_encryption_keys = self
415 .public_group()
416 .owned_encryption_keys(self.own_leaf_index());
417 let epoch_keypairs: Vec<EncryptionKeyPair> = old_epoch_keypairs
419 .into_iter()
420 .chain(state.new_keypairs)
421 .chain(leaf_keypair)
422 .filter(|keypair| new_owned_encryption_keys.contains(keypair.public_key()))
423 .collect();
424
425 debug_assert_eq!(new_owned_encryption_keys.len(), epoch_keypairs.len());
427 if new_owned_encryption_keys.len() != epoch_keypairs.len() {
428 return Err(LibraryError::custom(
429 "We should have all the private key material we need.",
430 )
431 .into());
432 }
433
434 let storage = provider.storage();
436 let group_id = self.group_id();
437
438 self.public_group
439 .store(storage)
440 .map_err(MergeCommitError::StorageError)?;
441 storage
442 .write_group_epoch_secrets(group_id, &self.group_epoch_secrets)
443 .map_err(MergeCommitError::StorageError)?;
444 storage
445 .write_message_secrets(group_id, &self.message_secrets_store)
446 .map_err(MergeCommitError::StorageError)?;
447
448 self.store_epoch_keypairs(storage, epoch_keypairs.as_slice())
450 .map_err(MergeCommitError::StorageError)?;
451
452 self.delete_previous_epoch_keypairs(storage)
454 .map_err(MergeCommitError::StorageError)?;
455 if let Some(keypair) = state.new_leaf_keypair_option {
456 keypair
457 .delete(storage)
458 .map_err(MergeCommitError::StorageError)?;
459 }
460
461 storage
463 .clear_proposal_queue::<GroupId, ProposalRef>(group_id)
464 .map_err(MergeCommitError::StorageError)?;
465 self.proposal_store_mut().empty();
466
467 Ok(())
468 }
469 }
470 }
471}
472
473#[derive(Debug, Serialize, Deserialize)]
474#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
475pub(crate) enum StagedCommitState {
476 PublicState(Box<PublicStagedCommitState>),
477 GroupMember(Box<MemberStagedCommitState>),
478}
479
480#[derive(Debug, Serialize, Deserialize)]
482#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
483pub struct StagedCommit {
484 staged_proposal_queue: ProposalQueue,
485 state: StagedCommitState,
486}
487
488impl StagedCommit {
489 pub(crate) fn new(staged_proposal_queue: ProposalQueue, state: StagedCommitState) -> Self {
492 StagedCommit {
493 staged_proposal_queue,
494 state,
495 }
496 }
497
498 pub fn add_proposals(&self) -> impl Iterator<Item = QueuedAddProposal<'_>> {
500 self.staged_proposal_queue.add_proposals()
501 }
502
503 pub fn remove_proposals(&self) -> impl Iterator<Item = QueuedRemoveProposal<'_>> {
505 self.staged_proposal_queue.remove_proposals()
506 }
507
508 pub fn update_proposals(&self) -> impl Iterator<Item = QueuedUpdateProposal<'_>> {
510 self.staged_proposal_queue.update_proposals()
511 }
512
513 pub fn psk_proposals(&self) -> impl Iterator<Item = QueuedPskProposal<'_>> {
515 self.staged_proposal_queue.psk_proposals()
516 }
517
518 pub fn queued_proposals(&self) -> impl Iterator<Item = &QueuedProposal> {
520 self.staged_proposal_queue.queued_proposals()
521 }
522
523 pub fn update_path_leaf_node(&self) -> Option<&LeafNode> {
525 match self.state {
526 StagedCommitState::PublicState(ref public_state) => {
527 public_state.update_path_leaf_node()
528 }
529 StagedCommitState::GroupMember(ref group_member_state) => {
530 group_member_state.update_path_leaf_node.as_ref()
531 }
532 }
533 }
534
535 pub fn credentials_to_verify(&self) -> impl Iterator<Item = &Credential> {
537 let update_path_leaf_node_cred = if let Some(node) = self.update_path_leaf_node() {
538 vec![node.credential()]
539 } else {
540 vec![]
541 };
542
543 update_path_leaf_node_cred
544 .into_iter()
545 .chain(
546 self.queued_proposals()
547 .flat_map(|proposal: &QueuedProposal| match proposal.proposal() {
548 Proposal::Update(update_proposal) => {
549 vec![update_proposal.leaf_node().credential()].into_iter()
550 }
551 Proposal::Add(add_proposal) => {
552 vec![add_proposal.key_package().leaf_node().credential()].into_iter()
553 }
554 Proposal::GroupContextExtensions(gce_proposal) => gce_proposal
555 .extensions()
556 .iter()
557 .flat_map(|extension| {
558 match extension {
559 Extension::ExternalSenders(external_senders) => {
560 external_senders
561 .iter()
562 .map(|external_sender| external_sender.credential())
563 .collect()
564 }
565 _ => vec![],
566 }
567 .into_iter()
568 })
569 .collect::<Vec<_>>()
576 .into_iter(),
577 _ => vec![].into_iter(),
578 }),
579 )
580 }
581
582 pub fn self_removed(&self) -> bool {
585 matches!(self.state, StagedCommitState::PublicState(_))
586 }
587
588 pub fn group_context(&self) -> &GroupContext {
590 match self.state {
591 StagedCommitState::PublicState(ref ps) => ps.staged_diff().group_context(),
592 StagedCommitState::GroupMember(ref gm) => gm.group_context(),
593 }
594 }
595
596 pub(crate) fn into_state(self) -> StagedCommitState {
598 self.state
599 }
600
601 pub fn epoch_authenticator(&self) -> Option<&EpochAuthenticator> {
605 if let StagedCommitState::GroupMember(ref gm) = self.state {
606 Some(gm.group_epoch_secrets.epoch_authenticator())
607 } else {
608 None
609 }
610 }
611
612 #[cfg(feature = "extensions-draft-08")]
613 pub(crate) fn safe_export_secret(
614 &mut self,
615 crypto: &impl OpenMlsCrypto,
616 component_id: u16,
617 ) -> Result<Vec<u8>, StagedSafeExportSecretError> {
618 let ciphersuite = self.group_context().ciphersuite();
619 let StagedCommitState::GroupMember(ref mut staged_commit) = self.state else {
620 return Err(StagedSafeExportSecretError::NotGroupMember);
621 };
622 let Some(application_export_tree) = staged_commit.application_export_tree.as_mut() else {
623 return Err(StagedSafeExportSecretError::Unsupported);
624 };
625 let secret =
626 application_export_tree.safe_export_secret(crypto, ciphersuite, component_id)?;
627 Ok(secret.as_slice().to_vec())
628 }
629}
630
631#[derive(Debug, Serialize, Deserialize)]
633#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
634pub(crate) struct MemberStagedCommitState {
635 group_epoch_secrets: GroupEpochSecrets,
636 message_secrets: MessageSecrets,
637 staged_diff: StagedPublicGroupDiff,
638 new_keypairs: Vec<EncryptionKeyPair>,
639 new_leaf_keypair_option: Option<EncryptionKeyPair>,
640 update_path_leaf_node: Option<LeafNode>,
641 #[cfg(feature = "extensions-draft-08")]
642 #[serde(default)]
643 application_export_tree: Option<ApplicationExportTree>,
646}
647
648impl MemberStagedCommitState {
649 pub(crate) fn new(
650 group_epoch_secrets: GroupEpochSecrets,
651 message_secrets: MessageSecrets,
652 staged_diff: StagedPublicGroupDiff,
653 new_keypairs: Vec<EncryptionKeyPair>,
654 new_leaf_keypair_option: Option<EncryptionKeyPair>,
655 update_path_leaf_node: Option<LeafNode>,
656 #[cfg(feature = "extensions-draft-08")] application_export_tree: ApplicationExportTree,
657 ) -> Self {
658 Self {
659 group_epoch_secrets,
660 message_secrets,
661 staged_diff,
662 new_keypairs,
663 new_leaf_keypair_option,
664 update_path_leaf_node,
665 #[cfg(feature = "extensions-draft-08")]
666 application_export_tree: Some(application_export_tree),
667 }
668 }
669
670 pub(crate) fn group_context(&self) -> &GroupContext {
672 self.staged_diff.group_context()
673 }
674}