Skip to main content

openmls/group/fork_resolution/
reboot.rs

1//! The [`reboot`] module contains helpers to set up a new group and add all members of the current
2//! group. The application needs to determine who should set that new group up and how to migrate
3//! the group context extensions from the old group. This is the more expensive mechanism.
4
5use openmls_traits::signatures::Signer;
6
7use crate::{
8    credentials::CredentialWithKey,
9    extensions::errors::InvalidExtensionError,
10    group::{
11        commit_builder::{CommitBuilder, CommitMessageBundle, Initial},
12        mls_group::builder::MlsGroupBuilder,
13        CommitBuilderStageError, CreateCommitError, Extensions, GroupId, Member, MlsGroup,
14        NewGroupError,
15    },
16    prelude::GroupContext,
17    prelude::KeyPackage,
18    storage::OpenMlsProvider,
19};
20
21impl MlsGroup {
22    /// The first step towards creating a new group based on the parameters and membership list of
23    /// the current one.
24    pub fn reboot(&'_ self, group_id: GroupId) -> RebootBuilder<'_> {
25        let group_builder = MlsGroup::builder()
26            .with_wire_format_policy(self.configuration().wire_format_policy)
27            .padding_size(self.configuration().padding_size)
28            .max_past_epochs(self.configuration().max_past_epochs)
29            .number_of_resumption_psks(self.configuration().number_of_resumption_psks)
30            .use_ratchet_tree_extension(self.configuration().use_ratchet_tree_extension)
31            .sender_ratchet_configuration(self.configuration().sender_ratchet_configuration)
32            .ciphersuite(self.ciphersuite())
33            .with_group_id(group_id);
34
35        RebootBuilder {
36            group: self,
37            group_builder,
38        }
39    }
40}
41
42/// The builder type for a group reboot.
43pub struct RebootBuilder<'a> {
44    group: &'a MlsGroup,
45    group_builder: MlsGroupBuilder,
46}
47
48impl<'a> RebootBuilder<'a> {
49    /// Returns the group context extensions of the old group, so they can be updated and passed
50    /// into the new group.
51    pub fn old_group_context_extensions(&self) -> &Extensions<GroupContext> {
52        self.group.context().extensions()
53    }
54
55    /// The members of the old group, so new key packages for other members can be retrieved.
56    pub fn old_members(&self) -> impl Iterator<Item = Member> + 'a {
57        self.group
58            .members()
59            .filter(|member| member.index != self.group.own_leaf_index())
60    }
61
62    /// Lets the caller make changes to the [`MlsGroupBuilder`] before the group is created.
63    pub fn refine_group_builder(
64        self,
65        mut f: impl FnMut(MlsGroupBuilder) -> MlsGroupBuilder,
66    ) -> Self {
67        Self {
68            group_builder: f(self.group_builder),
69            ..self
70        }
71    }
72
73    /// Creates the group and commit using the provided `extensions` and `new_members`. The caller
74    /// can also make further changes to the [`CommitBuilder`] using the `refine_commit_builder`
75    /// argument. If that is not desired, provide the identity function (`|b| b`).
76    pub fn finish<Provider: OpenMlsProvider>(
77        self,
78        extensions: Extensions<GroupContext>,
79        new_members: Vec<KeyPackage>,
80        refine_commit_builder: impl FnMut(CommitBuilder<Initial>) -> CommitBuilder<Initial>,
81        provider: &Provider,
82        signer: &impl Signer,
83        credential_with_key: CredentialWithKey,
84    ) -> Result<(MlsGroup, CommitMessageBundle), RebootError<Provider::StorageError>> {
85        let group_builder = self.group_builder.with_group_context_extensions(extensions);
86
87        let mut new_group = group_builder.build(provider, signer, credential_with_key)?;
88
89        new_group
90            .commit_builder()
91            .propose_adds(new_members)
92            .pipe_through(refine_commit_builder)
93            .load_psks(provider.storage())?
94            .build(provider.rand(), provider.crypto(), signer, |_| true)?
95            .stage_commit(provider)
96            .map_err(RebootError::CommitBuilderStage)
97            .map(|message_bundle| (new_group, message_bundle))
98    }
99}
100
101/// Indicates an error occurred during reboot.
102#[derive(Debug, thiserror::Error)]
103pub enum RebootError<StorageError> {
104    /// An invalid extension was provided.
105    #[error(transparent)]
106    InvalidExtension(#[from] InvalidExtensionError),
107    /// An error occurred while creating the new group.
108    #[error(transparent)]
109    NewGroup(#[from] NewGroupError<StorageError>),
110    /// An error occurred while creating the commit.
111    #[error(transparent)]
112    CreateCommit(#[from] CreateCommitError),
113    /// An error occurred while staging the commit.
114    #[error(transparent)]
115    CommitBuilderStage(#[from] CommitBuilderStageError<StorageError>),
116}
117
118/// Defines a method that consumes self, passes it into a closure and returns the output of the
119/// closure. Comes in handy in long builder chains.
120trait PipeThrough: Sized {
121    fn pipe_through<T: Sized, F: FnMut(Self) -> T>(self, mut f: F) -> T {
122        f(self)
123    }
124}
125
126impl<T> PipeThrough for T {}
127
128#[cfg(test)]
129mod test {
130    use crate::{
131        framing::MlsMessageIn,
132        group::{
133            mls_group::tests_and_kats::utils::{setup_alice_bob_group, setup_client},
134            tests_and_kats::utils::{generate_key_package, CredentialWithKeyAndSigner},
135            Extensions, GroupId, StagedWelcome,
136        },
137    };
138
139    #[openmls_test::openmls_test]
140    fn example_reboot() {
141        let alice_provider = &Provider::default();
142        let bob_provider = &Provider::default();
143        let charlie_provider = &Provider::default();
144        let dave_provider = &Provider::default();
145
146        // Create group with alice and bob
147        let (mut alice_group, alice_signer, mut bob_group, bob_signer, alice_cwk, bob_cwk) =
148            setup_alice_bob_group(ciphersuite, alice_provider, bob_provider);
149
150        let (charlie_cwk, charlie_kpb, charlie_signer, _charlie_sig_pk) =
151            setup_client("Charlie", ciphersuite, charlie_provider);
152
153        let (_dave_cwk, dave_kpb, _dave_signer, _dave_sig_pk) =
154            setup_client("Dave", ciphersuite, dave_provider);
155
156        let bob_cwkas = CredentialWithKeyAndSigner {
157            credential_with_key: bob_cwk.clone(),
158            signer: bob_signer.clone(),
159        };
160
161        let charlie_cwkas = CredentialWithKeyAndSigner {
162            credential_with_key: charlie_cwk.clone(),
163            signer: charlie_signer.clone(),
164        };
165
166        // Alice and Bob concurrently invite someone and merge for whatever reason
167        alice_group
168            .commit_builder()
169            .propose_adds(Some(charlie_kpb.key_package().clone()))
170            .load_psks(alice_provider.storage())
171            .unwrap()
172            .build(
173                alice_provider.rand(),
174                alice_provider.crypto(),
175                &alice_signer,
176                |_| true,
177            )
178            .unwrap()
179            .stage_commit(alice_provider)
180            .unwrap();
181
182        bob_group
183            .commit_builder()
184            .propose_adds(Some(dave_kpb.key_package().clone()))
185            .load_psks(bob_provider.storage())
186            .unwrap()
187            .build(
188                bob_provider.rand(),
189                bob_provider.crypto(),
190                &bob_signer,
191                |_| true,
192            )
193            .unwrap()
194            .stage_commit(bob_provider)
195            .unwrap();
196
197        alice_group.merge_pending_commit(alice_provider).unwrap();
198        bob_group.merge_pending_commit(bob_provider).unwrap();
199
200        // We are forked now! Let's try to recover by rebooting. first get new key packages
201        let bob_new_kpb =
202            generate_key_package(ciphersuite, Extensions::empty(), bob_provider, bob_cwkas);
203
204        let charlie_new_kpb = generate_key_package(
205            ciphersuite,
206            Extensions::empty(),
207            charlie_provider,
208            charlie_cwkas,
209        );
210
211        // Now, reboot the group
212        let (mut new_alice_group, message_bundle) = alice_group
213            .reboot(GroupId::from_slice(b"new group id"))
214            .finish(
215                Extensions::empty(),
216                vec![
217                    bob_new_kpb.key_package().clone(),
218                    charlie_new_kpb.key_package().clone(),
219                ],
220                |builder| builder,
221                alice_provider,
222                &alice_signer,
223                alice_cwk.clone(),
224            )
225            .unwrap();
226
227        let (_commit, welcome, _group_info) = message_bundle.into_messages();
228        new_alice_group
229            .merge_pending_commit(alice_provider)
230            .unwrap();
231
232        // Invite everyone
233        let welcome = welcome.unwrap();
234        let welcome: MlsMessageIn = welcome.into();
235        let welcome = welcome.into_welcome().unwrap();
236        let ratchet_tree = new_alice_group.export_ratchet_tree();
237
238        let new_bob_group = StagedWelcome::new_from_welcome(
239            bob_provider,
240            alice_group.configuration(),
241            welcome.clone(),
242            Some(ratchet_tree.clone().into()),
243        )
244        .unwrap()
245        .into_group(bob_provider)
246        .unwrap();
247
248        let new_group_charlie = StagedWelcome::new_from_welcome(
249            charlie_provider,
250            alice_group.configuration(),
251            welcome.clone(),
252            Some(ratchet_tree.clone().into()),
253        )
254        .unwrap()
255        .into_group(bob_provider)
256        .unwrap();
257
258        let alice_comparison = new_alice_group
259            .export_secret(alice_provider.crypto(), "comparison", b"", 32)
260            .unwrap();
261
262        let bob_comparison = new_bob_group
263            .export_secret(bob_provider.crypto(), "comparison", b"", 32)
264            .unwrap();
265
266        let charlie_comparison = new_group_charlie
267            .export_secret(charlie_provider.crypto(), "comparison", b"", 32)
268            .unwrap();
269
270        assert_eq!(alice_comparison, bob_comparison);
271        assert_eq!(alice_comparison, charlie_comparison);
272    }
273}