openmls/group/fork_resolution/
reboot.rs1use openmls_traits::signatures::Signer;
6
7use crate::{
8 credentials::CredentialWithKey,
9 extensions::errors::InvalidExtensionError,
10 group::{
11 commit_builder::{CommitBuilder, CommitMessageBundle, Initial},
12 mls_group::builder::MlsGroupBuilder,
13 CommitBuilderStageError, CreateCommitError, Extensions, GroupId, Member, MlsGroup,
14 NewGroupError,
15 },
16 prelude::GroupContext,
17 prelude::KeyPackage,
18 storage::OpenMlsProvider,
19};
20
21impl MlsGroup {
22 pub fn reboot(&'_ self, group_id: GroupId) -> RebootBuilder<'_> {
25 let group_builder = MlsGroup::builder()
26 .with_wire_format_policy(self.configuration().wire_format_policy)
27 .padding_size(self.configuration().padding_size)
28 .max_past_epochs(self.configuration().max_past_epochs)
29 .number_of_resumption_psks(self.configuration().number_of_resumption_psks)
30 .use_ratchet_tree_extension(self.configuration().use_ratchet_tree_extension)
31 .sender_ratchet_configuration(self.configuration().sender_ratchet_configuration)
32 .ciphersuite(self.ciphersuite())
33 .with_group_id(group_id);
34
35 RebootBuilder {
36 group: self,
37 group_builder,
38 }
39 }
40}
41
42pub struct RebootBuilder<'a> {
44 group: &'a MlsGroup,
45 group_builder: MlsGroupBuilder,
46}
47
48impl<'a> RebootBuilder<'a> {
49 pub fn old_group_context_extensions(&self) -> &Extensions<GroupContext> {
52 self.group.context().extensions()
53 }
54
55 pub fn old_members(&self) -> impl Iterator<Item = Member> + 'a {
57 self.group
58 .members()
59 .filter(|member| member.index != self.group.own_leaf_index())
60 }
61
62 pub fn refine_group_builder(
64 self,
65 mut f: impl FnMut(MlsGroupBuilder) -> MlsGroupBuilder,
66 ) -> Self {
67 Self {
68 group_builder: f(self.group_builder),
69 ..self
70 }
71 }
72
73 pub fn finish<Provider: OpenMlsProvider>(
77 self,
78 extensions: Extensions<GroupContext>,
79 new_members: Vec<KeyPackage>,
80 refine_commit_builder: impl FnMut(CommitBuilder<Initial>) -> CommitBuilder<Initial>,
81 provider: &Provider,
82 signer: &impl Signer,
83 credential_with_key: CredentialWithKey,
84 ) -> Result<(MlsGroup, CommitMessageBundle), RebootError<Provider::StorageError>> {
85 let group_builder = self.group_builder.with_group_context_extensions(extensions);
86
87 let mut new_group = group_builder.build(provider, signer, credential_with_key)?;
88
89 new_group
90 .commit_builder()
91 .propose_adds(new_members)
92 .pipe_through(refine_commit_builder)
93 .load_psks(provider.storage())?
94 .build(provider.rand(), provider.crypto(), signer, |_| true)?
95 .stage_commit(provider)
96 .map_err(RebootError::CommitBuilderStage)
97 .map(|message_bundle| (new_group, message_bundle))
98 }
99}
100
101#[derive(Debug, thiserror::Error)]
103pub enum RebootError<StorageError> {
104 #[error(transparent)]
106 InvalidExtension(#[from] InvalidExtensionError),
107 #[error(transparent)]
109 NewGroup(#[from] NewGroupError<StorageError>),
110 #[error(transparent)]
112 CreateCommit(#[from] CreateCommitError),
113 #[error(transparent)]
115 CommitBuilderStage(#[from] CommitBuilderStageError<StorageError>),
116}
117
118trait PipeThrough: Sized {
121 fn pipe_through<T: Sized, F: FnMut(Self) -> T>(self, mut f: F) -> T {
122 f(self)
123 }
124}
125
126impl<T> PipeThrough for T {}
127
128#[cfg(test)]
129mod test {
130 use crate::{
131 framing::MlsMessageIn,
132 group::{
133 mls_group::tests_and_kats::utils::{setup_alice_bob_group, setup_client},
134 tests_and_kats::utils::{generate_key_package, CredentialWithKeyAndSigner},
135 Extensions, GroupId, StagedWelcome,
136 },
137 };
138
139 #[openmls_test::openmls_test]
140 fn example_reboot() {
141 let alice_provider = &Provider::default();
142 let bob_provider = &Provider::default();
143 let charlie_provider = &Provider::default();
144 let dave_provider = &Provider::default();
145
146 let (mut alice_group, alice_signer, mut bob_group, bob_signer, alice_cwk, bob_cwk) =
148 setup_alice_bob_group(ciphersuite, alice_provider, bob_provider);
149
150 let (charlie_cwk, charlie_kpb, charlie_signer, _charlie_sig_pk) =
151 setup_client("Charlie", ciphersuite, charlie_provider);
152
153 let (_dave_cwk, dave_kpb, _dave_signer, _dave_sig_pk) =
154 setup_client("Dave", ciphersuite, dave_provider);
155
156 let bob_cwkas = CredentialWithKeyAndSigner {
157 credential_with_key: bob_cwk.clone(),
158 signer: bob_signer.clone(),
159 };
160
161 let charlie_cwkas = CredentialWithKeyAndSigner {
162 credential_with_key: charlie_cwk.clone(),
163 signer: charlie_signer.clone(),
164 };
165
166 alice_group
168 .commit_builder()
169 .propose_adds(Some(charlie_kpb.key_package().clone()))
170 .load_psks(alice_provider.storage())
171 .unwrap()
172 .build(
173 alice_provider.rand(),
174 alice_provider.crypto(),
175 &alice_signer,
176 |_| true,
177 )
178 .unwrap()
179 .stage_commit(alice_provider)
180 .unwrap();
181
182 bob_group
183 .commit_builder()
184 .propose_adds(Some(dave_kpb.key_package().clone()))
185 .load_psks(bob_provider.storage())
186 .unwrap()
187 .build(
188 bob_provider.rand(),
189 bob_provider.crypto(),
190 &bob_signer,
191 |_| true,
192 )
193 .unwrap()
194 .stage_commit(bob_provider)
195 .unwrap();
196
197 alice_group.merge_pending_commit(alice_provider).unwrap();
198 bob_group.merge_pending_commit(bob_provider).unwrap();
199
200 let bob_new_kpb =
202 generate_key_package(ciphersuite, Extensions::empty(), bob_provider, bob_cwkas);
203
204 let charlie_new_kpb = generate_key_package(
205 ciphersuite,
206 Extensions::empty(),
207 charlie_provider,
208 charlie_cwkas,
209 );
210
211 let (mut new_alice_group, message_bundle) = alice_group
213 .reboot(GroupId::from_slice(b"new group id"))
214 .finish(
215 Extensions::empty(),
216 vec![
217 bob_new_kpb.key_package().clone(),
218 charlie_new_kpb.key_package().clone(),
219 ],
220 |builder| builder,
221 alice_provider,
222 &alice_signer,
223 alice_cwk.clone(),
224 )
225 .unwrap();
226
227 let (_commit, welcome, _group_info) = message_bundle.into_messages();
228 new_alice_group
229 .merge_pending_commit(alice_provider)
230 .unwrap();
231
232 let welcome = welcome.unwrap();
234 let welcome: MlsMessageIn = welcome.into();
235 let welcome = welcome.into_welcome().unwrap();
236 let ratchet_tree = new_alice_group.export_ratchet_tree();
237
238 let new_bob_group = StagedWelcome::new_from_welcome(
239 bob_provider,
240 alice_group.configuration(),
241 welcome.clone(),
242 Some(ratchet_tree.clone().into()),
243 )
244 .unwrap()
245 .into_group(bob_provider)
246 .unwrap();
247
248 let new_group_charlie = StagedWelcome::new_from_welcome(
249 charlie_provider,
250 alice_group.configuration(),
251 welcome.clone(),
252 Some(ratchet_tree.clone().into()),
253 )
254 .unwrap()
255 .into_group(bob_provider)
256 .unwrap();
257
258 let alice_comparison = new_alice_group
259 .export_secret(alice_provider.crypto(), "comparison", b"", 32)
260 .unwrap();
261
262 let bob_comparison = new_bob_group
263 .export_secret(bob_provider.crypto(), "comparison", b"", 32)
264 .unwrap();
265
266 let charlie_comparison = new_group_charlie
267 .export_secret(charlie_provider.crypto(), "comparison", b"", 32)
268 .unwrap();
269
270 assert_eq!(alice_comparison, bob_comparison);
271 assert_eq!(alice_comparison, charlie_comparison);
272 }
273}