openmls/group/fork_resolution/
readd.rs

1//! This module contains helpers for removing and re-adding members that merged a wrong
2//! commit. It is the responsibility of the application to determine which commit is the right one,
3//! as well as which members need to be re-added. This is a relatively cheap mechanism, but it
4//! requires knowing about the partitions.
5
6use crate::binary_tree::LeafNodeIndex;
7
8use crate::{
9    group::{
10        commit_builder::{CommitBuilder, Initial},
11        Member, MlsGroup,
12    },
13    prelude::KeyPackage,
14};
15
16/// A stage for the [`CommitBuilder`] for removing and re-adding members from other partitions.
17pub struct ReAddExpectKeyPackages {
18    complement_partition: Vec<Member>,
19}
20
21impl MlsGroup {
22    /// Create a [`CommitBuilder`] that is preparing to remove and re-add members from other fork
23    /// partitions. `own_partition` is the list of [`LeafNodeIndex`] that are members in the
24    /// partition that the initiating client is in. This should include the [`LeafNodeIndex`] of
25    /// the initiating client.
26    pub fn recover_fork_by_readding(
27        &'_ mut self,
28        own_partition: &[LeafNodeIndex],
29    ) -> Result<CommitBuilder<'_, ReAddExpectKeyPackages>, ReAddError> {
30        // Load member info. This is None if at least one of the indexes is not a valid member
31        let own_partition: Option<Vec<_>> = own_partition
32            .iter()
33            .cloned()
34            .map(|leaf_index| self.member_at(leaf_index))
35            .collect();
36
37        // Fail if there is a leaf node index that is not a member
38        let own_partition = own_partition.ok_or(ReAddError::InvalidLeafNodeIndex)?;
39
40        // Compute the complement partition, i.e. the list of members that are not in our partition
41        let complement_partition = complement(&own_partition, self.members()).collect();
42
43        let stage = ReAddExpectKeyPackages {
44            complement_partition,
45        };
46
47        Ok(self.commit_builder().into_stage(stage))
48    }
49}
50
51impl<'a> CommitBuilder<'a, ReAddExpectKeyPackages> {
52    /// Returns the complement partition, i.e. the list of members that are not in our partition.
53    pub fn complement_partition(&self) -> &[Member] {
54        self.stage().complement_partition.as_slice()
55    }
56
57    /// Takes the key packages needed to re-add the other members and returns the prepared
58    /// [`CommitBuilder`].
59    pub fn provide_key_packages(
60        self,
61        new_key_packages: Vec<KeyPackage>,
62    ) -> CommitBuilder<'a, Initial> {
63        let (stage, builder) = self.replace_stage(Initial::default());
64
65        builder
66            .propose_removals(stage.complement_partition.iter().map(|member| member.index))
67            .propose_adds(new_key_packages)
68    }
69}
70
71#[derive(Debug, thiserror::Error)]
72/// Indicates an error occurred during re-adding
73pub enum ReAddError {
74    /// An invalid leaf node index was provided
75    #[error("An invalid leaf node index was provided.")]
76    InvalidLeafNodeIndex,
77}
78
79/// Computes the complement partition of the provided list of members.
80// NOTE: If we require that the list of LeafNodeIndex is ordered, we can make this O(n) instead
81// of O(n^2).
82fn complement<'a, MembersIter>(
83    partition: &'a [Member],
84    members: MembersIter,
85) -> impl Iterator<Item = Member> + 'a
86where
87    MembersIter: IntoIterator<Item = Member> + 'a,
88{
89    members.into_iter().filter(|member| {
90        partition
91            .iter()
92            .all(|own_member| member.index != own_member.index)
93    })
94}
95
96#[cfg(test)]
97mod test {
98    use crate::{
99        framing::MlsMessageIn,
100        group::{
101            mls_group::tests_and_kats::utils::{setup_alice_bob_group, setup_client},
102            tests_and_kats::utils::{generate_key_package, CredentialWithKeyAndSigner},
103            Extensions, StagedWelcome,
104        },
105    };
106
107    #[openmls_test::openmls_test]
108    fn example_readd() {
109        let alice_provider = &Provider::default();
110        let bob_provider = &Provider::default();
111        let charlie_provider = &Provider::default();
112        let dave_provider = &Provider::default();
113
114        // Create group with alice and bob
115        let (mut alice_group, alice_signer, mut bob_group, bob_signer, _alice_cwk, bob_cwk) =
116            setup_alice_bob_group(ciphersuite, alice_provider, bob_provider);
117
118        let (charlie_cwk, charlie_kpb, charlie_signer, _charlie_sig_pk) =
119            setup_client("Charlie", ciphersuite, charlie_provider);
120
121        let (_dave_cwk, dave_kpb, _dave_signer, _dave_sig_pk) =
122            setup_client("Dave", ciphersuite, dave_provider);
123
124        let bob_cwkas = CredentialWithKeyAndSigner {
125            credential_with_key: bob_cwk.clone(),
126            signer: bob_signer.clone(),
127        };
128
129        let charlie_cwkas = CredentialWithKeyAndSigner {
130            credential_with_key: charlie_cwk.clone(),
131            signer: charlie_signer.clone(),
132        };
133
134        // Alice and Bob concurrently invite someone and merge for whatever reason
135        alice_group
136            .commit_builder()
137            .propose_adds(Some(charlie_kpb.key_package().clone()))
138            .load_psks(alice_provider.storage())
139            .unwrap()
140            .build(
141                alice_provider.rand(),
142                alice_provider.crypto(),
143                &alice_signer,
144                |_| true,
145            )
146            .unwrap()
147            .stage_commit(alice_provider)
148            .unwrap();
149
150        bob_group
151            .commit_builder()
152            .propose_adds(Some(dave_kpb.key_package().clone()))
153            .load_psks(bob_provider.storage())
154            .unwrap()
155            .build(
156                bob_provider.rand(),
157                bob_provider.crypto(),
158                &bob_signer,
159                |_| true,
160            )
161            .unwrap()
162            .stage_commit(bob_provider)
163            .unwrap();
164
165        alice_group.merge_pending_commit(alice_provider).unwrap();
166        bob_group.merge_pending_commit(bob_provider).unwrap();
167
168        // We are forked now! Let's try to recover by rebooting. first get new key packages
169        let bob_new_kpb =
170            generate_key_package(ciphersuite, Extensions::empty(), bob_provider, bob_cwkas);
171
172        let charlie_new_kpb = generate_key_package(
173            ciphersuite,
174            Extensions::empty(),
175            charlie_provider,
176            charlie_cwkas,
177        );
178
179        // Now, re-add bob to the group
180        let builder = alice_group
181            .recover_fork_by_readding(&[alice_group.own_leaf_index()])
182            .unwrap();
183        let key_packages = builder
184            .complement_partition()
185            .iter()
186            .map(|member| match member.credential.serialized_content() {
187                b"Bob" => bob_new_kpb.key_package().clone(),
188                b"Charlie" => charlie_new_kpb.key_package().clone(),
189                _ => unreachable!(),
190            })
191            .collect();
192
193        let message_bundle = builder
194            .provide_key_packages(key_packages)
195            .load_psks(alice_provider.storage())
196            .unwrap()
197            .build(
198                alice_provider.rand(),
199                alice_provider.crypto(),
200                &alice_signer,
201                |_| true,
202            )
203            .unwrap()
204            .stage_commit(alice_provider)
205            .unwrap();
206
207        let (_commit, welcome, _group_info) = message_bundle.into_messages();
208        alice_group.merge_pending_commit(alice_provider).unwrap();
209
210        // Invite everyone
211        let welcome = welcome.unwrap();
212        let welcome: MlsMessageIn = welcome.into();
213        let welcome = welcome.into_welcome().unwrap();
214        let ratchet_tree = alice_group.export_ratchet_tree();
215
216        let new_bob_group = StagedWelcome::new_from_welcome(
217            bob_provider,
218            alice_group.configuration(),
219            welcome.clone(),
220            Some(ratchet_tree.clone().into()),
221        )
222        .unwrap()
223        .into_group(bob_provider)
224        .unwrap();
225
226        let new_group_charlie = StagedWelcome::new_from_welcome(
227            charlie_provider,
228            alice_group.configuration(),
229            welcome.clone(),
230            Some(ratchet_tree.clone().into()),
231        )
232        .unwrap()
233        .into_group(bob_provider)
234        .unwrap();
235
236        let alice_comparison = alice_group
237            .export_secret(alice_provider.crypto(), "comparison", b"", 32)
238            .unwrap();
239
240        let bob_comparison = new_bob_group
241            .export_secret(bob_provider.crypto(), "comparison", b"", 32)
242            .unwrap();
243
244        let charlie_comparison = new_group_charlie
245            .export_secret(charlie_provider.crypto(), "comparison", b"", 32)
246            .unwrap();
247
248        assert_eq!(alice_comparison, bob_comparison);
249        assert_eq!(alice_comparison, charlie_comparison);
250    }
251}