1use openmls_traits::{
5 crypto::OpenMlsCrypto, random::OpenMlsRand, signatures::Signer, storage::StorageProvider as _,
6};
7use tls_codec::Serialize as _;
8
9use crate::{
10 binary_tree::LeafNodeIndex,
11 ciphersuite::{signable::Signable as _, Secret},
12 group::{
13 create_commit::CommitType, diff::compute_path::PathComputationResult,
14 CommitBuilderStageError, CreateCommitError, Extension, Extensions, ExternalPubExtension,
15 ProposalQueue, ProposalQueueError, QueuedProposal, RatchetTreeExtension, StagedCommit,
16 },
17 key_packages::KeyPackage,
18 messages::{
19 group_info::{GroupInfo, GroupInfoTBS},
20 Commit, Welcome,
21 },
22 prelude::{LeafNodeParameters, LibraryError},
23 schedule::{
24 psk::{load_psks, PskSecret},
25 JoinerSecret, KeySchedule, PreSharedKeyId,
26 },
27 storage::{OpenMlsProvider, StorageProvider},
28 versions::ProtocolVersion,
29};
30
31use super::{
32 mls_auth_content::AuthenticatedContent,
33 staged_commit::{MemberStagedCommitState, StagedCommitState},
34 AddProposal, CreateCommitResult, GroupContextExtensionProposal, MlsGroup, MlsGroupState,
35 MlsMessageOut, PendingCommitState, Proposal, RemoveProposal, Sender,
36};
37
38pub struct Initial {
40 own_proposals: Vec<Proposal>,
41 force_self_update: bool,
42 leaf_node_parameters: LeafNodeParameters,
43
44 consume_proposal_store: bool,
47}
48
49impl Default for Initial {
50 fn default() -> Self {
51 Initial {
52 consume_proposal_store: true,
53 force_self_update: false,
54 leaf_node_parameters: LeafNodeParameters::default(),
55 own_proposals: vec![],
56 }
57 }
58}
59
60pub struct LoadedPsks {
62 own_proposals: Vec<Proposal>,
63 force_self_update: bool,
64 leaf_node_parameters: LeafNodeParameters,
65
66 consume_proposal_store: bool,
69 psks: Vec<(PreSharedKeyId, Secret)>,
70}
71
72pub struct Complete {
74 result: CreateCommitResult,
75}
76
77#[derive(Debug)]
115pub struct CommitBuilder<'a, T> {
116 group: &'a mut MlsGroup,
119
120 stage: T,
122}
123
124impl<'a, T> CommitBuilder<'a, T> {
125 pub(crate) fn replace_stage<NextStage>(
126 self,
127 next_stage: NextStage,
128 ) -> (T, CommitBuilder<'a, NextStage>) {
129 self.map_stage(|prev_stage| (prev_stage, next_stage))
130 }
131
132 pub(crate) fn into_stage<NextStage>(
133 self,
134 next_stage: NextStage,
135 ) -> CommitBuilder<'a, NextStage> {
136 self.replace_stage(next_stage).1
137 }
138
139 pub(crate) fn take_stage(self) -> (T, CommitBuilder<'a, ()>) {
140 self.replace_stage(())
141 }
142
143 pub(crate) fn map_stage<NextStage, Aux, F: FnOnce(T) -> (Aux, NextStage)>(
144 self,
145 f: F,
146 ) -> (Aux, CommitBuilder<'a, NextStage>) {
147 let Self { group, stage } = self;
148
149 let (aux, stage) = f(stage);
150
151 (aux, CommitBuilder { group, stage })
152 }
153
154 #[cfg(feature = "fork-resolution")]
155 pub(crate) fn stage(&self) -> &T {
156 &self.stage
157 }
158}
159
160impl MlsGroup {
161 pub fn commit_builder(&mut self) -> CommitBuilder<Initial> {
163 CommitBuilder::new(self)
164 }
165}
166
167impl<'a> CommitBuilder<'a, Initial> {
168 pub fn new(group: &'a mut MlsGroup) -> Self {
170 Self {
171 group,
172 stage: Initial::default(),
173 }
174 }
175
176 pub fn consume_proposal_store(mut self, consume_proposal_store: bool) -> Self {
179 self.stage.consume_proposal_store = consume_proposal_store;
180 self
181 }
182
183 pub fn force_self_update(mut self, force_self_update: bool) -> Self {
185 self.stage.force_self_update = force_self_update;
186 self
187 }
188
189 pub fn add_proposal(mut self, proposal: Proposal) -> Self {
191 self.stage.own_proposals.push(proposal);
192 self
193 }
194
195 pub fn add_proposals(mut self, proposals: impl IntoIterator<Item = Proposal>) -> Self {
197 self.stage.own_proposals.extend(proposals);
198 self
199 }
200
201 pub fn leaf_node_parameters(mut self, leaf_node_parameters: LeafNodeParameters) -> Self {
204 self.stage.leaf_node_parameters = leaf_node_parameters;
205 self
206 }
207
208 pub fn propose_adds(mut self, key_packages: impl IntoIterator<Item = KeyPackage>) -> Self {
211 self.stage.own_proposals.extend(
212 key_packages
213 .into_iter()
214 .map(|key_package| Proposal::Add(AddProposal { key_package })),
215 );
216 self
217 }
218
219 pub fn propose_removals(mut self, removed: impl IntoIterator<Item = LeafNodeIndex>) -> Self {
220 self.stage.own_proposals.extend(
221 removed
222 .into_iter()
223 .map(|removed| Proposal::Remove(RemoveProposal { removed })),
224 );
225 self
226 }
227
228 pub fn propose_group_context_extensions(mut self, extensions: Extensions) -> Self {
229 self.stage
230 .own_proposals
231 .push(Proposal::GroupContextExtensions(
232 GroupContextExtensionProposal::new(extensions),
233 ));
234 self
235 }
236
237 pub fn load_psks<Storage: StorageProvider>(
239 self,
240 storage: &'a Storage,
241 ) -> Result<CommitBuilder<'a, LoadedPsks>, CreateCommitError> {
242 let psk_ids: Vec<_> = self
243 .stage
244 .own_proposals
245 .iter()
246 .chain(
247 self.group
248 .proposal_store()
249 .proposals()
250 .map(|queued_proposal| queued_proposal.proposal()),
251 )
252 .filter_map(|proposal| match proposal {
253 Proposal::PreSharedKey(psk_proposal) => Some(psk_proposal.clone().into_psk_id()),
254 _ => None,
255 })
256 .collect();
257
258 let psks = load_psks(storage, &self.group.resumption_psk_store, &psk_ids)?
260 .into_iter()
261 .map(|(psk_id_ref, key)| (psk_id_ref.clone(), key))
262 .collect();
263
264 Ok(self
265 .map_stage(|stage| {
266 (
267 (),
268 LoadedPsks {
269 own_proposals: stage.own_proposals,
270 psks,
271 force_self_update: stage.force_self_update,
272 leaf_node_parameters: stage.leaf_node_parameters,
273 consume_proposal_store: stage.consume_proposal_store,
274 },
275 )
276 })
277 .1)
278 }
279}
280
281impl<'a> CommitBuilder<'a, LoadedPsks> {
282 pub fn build<S: Signer>(
286 self,
287 rand: &impl OpenMlsRand,
288 crypto: &impl OpenMlsCrypto,
289 signer: &S,
290 f: impl FnMut(&QueuedProposal) -> bool,
291 ) -> Result<CommitBuilder<'a, Complete>, CreateCommitError> {
292 self.build_internal(rand, crypto, signer, None::<&S>, f)
293 }
294
295 pub fn build_with_new_signer(
303 self,
304 rand: &impl OpenMlsRand,
305 crypto: &impl OpenMlsCrypto,
306 old_signer: &impl Signer,
307 new_signer: &impl Signer,
308 f: impl FnMut(&QueuedProposal) -> bool,
309 ) -> Result<CommitBuilder<'a, Complete>, CreateCommitError> {
310 self.build_internal(rand, crypto, old_signer, Some(new_signer), f)
311 }
312
313 fn build_internal(
314 self,
315 rand: &impl OpenMlsRand,
316 crypto: &impl OpenMlsCrypto,
317 old_signer: &impl Signer,
318 new_signer: Option<&impl Signer>,
319 f: impl FnMut(&QueuedProposal) -> bool,
320 ) -> Result<CommitBuilder<'a, Complete>, CreateCommitError> {
321 let ciphersuite = self.group.ciphersuite();
322 let sender = Sender::build_member(self.group.own_leaf_index());
323 let (cur_stage, builder) = self.take_stage();
324 let psks = cur_stage.psks;
325
326 let own_proposals: Vec<_> = cur_stage
329 .own_proposals
330 .into_iter()
331 .map(|proposal| {
332 QueuedProposal::from_proposal_and_sender(ciphersuite, crypto, proposal, &sender)
333 })
334 .collect::<Result<_, _>>()?;
335
336 let group_proposal_store_queue = builder
339 .group
340 .pending_proposals()
341 .filter(|_| cur_stage.consume_proposal_store)
342 .cloned();
343
344 let proposal_queue = group_proposal_store_queue.chain(own_proposals).filter(f);
348
349 let (proposal_queue, contains_own_updates) =
350 ProposalQueue::filter_proposals_without_inline(
351 proposal_queue,
352 builder.group.own_leaf_index,
353 )
354 .map_err(|e| match e {
355 ProposalQueueError::LibraryError(e) => e.into(),
356 ProposalQueueError::ProposalNotFound => CreateCommitError::MissingProposal,
357 ProposalQueueError::UpdateFromExternalSender
358 | ProposalQueueError::SelfRemoveFromNonMember => {
359 CreateCommitError::WrongProposalSenderType
360 }
361 })?;
362
363 builder
368 .group
369 .public_group
370 .validate_proposal_type_support(&proposal_queue)?;
371 builder
376 .group
377 .public_group
378 .validate_key_uniqueness(&proposal_queue, None)?;
379 builder
381 .group
382 .public_group
383 .validate_add_proposals(&proposal_queue)?;
384 builder
387 .group
388 .public_group
389 .validate_capabilities(&proposal_queue)?;
390 builder
393 .group
394 .public_group
395 .validate_remove_proposals(&proposal_queue)?;
396 builder
397 .group
398 .public_group
399 .validate_pre_shared_key_proposals(&proposal_queue)?;
400 builder
405 .group
406 .public_group
407 .validate_update_proposals(&proposal_queue, builder.group.own_leaf_index())?;
408
409 builder
412 .group
413 .public_group
414 .validate_group_context_extensions_proposal(&proposal_queue)?;
415
416 let ciphersuite = builder.group.ciphersuite();
417 let sender = Sender::build_member(builder.group.own_leaf_index());
418 let proposal_reference_list = proposal_queue.commit_list();
419
420 let mut diff = builder.group.public_group.empty_diff();
422
423 let apply_proposals_values =
425 diff.apply_proposals(&proposal_queue, builder.group.own_leaf_index())?;
426 if apply_proposals_values.self_removed {
427 return Err(CreateCommitError::CannotRemoveSelf);
428 }
429
430 let path_computation_result =
431 if apply_proposals_values.path_required
433 || contains_own_updates
434 || cur_stage.force_self_update
435 || !cur_stage.leaf_node_parameters.is_empty()
436 {
437 if let Some(new_signer) = new_signer {
441 diff.compute_path(
442 rand,
443 crypto,
444 builder.group.own_leaf_index(),
445 apply_proposals_values.exclusion_list(),
446 &CommitType::Member,
447 &cur_stage.leaf_node_parameters,
448 new_signer,
449 apply_proposals_values.extensions.clone()
450 )?
451 } else {
452 diff.compute_path(
453 rand,
454 crypto,
455 builder.group.own_leaf_index(),
456 apply_proposals_values.exclusion_list(),
457 &CommitType::Member,
458 &cur_stage.leaf_node_parameters,
459 old_signer,
460 apply_proposals_values.extensions.clone()
461 )?
462 }
463 } else {
464 diff.update_group_context(crypto, apply_proposals_values.extensions.clone())?;
467 PathComputationResult::default()
468 };
469
470 let update_path_leaf_node = path_computation_result
471 .encrypted_path
472 .as_ref()
473 .map(|path| path.leaf_node().clone());
474
475 let commit = Commit {
477 proposals: proposal_reference_list,
478 path: path_computation_result.encrypted_path,
479 };
480
481 let mut authenticated_content = AuthenticatedContent::commit(
483 builder.group.framing_parameters(),
484 sender,
485 commit,
486 builder.group.public_group.group_context(),
487 old_signer,
488 )?;
489
490 diff.update_confirmed_transcript_hash(crypto, &authenticated_content)?;
492
493 let serialized_provisional_group_context = diff
494 .group_context()
495 .tls_serialize_detached()
496 .map_err(LibraryError::missing_bound_check)?;
497
498 let joiner_secret = JoinerSecret::new(
499 crypto,
500 ciphersuite,
501 path_computation_result.commit_secret,
502 builder.group.group_epoch_secrets().init_secret(),
503 &serialized_provisional_group_context,
504 )
505 .map_err(LibraryError::unexpected_crypto_error)?;
506
507 let psk_secret = { PskSecret::new(crypto, ciphersuite, psks)? };
509
510 let mut key_schedule = KeySchedule::init(ciphersuite, crypto, &joiner_secret, psk_secret)?;
512
513 let serialized_provisional_group_context = diff
514 .group_context()
515 .tls_serialize_detached()
516 .map_err(LibraryError::missing_bound_check)?;
517
518 let welcome_secret = key_schedule
519 .welcome(crypto, builder.group.ciphersuite())
520 .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?;
521 key_schedule
522 .add_context(crypto, &serialized_provisional_group_context)
523 .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?;
524 let provisional_epoch_secrets = key_schedule
525 .epoch_secrets(crypto, builder.group.ciphersuite())
526 .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?;
527
528 let confirmation_tag = provisional_epoch_secrets
530 .confirmation_key()
531 .tag(
532 crypto,
533 builder.group.ciphersuite(),
534 diff.group_context().confirmed_transcript_hash(),
535 )
536 .map_err(LibraryError::unexpected_crypto_error)?;
537
538 authenticated_content.set_confirmation_tag(confirmation_tag.clone());
540
541 diff.update_interim_transcript_hash(ciphersuite, crypto, confirmation_tag.clone())?;
542
543 let needs_welcome = !apply_proposals_values.invitation_list.is_empty();
545
546 let needs_group_info =
549 needs_welcome || builder.group.configuration().use_ratchet_tree_extension;
550
551 let group_info = if !needs_group_info {
552 None
553 } else {
554 let external_pub = provisional_epoch_secrets
556 .external_secret()
557 .derive_external_keypair(crypto, ciphersuite)
558 .map_err(LibraryError::unexpected_crypto_error)?
559 .public;
560 let external_pub_extension =
561 Extension::ExternalPub(ExternalPubExtension::new(external_pub.into()));
562
563 let extensions: Extensions = if builder.group.configuration().use_ratchet_tree_extension
565 {
566 Extensions::from_vec(vec![
567 Extension::RatchetTree(RatchetTreeExtension::new(diff.export_ratchet_tree())),
568 external_pub_extension,
569 ])?
570 } else {
571 Extensions::single(external_pub_extension)
572 };
573
574 let group_info_tbs = {
576 GroupInfoTBS::new(
577 diff.group_context().clone(),
578 extensions,
579 confirmation_tag,
580 builder.group.own_leaf_index(),
581 )
582 };
583 Some(group_info_tbs.sign(old_signer)?)
585 };
586
587 let welcome_option = if !needs_welcome {
588 None
589 } else {
590 let (welcome_key, welcome_nonce) = welcome_secret
592 .derive_welcome_key_nonce(crypto, builder.group.ciphersuite())
593 .map_err(LibraryError::unexpected_crypto_error)?;
594 let encrypted_group_info = welcome_key
595 .aead_seal(
596 crypto,
597 group_info
598 .as_ref()
599 .ok_or_else(|| LibraryError::custom("GroupInfo was not computed"))?
600 .tls_serialize_detached()
601 .map_err(LibraryError::missing_bound_check)?
602 .as_slice(),
603 &[],
604 &welcome_nonce,
605 )
606 .map_err(LibraryError::unexpected_crypto_error)?;
607
608 let encrypted_secrets = diff.encrypt_group_secrets(
611 &joiner_secret,
612 apply_proposals_values.invitation_list,
613 path_computation_result.plain_path.as_deref(),
614 &apply_proposals_values.presharedkeys,
615 &encrypted_group_info,
616 crypto,
617 builder.group.own_leaf_index(),
618 )?;
619
620 let welcome = Welcome::new(ciphersuite, encrypted_secrets, encrypted_group_info);
622 Some(welcome)
623 };
624
625 let (provisional_group_epoch_secrets, provisional_message_secrets) =
626 provisional_epoch_secrets.split_secrets(
627 serialized_provisional_group_context,
628 diff.tree_size(),
629 builder.group.own_leaf_index(),
630 );
631
632 let staged_commit_state = MemberStagedCommitState::new(
633 provisional_group_epoch_secrets,
634 provisional_message_secrets,
635 diff.into_staged_diff(crypto, ciphersuite)?,
636 path_computation_result.new_keypairs,
637 None,
640 update_path_leaf_node,
641 );
642 let staged_commit = StagedCommit::new(
643 proposal_queue,
644 StagedCommitState::GroupMember(Box::new(staged_commit_state)),
645 );
646
647 let use_ratchet_tree_extension = builder.group.configuration().use_ratchet_tree_extension;
648
649 Ok(builder.into_stage(Complete {
650 result: CreateCommitResult {
651 commit: authenticated_content,
652 welcome_option,
653 staged_commit,
654 group_info: group_info.filter(|_| use_ratchet_tree_extension),
655 },
656 }))
657 }
658}
659
660impl CommitBuilder<'_, Complete> {
661 #[cfg(test)]
662 pub(crate) fn commit_result(self) -> CreateCommitResult {
663 self.stage.result
664 }
665
666 pub fn stage_commit<Provider: OpenMlsProvider>(
668 self,
669 provider: &Provider,
670 ) -> Result<CommitMessageBundle, CommitBuilderStageError<Provider::StorageError>> {
671 let Self {
672 group,
673 stage: Complete {
674 result: create_commit_result,
675 },
676 ..
677 } = self;
678
679 group.group_state = MlsGroupState::PendingCommit(Box::new(PendingCommitState::Member(
682 create_commit_result.staged_commit,
683 )));
684
685 provider
686 .storage()
687 .write_group_state(group.group_id(), &group.group_state)
688 .map_err(CommitBuilderStageError::KeyStoreError)?;
689
690 group.reset_aad();
691
692 let mls_message = group.content_to_mls_message(create_commit_result.commit, provider)?;
698
699 Ok(CommitMessageBundle {
700 version: group.version(),
701 commit: mls_message,
702 welcome: create_commit_result.welcome_option,
703 group_info: create_commit_result.group_info,
704 })
705 }
706}
707
708#[derive(Debug, Clone)]
711pub struct CommitMessageBundle {
712 version: ProtocolVersion,
713 commit: MlsMessageOut,
714 welcome: Option<Welcome>,
715 group_info: Option<GroupInfo>,
716}
717
718#[cfg(test)]
719impl CommitMessageBundle {
720 pub fn new(
721 version: ProtocolVersion,
722 commit: MlsMessageOut,
723 welcome: Option<Welcome>,
724 group_info: Option<GroupInfo>,
725 ) -> Self {
726 Self {
727 version,
728 commit,
729 welcome,
730 group_info,
731 }
732 }
733}
734
735impl CommitMessageBundle {
736 pub fn commit(&self) -> &MlsMessageOut {
740 &self.commit
741 }
742
743 pub fn welcome(&self) -> Option<&Welcome> {
746 self.welcome.as_ref()
747 }
748
749 pub fn to_welcome_msg(&self) -> Option<MlsMessageOut> {
752 self.welcome
753 .as_ref()
754 .map(|welcome| MlsMessageOut::from_welcome(welcome.clone(), self.version))
755 }
756
757 pub fn group_info(&self) -> Option<&GroupInfo> {
761 self.group_info.as_ref()
762 }
763
764 pub fn contents(&self) -> (&MlsMessageOut, Option<&Welcome>, Option<&GroupInfo>) {
767 (
768 &self.commit,
769 self.welcome.as_ref(),
770 self.group_info.as_ref(),
771 )
772 }
773
774 pub fn into_commit(self) -> MlsMessageOut {
778 self.commit
779 }
780
781 pub fn into_welcome(self) -> Option<Welcome> {
785 self.welcome
786 }
787
788 pub fn into_welcome_msg(self) -> Option<MlsMessageOut> {
791 self.welcome
792 .map(|welcome| MlsMessageOut::from_welcome(welcome, self.version))
793 }
794
795 pub fn into_group_info(self) -> Option<GroupInfo> {
800 self.group_info
801 }
802
803 pub fn into_group_info_msg(self) -> Option<MlsMessageOut> {
805 self.group_info.map(|group_info| group_info.into())
806 }
807
808 pub fn into_contents(self) -> (MlsMessageOut, Option<Welcome>, Option<GroupInfo>) {
811 (self.commit, self.welcome, self.group_info)
812 }
813
814 pub fn into_messages(self) -> (MlsMessageOut, Option<MlsMessageOut>, Option<MlsMessageOut>) {
817 (
818 self.commit,
819 self.welcome
820 .map(|welcome| MlsMessageOut::from_welcome(welcome, self.version)),
821 self.group_info.map(|group_info| group_info.into()),
822 )
823 }
824}
825
826impl IntoIterator for CommitMessageBundle {
827 type Item = MlsMessageOut;
828
829 type IntoIter = core::iter::Chain<
830 core::iter::Chain<
831 core::option::IntoIter<MlsMessageOut>,
832 core::option::IntoIter<MlsMessageOut>,
833 >,
834 core::option::IntoIter<MlsMessageOut>,
835 >;
836
837 fn into_iter(self) -> Self::IntoIter {
838 let welcome = self.to_welcome_msg();
839 let group_info = self.group_info.map(|group_info| group_info.into());
840
841 Some(self.commit)
842 .into_iter()
843 .chain(welcome)
844 .chain(group_info)
845 }
846}