openmls/test_utils/single_group_test_framework/
mod.rs

1use openmls_basic_credential::SignatureKeyPair;
2use openmls_traits::{signatures::Signer, types::SignatureScheme};
3pub use openmls_traits::{
4    storage::StorageProvider as StorageProviderTrait,
5    types::{Ciphersuite, HpkeKeyPair},
6    OpenMlsProvider,
7};
8
9pub use crate::utils::*;
10use crate::{
11    credentials::CredentialWithKey,
12    key_packages::KeyPackageBuilder,
13    prelude::{commit_builder::*, *},
14};
15
16use crate::test_utils::storage_state::GroupStorageState;
17
18mod assertions;
19
20mod errors;
21pub use errors::GroupError;
22use errors::*;
23
24use std::collections::HashMap;
25
26// type alias for &'static str
27type Name = &'static str;
28
29// TODO: only define this once
30/// Helper function for generating a credential.
31pub fn generate_credential(
32    identity: Vec<u8>,
33    signature_algorithm: SignatureScheme,
34    provider: &impl crate::storage::OpenMlsProvider,
35) -> (CredentialWithKey, SignatureKeyPair) {
36    let credential = BasicCredential::new(identity);
37    let signature_keys = SignatureKeyPair::new(signature_algorithm).unwrap();
38    signature_keys.store(provider.storage()).unwrap();
39
40    (
41        CredentialWithKey {
42            credential: credential.into(),
43            signature_key: signature_keys.to_public_vec().into(),
44        },
45        signature_keys,
46    )
47}
48
49// TODO: only define this once
50pub(crate) fn generate_key_package(
51    ciphersuite: Ciphersuite,
52    credential_with_key: CredentialWithKey,
53    extensions: Extensions,
54    provider: &impl crate::storage::OpenMlsProvider,
55    lifetime: impl Into<Option<Lifetime>>,
56    signer: &impl Signer,
57) -> KeyPackageBundle {
58    let mut builder = KeyPackage::builder().key_package_extensions(extensions);
59
60    if let Some(lifetime) = lifetime.into() {
61        builder = builder.key_package_lifetime(lifetime);
62    }
63
64    builder
65        .build(ciphersuite, provider, signer, credential_with_key)
66        .unwrap()
67}
68
69/// Struct representing a party's global state
70pub struct CorePartyState<Provider> {
71    pub name: Name,
72    pub provider: Provider,
73}
74
75impl<Provider: Default> CorePartyState<Provider> {
76    pub fn new(name: Name) -> Self {
77        Self {
78            name,
79            provider: Provider::default(),
80        }
81    }
82}
83
84/// Struct representing a party's state before joining a group
85pub struct PreGroupPartyState<'a, Provider> {
86    pub credential_with_key: CredentialWithKey,
87    // TODO: regenerate?
88    pub key_package_bundle: KeyPackageBundle,
89    pub signer: SignatureKeyPair,
90    pub core_state: &'a CorePartyState<Provider>,
91}
92
93// XXX: This should probably get a builder at some point.
94impl<Provider: OpenMlsProvider> CorePartyState<Provider> {
95    /// Generates the pre-group state for a `CorePartyState`
96    pub fn generate_pre_group(&self, ciphersuite: Ciphersuite) -> PreGroupPartyState<'_, Provider> {
97        self.generate_pre_group_lifetime(ciphersuite, None)
98    }
99
100    /// Generates the pre-group state for a `CorePartyState`
101    pub fn generate_pre_group_lifetime(
102        &self,
103        ciphersuite: Ciphersuite,
104        lifetime: impl Into<Option<Lifetime>>,
105    ) -> PreGroupPartyState<'_, Provider> {
106        let (credential_with_key, signer) = generate_credential(
107            self.name.into(),
108            ciphersuite.signature_algorithm(),
109            &self.provider,
110        );
111
112        let key_package_bundle = generate_key_package(
113            ciphersuite,
114            credential_with_key.clone(),
115            Extensions::default(), // TODO: provide as argument?
116            &self.provider,
117            lifetime,
118            &signer,
119        );
120
121        PreGroupPartyState {
122            credential_with_key,
123            key_package_bundle,
124            signer,
125            core_state: self,
126        }
127    }
128}
129
130/// Represents a group member's `MlsGroup` instance and pre-group state
131pub struct MemberState<'a, Provider> {
132    pub party: PreGroupPartyState<'a, Provider>,
133    pub group: MlsGroup,
134}
135
136impl<Provider: OpenMlsProvider> MemberState<'_, Provider> {
137    /// Get member's `SignatureKeyPair` if available
138    pub fn get_storage_signature_key_pair(&self) -> Option<SignatureKeyPair> {
139        let ciphersuite = self
140            .party
141            .key_package_bundle
142            .key_package()
143            .ciphersuite()
144            .into();
145
146        SignatureKeyPair::read(
147            self.party.core_state.provider.storage(),
148            self.party.signer.public(),
149            ciphersuite,
150        )
151    }
152    /// Get the `GroupStorageState` for this group
153    pub fn group_storage_state(&self) -> GroupStorageState {
154        let storage_provider = self.party.core_state.provider.storage();
155        let group_id = self.group.group_id();
156
157        GroupStorageState::from_storage(storage_provider, group_id)
158    }
159    /// Deliver_and_apply a message to this member's `MlsGroup`
160    pub fn deliver_and_apply(&mut self, message: MlsMessageIn) -> Result<(), GroupError<Provider>> {
161        let message = message.try_into_protocol_message()?;
162
163        // process message
164        let processed_message = self
165            .group
166            .process_message(&self.party.core_state.provider, message)?;
167
168        match processed_message.into_content() {
169            ProcessedMessageContent::ApplicationMessage(_) => todo!(),
170            ProcessedMessageContent::ProposalMessage(_) => todo!(),
171            ProcessedMessageContent::ExternalJoinProposalMessage(_) => todo!(),
172            ProcessedMessageContent::StagedCommitMessage(m) => self
173                .group
174                .merge_staged_commit(&self.party.core_state.provider, *m)?,
175        };
176
177        Ok(())
178    }
179}
180
181impl<'commit_builder, 'b: 'commit_builder, 'a: 'b, Provider> MemberState<'a, Provider>
182where
183    Provider: openmls_traits::OpenMlsProvider,
184{
185    /// Build and stage a commit, using the provided closure to add proposals
186    pub fn build_commit_and_stage(
187        &'b mut self,
188        f: impl FnOnce(
189            CommitBuilder<'commit_builder, Initial>,
190        ) -> CommitBuilder<'commit_builder, Initial>,
191    ) -> Result<CommitMessageBundle, GroupError<Provider>> {
192        let commit_builder = f(self.group.commit_builder());
193
194        let provider = &self.party.core_state.provider;
195
196        // TODO: most of the steps here cannot be done via the closure (yet)
197        let bundle = commit_builder
198            .load_psks(provider.storage())?
199            .build(
200                provider.rand(),
201                provider.crypto(),
202                &self.party.signer,
203                |_| true,
204            )?
205            .stage_commit(provider)?;
206
207        Ok(bundle)
208    }
209}
210
211impl<'a, Provider: OpenMlsProvider> MemberState<'a, Provider> {
212    /// Create a `MemberState` from a `PreGroupPartyState`. This creates a new `MlsGroup` with one
213    /// member
214    pub fn create_from_pre_group(
215        party: PreGroupPartyState<'a, Provider>,
216        mls_group_create_config: MlsGroupCreateConfig,
217        group_id: GroupId,
218    ) -> Result<Self, GroupError<Provider>> {
219        // initialize MlsGroup
220        let group = MlsGroup::new_with_group_id(
221            &party.core_state.provider,
222            &party.signer,
223            &mls_group_create_config,
224            group_id,
225            party.credential_with_key.clone(),
226        )?;
227
228        Ok(Self { party, group })
229    }
230    /// Create a `MemberState` from a `Welcome`, which creates a new `MlsGroup` using a `Welcome`
231    /// invitation from an existing group
232    pub fn join_from_pre_group(
233        party: PreGroupPartyState<'a, Provider>,
234        mls_group_join_config: MlsGroupJoinConfig,
235        welcome: Welcome,
236        tree: Option<RatchetTreeIn>,
237    ) -> Result<Self, GroupError<Provider>> {
238        let staged_join = StagedWelcome::new_from_welcome(
239            &party.core_state.provider,
240            &mls_group_join_config,
241            welcome,
242            tree,
243        )?;
244
245        let group = staged_join.into_group(&party.core_state.provider)?;
246
247        Ok(Self { party, group })
248    }
249}
250
251/// All of the state for a group and its members
252pub struct GroupState<'a, Provider> {
253    group_id: GroupId,
254    members: HashMap<Name, MemberState<'a, Provider>>,
255}
256
257impl<'a, Provider: OpenMlsProvider> GroupState<'a, Provider> {
258    /// Create a new `GroupState` from a single party
259    pub fn new_from_party(
260        group_id: GroupId,
261        pre_group_state: PreGroupPartyState<'a, Provider>,
262        mls_group_create_config: MlsGroupCreateConfig,
263    ) -> Result<Self, GroupError<Provider>> {
264        let mut members = HashMap::new();
265
266        let name = pre_group_state.core_state.name;
267        let member_state = MemberState::create_from_pre_group(
268            pre_group_state,
269            mls_group_create_config,
270            group_id.clone(),
271        )?;
272
273        members.insert(name, member_state);
274
275        Ok(Self { group_id, members })
276    }
277
278    /// Get mutable references to specified `MemberState`s as a fixed-size array,
279    /// in the order of the names provided in `names`.
280    /// At least one member must be requested.
281    pub fn members_mut<const N: usize>(
282        &mut self,
283        names: &[Name; N],
284    ) -> [&mut MemberState<'a, Provider>; N] {
285        assert!(N > 0, "must request at least one member");
286        assert!(
287            N <= self.members.len(),
288            "cannot request more members than available"
289        );
290
291        // map each member in `self.members` to its name's index in `names`
292        let mut members: [(_, _); N] = self
293            .members
294            .iter_mut()
295            .filter_map(|(member_name, member)| {
296                // Find the index of the member's name in `names`
297                // NOTE: the list of names provided to this method will generally be short,
298                // so not many comparisons are made here.
299                let index = names.iter().position(|name| name == member_name)?;
300
301                Some((index, member))
302            })
303            // collect into Vec, then into fixed-size array
304            .collect::<Vec<_>>()
305            .try_into()
306            .ok()
307            .expect("At least one requested member not found");
308
309        // sort by index
310        members.sort_by_key(|(pos, _member)| *pos);
311
312        members.map(|(_pos, member)| member)
313    }
314
315    /// Deliver_and_apply a message to all parties
316    pub fn deliver_and_apply(&mut self, message: MlsMessageIn) -> Result<(), GroupError<Provider>> {
317        self.deliver_and_apply_if(message, |_| true)
318    }
319    /// Deliver_and_apply a message to all parties if a provided condition is met
320    pub fn deliver_and_apply_if(
321        &mut self,
322        message: MlsMessageIn,
323        condition: impl Fn(&MemberState<'a, Provider>) -> bool,
324    ) -> Result<(), GroupError<Provider>> {
325        self.members
326            .values_mut()
327            .filter(|member| condition(member))
328            .try_for_each(|member| member.deliver_and_apply(message.clone()))?;
329
330        Ok(())
331    }
332
333    /// Deliver_and_apply a welcome to a single party, and initialize a group for that party
334    pub fn deliver_and_apply_welcome(
335        &mut self,
336        recipient: PreGroupPartyState<'a, Provider>,
337        mls_group_join_config: MlsGroupJoinConfig,
338        welcome: Welcome,
339        tree: Option<RatchetTreeIn>,
340    ) -> Result<(), GroupError<Provider>> {
341        // create new group
342        let name = recipient.core_state.name;
343
344        let member_state =
345            MemberState::join_from_pre_group(recipient, mls_group_join_config, welcome, tree)?;
346
347        // insert after success
348        self.members.insert(name, member_state);
349
350        Ok(())
351    }
352
353    /// Drop a member from the internal hashmap. This does not delete the member from any
354    /// `MlsGroup`
355    pub fn untrack_member(&mut self, name: Name) {
356        let _ = self.members.remove(&name);
357    }
358
359    pub fn add_member(
360        &mut self,
361        add_config: AddMemberConfig<'a, Provider>,
362    ) -> Result<(), GroupError<Provider>> {
363        let adder = self
364            .members
365            .get_mut(add_config.adder)
366            .ok_or(TestError::NoSuchMember)?;
367
368        let key_packages: Vec<_> = add_config
369            .addees
370            .iter()
371            .map(|addee| addee.key_package_bundle.key_package.clone())
372            .collect();
373
374        let (commit, welcome, _) = adder.group.add_members(
375            &adder.party.core_state.provider,
376            &adder.party.signer,
377            &key_packages,
378        )?;
379
380        // Deliver_and_apply to all members but adder
381        self.deliver_and_apply_if(commit.into(), |member| {
382            member.party.core_state.name != add_config.adder
383        })?;
384
385        // Deliver_and_apply welcome to addee
386        let welcome = match welcome.body() {
387            MlsMessageBodyOut::Welcome(welcome) => welcome.clone(),
388            _ => panic!("No welcome returned"),
389        };
390
391        for addee in add_config.addees.into_iter() {
392            self.deliver_and_apply_welcome(
393                addee,
394                add_config.join_config.clone(),
395                welcome.clone(),
396                None,
397            )?;
398        }
399
400        let adder = self
401            .members
402            .get_mut(add_config.adder)
403            .ok_or(TestError::NoSuchMember)?;
404
405        let staged_commit = adder.group.pending_commit().unwrap().clone();
406
407        adder
408            .group
409            .merge_staged_commit(&adder.party.core_state.provider, staged_commit)?;
410
411        Ok(())
412    }
413
414    /// Returns a copy of the GroupId
415    pub fn group_id(&self) -> GroupId {
416        self.group_id.clone()
417    }
418}
419
420impl MlsGroupCreateConfig {
421    /// Default config for test framework
422    pub fn test_default_from_ciphersuite(ciphersuite: Ciphersuite) -> Self {
423        MlsGroupCreateConfig::builder()
424            .ciphersuite(ciphersuite)
425            .use_ratchet_tree_extension(true)
426            .wire_format_policy(PURE_PLAINTEXT_WIRE_FORMAT_POLICY) // Important because the secret tree might diverge otherwise
427            .build()
428    }
429}
430
431pub struct AddMemberConfig<'a, Provider> {
432    pub adder: Name,
433    pub addees: Vec<PreGroupPartyState<'a, Provider>>,
434    pub join_config: MlsGroupJoinConfig,
435    pub tree: Option<RatchetTreeIn>,
436}
437
438#[cfg(test)]
439mod test {
440
441    use super::*;
442    use openmls_test::openmls_test;
443
444    #[openmls_test]
445    fn test_members_mut() {
446        let alice_party = CorePartyState::<Provider>::new("alice");
447        let bob_party = CorePartyState::<Provider>::new("bob");
448        let charlie_party = CorePartyState::<Provider>::new("charlie");
449        let dave_party = CorePartyState::<Provider>::new("dave");
450
451        let alice_pre_group = alice_party.generate_pre_group(ciphersuite);
452        let bob_pre_group = bob_party.generate_pre_group(ciphersuite);
453        let charlie_pre_group = charlie_party.generate_pre_group(ciphersuite);
454        let dave_pre_group = dave_party.generate_pre_group(ciphersuite);
455
456        // Create config
457        let mls_group_create_config = MlsGroupCreateConfig::builder()
458            .ciphersuite(ciphersuite)
459            .use_ratchet_tree_extension(true)
460            .build();
461
462        // Join config
463        let mls_group_join_config = mls_group_create_config.join_config().clone();
464
465        // Initialize the group state
466        let group_id = GroupId::from_slice(b"test");
467        let mut group_state =
468            GroupState::new_from_party(group_id, alice_pre_group, mls_group_create_config).unwrap();
469
470        group_state
471            .add_member(AddMemberConfig {
472                adder: "alice",
473                addees: vec![bob_pre_group, charlie_pre_group, dave_pre_group],
474                join_config: mls_group_join_config.clone(),
475                tree: None,
476            })
477            .expect("Could not add member");
478
479        // test different orderings
480        let [alice, bob, charlie, dave] =
481            group_state.members_mut(&["alice", "bob", "charlie", "dave"]);
482        assert_eq!(alice.party.core_state.name, "alice");
483        assert_eq!(bob.party.core_state.name, "bob");
484        assert_eq!(charlie.party.core_state.name, "charlie");
485        assert_eq!(dave.party.core_state.name, "dave");
486
487        let [dave, charlie, bob, alice] =
488            group_state.members_mut(&["dave", "charlie", "bob", "alice"]);
489        assert_eq!(alice.party.core_state.name, "alice");
490        assert_eq!(bob.party.core_state.name, "bob");
491        assert_eq!(charlie.party.core_state.name, "charlie");
492        assert_eq!(dave.party.core_state.name, "dave");
493
494        let [dave, bob, charlie, alice] =
495            group_state.members_mut(&["dave", "bob", "charlie", "alice"]);
496        assert_eq!(alice.party.core_state.name, "alice");
497        assert_eq!(bob.party.core_state.name, "bob");
498        assert_eq!(charlie.party.core_state.name, "charlie");
499        assert_eq!(dave.party.core_state.name, "dave");
500
501        let [dave, bob] = group_state.members_mut(&["dave", "bob"]);
502        assert_eq!(bob.party.core_state.name, "bob");
503        assert_eq!(dave.party.core_state.name, "dave");
504
505        let [alice, charlie] = group_state.members_mut(&["alice", "charlie"]);
506        assert_eq!(alice.party.core_state.name, "alice");
507        assert_eq!(charlie.party.core_state.name, "charlie");
508    }
509    #[openmls_test]
510    pub fn simpler_example() {
511        let alice_party = CorePartyState::<Provider>::new("alice");
512        let bob_party = CorePartyState::<Provider>::new("bob");
513        let charlie_party = CorePartyState::<Provider>::new("charlie");
514        let dave_party = CorePartyState::<Provider>::new("dave");
515
516        let alice_pre_group = alice_party.generate_pre_group(ciphersuite);
517        let bob_pre_group = bob_party.generate_pre_group(ciphersuite);
518        let charlie_pre_group = charlie_party.generate_pre_group(ciphersuite);
519        let dave_pre_group = dave_party.generate_pre_group(ciphersuite);
520
521        // Create config
522        let mls_group_create_config = MlsGroupCreateConfig::builder()
523            .ciphersuite(ciphersuite)
524            .use_ratchet_tree_extension(true)
525            .build();
526
527        // Join config
528        let mls_group_join_config = mls_group_create_config.join_config().clone();
529
530        // Initialize the group state
531        let group_id = GroupId::from_slice(b"test");
532        let mut group_state =
533            GroupState::new_from_party(group_id, alice_pre_group, mls_group_create_config).unwrap();
534
535        group_state
536            .add_member(AddMemberConfig {
537                adder: "alice",
538                addees: vec![bob_pre_group, charlie_pre_group],
539                join_config: mls_group_join_config.clone(),
540                tree: None,
541            })
542            .expect("Could not add member");
543
544        group_state.assert_membership();
545
546        group_state
547            .add_member(AddMemberConfig {
548                adder: "bob",
549                addees: vec![dave_pre_group],
550                join_config: mls_group_join_config,
551                tree: None,
552            })
553            .expect("Could not add member");
554
555        group_state.assert_membership();
556    }
557
558    #[openmls_test]
559    pub fn simple_example() {
560        let alice_party = CorePartyState::<Provider>::new("alice");
561        let bob_party = CorePartyState::<Provider>::new("bob");
562
563        let alice_pre_group = alice_party.generate_pre_group(ciphersuite);
564        let bob_pre_group = bob_party.generate_pre_group(ciphersuite);
565
566        // Get the key package for Bob
567        // TODO: should key package be regenerated each time?
568        let bob_key_package = bob_pre_group.key_package_bundle.key_package.clone();
569
570        // Create config
571        let mls_group_create_config = MlsGroupCreateConfig::builder()
572            .ciphersuite(ciphersuite)
573            .use_ratchet_tree_extension(true)
574            .build();
575
576        // Join config
577        let mls_group_join_config = mls_group_create_config.join_config().clone();
578
579        // Initialize the group state
580        let group_id = GroupId::from_slice(b"test");
581        let mut group_state =
582            GroupState::new_from_party(group_id, alice_pre_group, mls_group_create_config).unwrap();
583
584        // Get a mutable reference to Alice's group representation
585        let [alice] = group_state.members_mut(&["alice"]);
586
587        // Build a commit with a single add proposal
588        let bundle = alice
589            .build_commit_and_stage(move |builder| {
590                let add_proposal = Proposal::add(AddProposal {
591                    key_package: bob_key_package,
592                });
593
594                // ...add more proposals here...
595
596                builder
597                    .consume_proposal_store(false)
598                    .add_proposal(add_proposal)
599            })
600            .expect("Could not stage commit");
601
602        // Deliver and apply welcome to Bob
603        let welcome = bundle.welcome().unwrap().clone();
604        group_state
605            .deliver_and_apply_welcome(bob_pre_group, mls_group_join_config, welcome, None)
606            .expect("Error delivering and applying welcome");
607
608        let [alice] = group_state.members_mut(&["alice"]);
609
610        let staged_commit = alice.group.pending_commit().unwrap().clone();
611
612        alice
613            .group
614            .merge_staged_commit(&alice.party.core_state.provider, staged_commit)
615            .expect("Error merging staged commit");
616
617        group_state.assert_membership();
618    }
619}