openmls/group/fork_resolution/
readd.rs1use crate::binary_tree::LeafNodeIndex;
7
8use crate::{
9 group::{
10 commit_builder::{CommitBuilder, Initial},
11 Member, MlsGroup,
12 },
13 prelude::KeyPackage,
14};
15
16pub struct ReAddExpectKeyPackages {
18 complement_partition: Vec<Member>,
19}
20
21impl MlsGroup {
22 pub fn recover_fork_by_readding(
27 &'_ mut self,
28 own_partition: &[LeafNodeIndex],
29 ) -> Result<CommitBuilder<'_, ReAddExpectKeyPackages>, ReAddError> {
30 let own_partition: Option<Vec<_>> = own_partition
32 .iter()
33 .cloned()
34 .map(|leaf_index| self.member_at(leaf_index))
35 .collect();
36
37 let own_partition = own_partition.ok_or(ReAddError::InvalidLeafNodeIndex)?;
39
40 let complement_partition = complement(&own_partition, self.members()).collect();
42
43 let stage = ReAddExpectKeyPackages {
44 complement_partition,
45 };
46
47 Ok(self.commit_builder().into_stage(stage))
48 }
49}
50
51impl<'a> CommitBuilder<'a, ReAddExpectKeyPackages> {
52 pub fn complement_partition(&self) -> &[Member] {
54 self.stage().complement_partition.as_slice()
55 }
56
57 pub fn provide_key_packages(
60 self,
61 new_key_packages: Vec<KeyPackage>,
62 ) -> CommitBuilder<'a, Initial> {
63 let (stage, builder) = self.replace_stage(Initial::default());
64
65 builder
66 .propose_removals(stage.complement_partition.iter().map(|member| member.index))
67 .propose_adds(new_key_packages)
68 }
69}
70
71#[derive(Debug, thiserror::Error)]
72pub enum ReAddError {
74 #[error("An invalid leaf node index was provided.")]
76 InvalidLeafNodeIndex,
77}
78
79fn complement<'a, MembersIter>(
83 partition: &'a [Member],
84 members: MembersIter,
85) -> impl Iterator<Item = Member> + 'a
86where
87 MembersIter: IntoIterator<Item = Member> + 'a,
88{
89 members.into_iter().filter(|member| {
90 partition
91 .iter()
92 .all(|own_member| member.index != own_member.index)
93 })
94}
95
96#[cfg(test)]
97mod test {
98 use crate::{
99 framing::MlsMessageIn,
100 group::{
101 mls_group::tests_and_kats::utils::{setup_alice_bob_group, setup_client},
102 tests_and_kats::utils::{generate_key_package, CredentialWithKeyAndSigner},
103 Extensions, StagedWelcome,
104 },
105 };
106
107 #[openmls_test::openmls_test]
108 fn example_readd() {
109 let alice_provider = &Provider::default();
110 let bob_provider = &Provider::default();
111 let charlie_provider = &Provider::default();
112 let dave_provider = &Provider::default();
113
114 let (mut alice_group, alice_signer, mut bob_group, bob_signer, _alice_cwk, bob_cwk) =
116 setup_alice_bob_group(ciphersuite, alice_provider, bob_provider);
117
118 let (charlie_cwk, charlie_kpb, charlie_signer, _charlie_sig_pk) =
119 setup_client("Charlie", ciphersuite, charlie_provider);
120
121 let (_dave_cwk, dave_kpb, _dave_signer, _dave_sig_pk) =
122 setup_client("Dave", ciphersuite, dave_provider);
123
124 let bob_cwkas = CredentialWithKeyAndSigner {
125 credential_with_key: bob_cwk.clone(),
126 signer: bob_signer.clone(),
127 };
128
129 let charlie_cwkas = CredentialWithKeyAndSigner {
130 credential_with_key: charlie_cwk.clone(),
131 signer: charlie_signer.clone(),
132 };
133
134 alice_group
136 .commit_builder()
137 .propose_adds(Some(charlie_kpb.key_package().clone()))
138 .load_psks(alice_provider.storage())
139 .unwrap()
140 .build(
141 alice_provider.rand(),
142 alice_provider.crypto(),
143 &alice_signer,
144 |_| true,
145 )
146 .unwrap()
147 .stage_commit(alice_provider)
148 .unwrap();
149
150 bob_group
151 .commit_builder()
152 .propose_adds(Some(dave_kpb.key_package().clone()))
153 .load_psks(bob_provider.storage())
154 .unwrap()
155 .build(
156 bob_provider.rand(),
157 bob_provider.crypto(),
158 &bob_signer,
159 |_| true,
160 )
161 .unwrap()
162 .stage_commit(bob_provider)
163 .unwrap();
164
165 alice_group.merge_pending_commit(alice_provider).unwrap();
166 bob_group.merge_pending_commit(bob_provider).unwrap();
167
168 let bob_new_kpb =
170 generate_key_package(ciphersuite, Extensions::empty(), bob_provider, bob_cwkas);
171
172 let charlie_new_kpb = generate_key_package(
173 ciphersuite,
174 Extensions::empty(),
175 charlie_provider,
176 charlie_cwkas,
177 );
178
179 let builder = alice_group
181 .recover_fork_by_readding(&[alice_group.own_leaf_index()])
182 .unwrap();
183 let key_packages = builder
184 .complement_partition()
185 .iter()
186 .map(|member| match member.credential.serialized_content() {
187 b"Bob" => bob_new_kpb.key_package().clone(),
188 b"Charlie" => charlie_new_kpb.key_package().clone(),
189 _ => unreachable!(),
190 })
191 .collect();
192
193 let message_bundle = builder
194 .provide_key_packages(key_packages)
195 .load_psks(alice_provider.storage())
196 .unwrap()
197 .build(
198 alice_provider.rand(),
199 alice_provider.crypto(),
200 &alice_signer,
201 |_| true,
202 )
203 .unwrap()
204 .stage_commit(alice_provider)
205 .unwrap();
206
207 let (_commit, welcome, _group_info) = message_bundle.into_messages();
208 alice_group.merge_pending_commit(alice_provider).unwrap();
209
210 let welcome = welcome.unwrap();
212 let welcome: MlsMessageIn = welcome.into();
213 let welcome = welcome.into_welcome().unwrap();
214 let ratchet_tree = alice_group.export_ratchet_tree();
215
216 let new_bob_group = StagedWelcome::new_from_welcome(
217 bob_provider,
218 alice_group.configuration(),
219 welcome.clone(),
220 Some(ratchet_tree.clone().into()),
221 )
222 .unwrap()
223 .into_group(bob_provider)
224 .unwrap();
225
226 let new_group_charlie = StagedWelcome::new_from_welcome(
227 charlie_provider,
228 alice_group.configuration(),
229 welcome.clone(),
230 Some(ratchet_tree.clone().into()),
231 )
232 .unwrap()
233 .into_group(bob_provider)
234 .unwrap();
235
236 let alice_comparison = alice_group
237 .export_secret(alice_provider.crypto(), "comparison", b"", 32)
238 .unwrap();
239
240 let bob_comparison = new_bob_group
241 .export_secret(bob_provider.crypto(), "comparison", b"", 32)
242 .unwrap();
243
244 let charlie_comparison = new_group_charlie
245 .export_secret(charlie_provider.crypto(), "comparison", b"", 32)
246 .unwrap();
247
248 assert_eq!(alice_comparison, bob_comparison);
249 assert_eq!(alice_comparison, charlie_comparison);
250 }
251}