Skip to main content

openmls/test_utils/single_group_test_framework/
mod.rs

1use openmls_basic_credential::SignatureKeyPair;
2use openmls_traits::{signatures::Signer, types::SignatureScheme};
3pub use openmls_traits::{
4    storage::StorageProvider as StorageProviderTrait,
5    types::{Ciphersuite, HpkeKeyPair},
6    OpenMlsProvider,
7};
8
9pub use crate::utils::*;
10use crate::{
11    credentials::CredentialWithKey,
12    key_packages::KeyPackageBuilder,
13    prelude::{commit_builder::*, *},
14};
15
16use crate::test_utils::storage_state::GroupStorageState;
17
18mod assertions;
19
20mod errors;
21pub use errors::GroupError;
22use errors::*;
23
24use std::collections::HashMap;
25
26// type alias for &'static str
27type Name = &'static str;
28
29// TODO: only define this once
30/// Helper function for generating a credential.
31pub fn generate_credential(
32    identity: Vec<u8>,
33    signature_algorithm: SignatureScheme,
34    provider: &impl crate::storage::OpenMlsProvider,
35) -> (CredentialWithKey, SignatureKeyPair) {
36    let credential = BasicCredential::new(identity);
37    let signature_keys = SignatureKeyPair::new(signature_algorithm).unwrap();
38    signature_keys.store(provider.storage()).unwrap();
39
40    (
41        CredentialWithKey {
42            credential: credential.into(),
43            signature_key: signature_keys.to_public_vec().into(),
44        },
45        signature_keys,
46    )
47}
48
49// TODO: only define this once
50pub(crate) fn generate_key_package(
51    ciphersuite: Ciphersuite,
52    credential_with_key: CredentialWithKey,
53    extensions: Extensions<KeyPackage>,
54    provider: &impl crate::storage::OpenMlsProvider,
55    lifetime: impl Into<Option<Lifetime>>,
56    signer: &impl Signer,
57) -> KeyPackageBundle {
58    let mut builder = KeyPackage::builder().key_package_extensions(extensions);
59
60    if let Some(lifetime) = lifetime.into() {
61        builder = builder.key_package_lifetime(lifetime);
62    }
63
64    builder
65        .build(ciphersuite, provider, signer, credential_with_key)
66        .unwrap()
67}
68
69/// Struct representing a party's global state
70pub struct CorePartyState<Provider> {
71    pub name: Name,
72    pub provider: Provider,
73}
74
75impl<Provider: Default> CorePartyState<Provider> {
76    pub fn new(name: Name) -> Self {
77        Self {
78            name,
79            provider: Provider::default(),
80        }
81    }
82}
83
84/// Struct representing a party's state before joining a group
85pub struct PreGroupPartyState<'a, Provider> {
86    pub credential_with_key: CredentialWithKey,
87    // TODO: regenerate?
88    pub key_package_bundle: KeyPackageBundle,
89    pub signer: SignatureKeyPair,
90    pub core_state: &'a CorePartyState<Provider>,
91}
92
93pub struct PreGroupPartyStateBuilder<'a, Provider: OpenMlsProvider> {
94    ciphersuite: Ciphersuite,
95    lifetime: Option<Lifetime>,
96    key_package_extensions: Option<Extensions<KeyPackage>>,
97    leaf_node_extensions: Option<Extensions<LeafNode>>,
98    leaf_node_capabilities: Option<Capabilities>,
99    core_state: &'a CorePartyState<Provider>,
100}
101
102impl<'a, Provider: OpenMlsProvider> PreGroupPartyStateBuilder<'a, Provider> {
103    pub fn with_lifetime(mut self, lifetime: impl Into<Option<Lifetime>>) -> Self {
104        self.lifetime = lifetime.into();
105
106        self
107    }
108    pub fn with_key_package_extensions(
109        mut self,
110        extensions: impl Into<Option<Extensions<KeyPackage>>>,
111    ) -> Self {
112        self.key_package_extensions = extensions.into();
113
114        self
115    }
116    pub fn with_leaf_node_extensions(
117        mut self,
118        extensions: impl Into<Option<Extensions<LeafNode>>>,
119    ) -> Self {
120        self.leaf_node_extensions = extensions.into();
121
122        self
123    }
124    pub fn with_leaf_node_capabilities(
125        mut self,
126        capabilities: impl Into<Option<Capabilities>>,
127    ) -> Self {
128        self.leaf_node_capabilities = capabilities.into();
129
130        self
131    }
132
133    pub fn build(self) -> PreGroupPartyState<'a, Provider> {
134        let (credential_with_key, signer) = generate_credential(
135            self.core_state.name.into(),
136            self.ciphersuite.signature_algorithm(),
137            &self.core_state.provider,
138        );
139        let mut builder = KeyPackage::builder()
140            .leaf_node_extensions(self.leaf_node_extensions.unwrap_or_default())
141            .key_package_extensions(self.key_package_extensions.unwrap_or_default())
142            .leaf_node_capabilities(self.leaf_node_capabilities.unwrap_or_default());
143
144        if let Some(lifetime) = self.lifetime {
145            builder = builder.key_package_lifetime(lifetime);
146        }
147
148        let key_package_bundle = builder
149            .build(
150                self.ciphersuite,
151                &self.core_state.provider,
152                &signer,
153                credential_with_key.clone(),
154            )
155            .unwrap();
156
157        PreGroupPartyState {
158            credential_with_key,
159            key_package_bundle,
160            signer,
161            core_state: self.core_state,
162        }
163    }
164}
165
166impl<Provider: OpenMlsProvider> CorePartyState<Provider> {
167    /// Returns a builder for the [`CorePartyState`].
168    pub fn pre_group_builder<'a>(
169        &'a self,
170        ciphersuite: Ciphersuite,
171    ) -> PreGroupPartyStateBuilder<'a, Provider> {
172        PreGroupPartyStateBuilder {
173            ciphersuite,
174            lifetime: None,
175            key_package_extensions: None,
176            leaf_node_extensions: None,
177            leaf_node_capabilities: None,
178            core_state: self,
179        }
180    }
181
182    /// Generates a simple pre-group state for a `CorePartyState`
183    pub fn generate_pre_group(&self, ciphersuite: Ciphersuite) -> PreGroupPartyState<'_, Provider> {
184        self.pre_group_builder(ciphersuite).build()
185    }
186}
187
188/// Represents a group member's `MlsGroup` instance and pre-group state
189pub struct MemberState<'a, Provider> {
190    pub party: PreGroupPartyState<'a, Provider>,
191    pub group: MlsGroup,
192}
193
194impl<Provider: OpenMlsProvider> MemberState<'_, Provider> {
195    /// Get member's `SignatureKeyPair` if available
196    pub fn get_storage_signature_key_pair(&self) -> Option<SignatureKeyPair> {
197        let ciphersuite = self
198            .party
199            .key_package_bundle
200            .key_package()
201            .ciphersuite()
202            .into();
203
204        SignatureKeyPair::read(
205            self.party.core_state.provider.storage(),
206            self.party.signer.public(),
207            ciphersuite,
208        )
209    }
210    /// Get the `GroupStorageState` for this group
211    pub fn group_storage_state(&self) -> GroupStorageState {
212        let storage_provider = self.party.core_state.provider.storage();
213        let group_id = self.group.group_id();
214
215        GroupStorageState::from_storage(storage_provider, group_id)
216    }
217    /// Deliver_and_apply a message to this member's `MlsGroup`
218    pub fn deliver_and_apply(&mut self, message: MlsMessageIn) -> Result<(), GroupError<Provider>> {
219        let message = message.try_into_protocol_message()?;
220
221        // process message
222        let processed_message = self
223            .group
224            .process_message(&self.party.core_state.provider, message)?;
225
226        match processed_message.into_content() {
227            ProcessedMessageContent::ApplicationMessage(_) => todo!(),
228            ProcessedMessageContent::ProposalMessage(_) => todo!(),
229            ProcessedMessageContent::ExternalJoinProposalMessage(_) => todo!(),
230            ProcessedMessageContent::StagedCommitMessage(m) => self
231                .group
232                .merge_staged_commit(&self.party.core_state.provider, *m)?,
233        };
234
235        Ok(())
236    }
237}
238
239impl<'commit_builder, 'b: 'commit_builder, 'a: 'b, Provider> MemberState<'a, Provider>
240where
241    Provider: openmls_traits::OpenMlsProvider,
242{
243    /// Build and stage a commit, using the provided closure to add proposals
244    pub fn build_commit_and_stage(
245        &'b mut self,
246        f: impl FnOnce(
247            CommitBuilder<'commit_builder, Initial>,
248        ) -> CommitBuilder<'commit_builder, Initial>,
249    ) -> Result<CommitMessageBundle, GroupError<Provider>> {
250        let commit_builder = f(self.group.commit_builder());
251
252        let provider = &self.party.core_state.provider;
253
254        // TODO: most of the steps here cannot be done via the closure (yet)
255        let bundle = commit_builder
256            .load_psks(provider.storage())?
257            .build(
258                provider.rand(),
259                provider.crypto(),
260                &self.party.signer,
261                |_| true,
262            )?
263            .stage_commit(provider)?;
264
265        Ok(bundle)
266    }
267}
268
269impl<'a, Provider: OpenMlsProvider> MemberState<'a, Provider> {
270    /// Create a `MemberState` from a `PreGroupPartyState`. This creates a new `MlsGroup` with one
271    /// member
272    pub fn create_from_pre_group(
273        party: PreGroupPartyState<'a, Provider>,
274        mls_group_create_config: MlsGroupCreateConfig,
275        group_id: GroupId,
276    ) -> Result<Self, GroupError<Provider>> {
277        // initialize MlsGroup
278        let group = MlsGroup::new_with_group_id(
279            &party.core_state.provider,
280            &party.signer,
281            &mls_group_create_config,
282            group_id,
283            party.credential_with_key.clone(),
284        )?;
285
286        Ok(Self { party, group })
287    }
288    /// Create a `MemberState` from a `Welcome`, which creates a new `MlsGroup` using a `Welcome`
289    /// invitation from an existing group
290    pub fn join_from_pre_group(
291        party: PreGroupPartyState<'a, Provider>,
292        mls_group_join_config: MlsGroupJoinConfig,
293        welcome: Welcome,
294        tree: Option<RatchetTreeIn>,
295    ) -> Result<Self, GroupError<Provider>> {
296        let staged_join = StagedWelcome::new_from_welcome(
297            &party.core_state.provider,
298            &mls_group_join_config,
299            welcome,
300            tree,
301        )?;
302
303        let group = staged_join.into_group(&party.core_state.provider)?;
304
305        Ok(Self { party, group })
306    }
307}
308
309/// All of the state for a group and its members
310pub struct GroupState<'a, Provider> {
311    group_id: GroupId,
312    members: HashMap<Name, MemberState<'a, Provider>>,
313}
314
315impl<'a, Provider: OpenMlsProvider> GroupState<'a, Provider> {
316    /// Create a new `GroupState` from a single party
317    pub fn new_from_party(
318        group_id: GroupId,
319        pre_group_state: PreGroupPartyState<'a, Provider>,
320        mls_group_create_config: MlsGroupCreateConfig,
321    ) -> Result<Self, GroupError<Provider>> {
322        let mut members = HashMap::new();
323
324        let name = pre_group_state.core_state.name;
325        let member_state = MemberState::create_from_pre_group(
326            pre_group_state,
327            mls_group_create_config,
328            group_id.clone(),
329        )?;
330
331        members.insert(name, member_state);
332
333        Ok(Self { group_id, members })
334    }
335
336    /// Get mutable references to specified `MemberState`s as a fixed-size array,
337    /// in the order of the names provided in `names`.
338    /// At least one member must be requested.
339    pub fn members_mut<const N: usize>(
340        &mut self,
341        names: &[Name; N],
342    ) -> [&mut MemberState<'a, Provider>; N] {
343        assert!(N > 0, "must request at least one member");
344        assert!(
345            N <= self.members.len(),
346            "cannot request more members than available"
347        );
348
349        // map each member in `self.members` to its name's index in `names`
350        let mut members: [(_, _); N] = self
351            .members
352            .iter_mut()
353            .filter_map(|(member_name, member)| {
354                // Find the index of the member's name in `names`
355                // NOTE: the list of names provided to this method will generally be short,
356                // so not many comparisons are made here.
357                let index = names.iter().position(|name| name == member_name)?;
358
359                Some((index, member))
360            })
361            // collect into Vec, then into fixed-size array
362            .collect::<Vec<_>>()
363            .try_into()
364            .ok()
365            .expect("At least one requested member not found");
366
367        // sort by index
368        members.sort_by_key(|(pos, _member)| *pos);
369
370        members.map(|(_pos, member)| member)
371    }
372
373    /// Deliver_and_apply a message to all parties
374    pub fn deliver_and_apply(&mut self, message: MlsMessageIn) -> Result<(), GroupError<Provider>> {
375        self.deliver_and_apply_if(message, |_| true)
376    }
377    /// Deliver_and_apply a message to all parties if a provided condition is met
378    pub fn deliver_and_apply_if(
379        &mut self,
380        message: MlsMessageIn,
381        condition: impl Fn(&MemberState<'a, Provider>) -> bool,
382    ) -> Result<(), GroupError<Provider>> {
383        self.members
384            .values_mut()
385            .filter(|member| condition(member))
386            .try_for_each(|member| member.deliver_and_apply(message.clone()))?;
387
388        Ok(())
389    }
390
391    /// Deliver_and_apply a welcome to a single party, and initialize a group for that party
392    pub fn deliver_and_apply_welcome(
393        &mut self,
394        recipient: PreGroupPartyState<'a, Provider>,
395        mls_group_join_config: MlsGroupJoinConfig,
396        welcome: Welcome,
397        tree: Option<RatchetTreeIn>,
398    ) -> Result<(), GroupError<Provider>> {
399        // create new group
400        let name = recipient.core_state.name;
401
402        let member_state =
403            MemberState::join_from_pre_group(recipient, mls_group_join_config, welcome, tree)?;
404
405        // insert after success
406        self.members.insert(name, member_state);
407
408        Ok(())
409    }
410
411    /// Drop a member from the internal hashmap. This does not delete the member from any
412    /// `MlsGroup`
413    pub fn untrack_member(&mut self, name: Name) {
414        let _ = self.members.remove(&name);
415    }
416
417    pub fn add_member(
418        &mut self,
419        add_config: AddMemberConfig<'a, Provider>,
420    ) -> Result<(), GroupError<Provider>> {
421        let adder = self
422            .members
423            .get_mut(add_config.adder)
424            .ok_or(TestError::NoSuchMember)?;
425
426        let key_packages: Vec<_> = add_config
427            .addees
428            .iter()
429            .map(|addee| addee.key_package_bundle.key_package.clone())
430            .collect();
431
432        let (commit, welcome, _) = adder.group.add_members(
433            &adder.party.core_state.provider,
434            &adder.party.signer,
435            &key_packages,
436        )?;
437
438        // Deliver_and_apply to all members but adder
439        self.deliver_and_apply_if(commit.into(), |member| {
440            member.party.core_state.name != add_config.adder
441        })?;
442
443        // Deliver_and_apply welcome to addee
444        let welcome = match welcome.body() {
445            MlsMessageBodyOut::Welcome(welcome) => welcome.clone(),
446            _ => panic!("No welcome returned"),
447        };
448
449        for addee in add_config.addees.into_iter() {
450            self.deliver_and_apply_welcome(
451                addee,
452                add_config.join_config.clone(),
453                welcome.clone(),
454                None,
455            )?;
456        }
457
458        let adder = self
459            .members
460            .get_mut(add_config.adder)
461            .ok_or(TestError::NoSuchMember)?;
462
463        let staged_commit = adder.group.pending_commit().unwrap().clone();
464
465        adder
466            .group
467            .merge_staged_commit(&adder.party.core_state.provider, staged_commit)?;
468
469        Ok(())
470    }
471
472    /// Returns a copy of the GroupId
473    pub fn group_id(&self) -> GroupId {
474        self.group_id.clone()
475    }
476}
477
478impl MlsGroupCreateConfig {
479    /// Default config for test framework
480    pub fn test_default_from_ciphersuite(ciphersuite: Ciphersuite) -> Self {
481        MlsGroupCreateConfig::builder()
482            .ciphersuite(ciphersuite)
483            .use_ratchet_tree_extension(true)
484            .wire_format_policy(PURE_PLAINTEXT_WIRE_FORMAT_POLICY) // Important because the secret tree might diverge otherwise
485            .build()
486    }
487}
488
489pub struct AddMemberConfig<'a, Provider> {
490    pub adder: Name,
491    pub addees: Vec<PreGroupPartyState<'a, Provider>>,
492    pub join_config: MlsGroupJoinConfig,
493    pub tree: Option<RatchetTreeIn>,
494}
495
496#[cfg(test)]
497mod test {
498
499    use super::*;
500    use openmls_test::openmls_test;
501
502    #[openmls_test]
503    fn test_members_mut() {
504        let alice_party = CorePartyState::<Provider>::new("alice");
505        let bob_party = CorePartyState::<Provider>::new("bob");
506        let charlie_party = CorePartyState::<Provider>::new("charlie");
507        let dave_party = CorePartyState::<Provider>::new("dave");
508
509        let alice_pre_group = alice_party.generate_pre_group(ciphersuite);
510        let bob_pre_group = bob_party.generate_pre_group(ciphersuite);
511        let charlie_pre_group = charlie_party.generate_pre_group(ciphersuite);
512        let dave_pre_group = dave_party.generate_pre_group(ciphersuite);
513
514        // Create config
515        let mls_group_create_config = MlsGroupCreateConfig::builder()
516            .ciphersuite(ciphersuite)
517            .use_ratchet_tree_extension(true)
518            .build();
519
520        // Join config
521        let mls_group_join_config = mls_group_create_config.join_config().clone();
522
523        // Initialize the group state
524        let group_id = GroupId::from_slice(b"test");
525        let mut group_state =
526            GroupState::new_from_party(group_id, alice_pre_group, mls_group_create_config).unwrap();
527
528        group_state
529            .add_member(AddMemberConfig {
530                adder: "alice",
531                addees: vec![bob_pre_group, charlie_pre_group, dave_pre_group],
532                join_config: mls_group_join_config.clone(),
533                tree: None,
534            })
535            .expect("Could not add member");
536
537        // test different orderings
538        let [alice, bob, charlie, dave] =
539            group_state.members_mut(&["alice", "bob", "charlie", "dave"]);
540        assert_eq!(alice.party.core_state.name, "alice");
541        assert_eq!(bob.party.core_state.name, "bob");
542        assert_eq!(charlie.party.core_state.name, "charlie");
543        assert_eq!(dave.party.core_state.name, "dave");
544
545        let [dave, charlie, bob, alice] =
546            group_state.members_mut(&["dave", "charlie", "bob", "alice"]);
547        assert_eq!(alice.party.core_state.name, "alice");
548        assert_eq!(bob.party.core_state.name, "bob");
549        assert_eq!(charlie.party.core_state.name, "charlie");
550        assert_eq!(dave.party.core_state.name, "dave");
551
552        let [dave, bob, charlie, alice] =
553            group_state.members_mut(&["dave", "bob", "charlie", "alice"]);
554        assert_eq!(alice.party.core_state.name, "alice");
555        assert_eq!(bob.party.core_state.name, "bob");
556        assert_eq!(charlie.party.core_state.name, "charlie");
557        assert_eq!(dave.party.core_state.name, "dave");
558
559        let [dave, bob] = group_state.members_mut(&["dave", "bob"]);
560        assert_eq!(bob.party.core_state.name, "bob");
561        assert_eq!(dave.party.core_state.name, "dave");
562
563        let [alice, charlie] = group_state.members_mut(&["alice", "charlie"]);
564        assert_eq!(alice.party.core_state.name, "alice");
565        assert_eq!(charlie.party.core_state.name, "charlie");
566    }
567    #[openmls_test]
568    pub fn simpler_example() {
569        let alice_party = CorePartyState::<Provider>::new("alice");
570        let bob_party = CorePartyState::<Provider>::new("bob");
571        let charlie_party = CorePartyState::<Provider>::new("charlie");
572        let dave_party = CorePartyState::<Provider>::new("dave");
573
574        let alice_pre_group = alice_party.generate_pre_group(ciphersuite);
575        let bob_pre_group = bob_party.generate_pre_group(ciphersuite);
576        let charlie_pre_group = charlie_party.generate_pre_group(ciphersuite);
577        let dave_pre_group = dave_party.generate_pre_group(ciphersuite);
578
579        // Create config
580        let mls_group_create_config = MlsGroupCreateConfig::builder()
581            .ciphersuite(ciphersuite)
582            .use_ratchet_tree_extension(true)
583            .build();
584
585        // Join config
586        let mls_group_join_config = mls_group_create_config.join_config().clone();
587
588        // Initialize the group state
589        let group_id = GroupId::from_slice(b"test");
590        let mut group_state =
591            GroupState::new_from_party(group_id, alice_pre_group, mls_group_create_config).unwrap();
592
593        group_state
594            .add_member(AddMemberConfig {
595                adder: "alice",
596                addees: vec![bob_pre_group, charlie_pre_group],
597                join_config: mls_group_join_config.clone(),
598                tree: None,
599            })
600            .expect("Could not add member");
601
602        group_state.assert_membership();
603
604        group_state
605            .add_member(AddMemberConfig {
606                adder: "bob",
607                addees: vec![dave_pre_group],
608                join_config: mls_group_join_config,
609                tree: None,
610            })
611            .expect("Could not add member");
612
613        group_state.assert_membership();
614    }
615
616    #[openmls_test]
617    pub fn simple_example() {
618        let alice_party = CorePartyState::<Provider>::new("alice");
619        let bob_party = CorePartyState::<Provider>::new("bob");
620
621        let alice_pre_group = alice_party.generate_pre_group(ciphersuite);
622        let bob_pre_group = bob_party.generate_pre_group(ciphersuite);
623
624        // Get the key package for Bob
625        // TODO: should key package be regenerated each time?
626        let bob_key_package = bob_pre_group.key_package_bundle.key_package.clone();
627
628        // Create config
629        let mls_group_create_config = MlsGroupCreateConfig::builder()
630            .ciphersuite(ciphersuite)
631            .use_ratchet_tree_extension(true)
632            .build();
633
634        // Join config
635        let mls_group_join_config = mls_group_create_config.join_config().clone();
636
637        // Initialize the group state
638        let group_id = GroupId::from_slice(b"test");
639        let mut group_state =
640            GroupState::new_from_party(group_id, alice_pre_group, mls_group_create_config).unwrap();
641
642        // Get a mutable reference to Alice's group representation
643        let [alice] = group_state.members_mut(&["alice"]);
644
645        // Build a commit with a single add proposal
646        let bundle = alice
647            .build_commit_and_stage(move |builder| {
648                let add_proposal = Proposal::add(AddProposal {
649                    key_package: bob_key_package,
650                });
651
652                // ...add more proposals here...
653
654                builder
655                    .consume_proposal_store(false)
656                    .add_proposal(add_proposal)
657            })
658            .expect("Could not stage commit");
659
660        // Deliver and apply welcome to Bob
661        let welcome = bundle.welcome().unwrap().clone();
662        group_state
663            .deliver_and_apply_welcome(bob_pre_group, mls_group_join_config, welcome, None)
664            .expect("Error delivering and applying welcome");
665
666        let [alice] = group_state.members_mut(&["alice"]);
667
668        let staged_commit = alice.group.pending_commit().unwrap().clone();
669
670        alice
671            .group
672            .merge_staged_commit(&alice.party.core_state.provider, staged_commit)
673            .expect("Error merging staged commit");
674
675        group_state.assert_membership();
676    }
677}