1use openmls_basic_credential::SignatureKeyPair;
2use openmls_traits::{signatures::Signer, types::SignatureScheme};
3pub use openmls_traits::{
4 storage::StorageProvider as StorageProviderTrait,
5 types::{Ciphersuite, HpkeKeyPair},
6 OpenMlsProvider,
7};
8
9pub use crate::utils::*;
10use crate::{
11 credentials::CredentialWithKey,
12 key_packages::KeyPackageBuilder,
13 prelude::{commit_builder::*, *},
14};
15
16use crate::test_utils::storage_state::GroupStorageState;
17
18mod assertions;
19
20mod errors;
21pub use errors::GroupError;
22use errors::*;
23
24use std::collections::HashMap;
25
26type Name = &'static str;
28
29pub fn generate_credential(
32 identity: Vec<u8>,
33 signature_algorithm: SignatureScheme,
34 provider: &impl crate::storage::OpenMlsProvider,
35) -> (CredentialWithKey, SignatureKeyPair) {
36 let credential = BasicCredential::new(identity);
37 let signature_keys = SignatureKeyPair::new(signature_algorithm).unwrap();
38 signature_keys.store(provider.storage()).unwrap();
39
40 (
41 CredentialWithKey {
42 credential: credential.into(),
43 signature_key: signature_keys.to_public_vec().into(),
44 },
45 signature_keys,
46 )
47}
48
49pub(crate) fn generate_key_package(
51 ciphersuite: Ciphersuite,
52 credential_with_key: CredentialWithKey,
53 extensions: Extensions<KeyPackage>,
54 provider: &impl crate::storage::OpenMlsProvider,
55 lifetime: impl Into<Option<Lifetime>>,
56 signer: &impl Signer,
57) -> KeyPackageBundle {
58 let mut builder = KeyPackage::builder().key_package_extensions(extensions);
59
60 if let Some(lifetime) = lifetime.into() {
61 builder = builder.key_package_lifetime(lifetime);
62 }
63
64 builder
65 .build(ciphersuite, provider, signer, credential_with_key)
66 .unwrap()
67}
68
69pub struct CorePartyState<Provider> {
71 pub name: Name,
72 pub provider: Provider,
73}
74
75impl<Provider: Default> CorePartyState<Provider> {
76 pub fn new(name: Name) -> Self {
77 Self {
78 name,
79 provider: Provider::default(),
80 }
81 }
82}
83
84pub struct PreGroupPartyState<'a, Provider> {
86 pub credential_with_key: CredentialWithKey,
87 pub key_package_bundle: KeyPackageBundle,
89 pub signer: SignatureKeyPair,
90 pub core_state: &'a CorePartyState<Provider>,
91}
92
93pub struct PreGroupPartyStateBuilder<'a, Provider: OpenMlsProvider> {
94 ciphersuite: Ciphersuite,
95 lifetime: Option<Lifetime>,
96 key_package_extensions: Option<Extensions<KeyPackage>>,
97 leaf_node_extensions: Option<Extensions<LeafNode>>,
98 leaf_node_capabilities: Option<Capabilities>,
99 core_state: &'a CorePartyState<Provider>,
100}
101
102impl<'a, Provider: OpenMlsProvider> PreGroupPartyStateBuilder<'a, Provider> {
103 pub fn with_lifetime(mut self, lifetime: impl Into<Option<Lifetime>>) -> Self {
104 self.lifetime = lifetime.into();
105
106 self
107 }
108 pub fn with_key_package_extensions(
109 mut self,
110 extensions: impl Into<Option<Extensions<KeyPackage>>>,
111 ) -> Self {
112 self.key_package_extensions = extensions.into();
113
114 self
115 }
116 pub fn with_leaf_node_extensions(
117 mut self,
118 extensions: impl Into<Option<Extensions<LeafNode>>>,
119 ) -> Self {
120 self.leaf_node_extensions = extensions.into();
121
122 self
123 }
124 pub fn with_leaf_node_capabilities(
125 mut self,
126 capabilities: impl Into<Option<Capabilities>>,
127 ) -> Self {
128 self.leaf_node_capabilities = capabilities.into();
129
130 self
131 }
132
133 pub fn build(self) -> PreGroupPartyState<'a, Provider> {
134 let (credential_with_key, signer) = generate_credential(
135 self.core_state.name.into(),
136 self.ciphersuite.signature_algorithm(),
137 &self.core_state.provider,
138 );
139 let mut builder = KeyPackage::builder()
140 .leaf_node_extensions(self.leaf_node_extensions.unwrap_or_default())
141 .key_package_extensions(self.key_package_extensions.unwrap_or_default())
142 .leaf_node_capabilities(self.leaf_node_capabilities.unwrap_or_default());
143
144 if let Some(lifetime) = self.lifetime {
145 builder = builder.key_package_lifetime(lifetime);
146 }
147
148 let key_package_bundle = builder
149 .build(
150 self.ciphersuite,
151 &self.core_state.provider,
152 &signer,
153 credential_with_key.clone(),
154 )
155 .unwrap();
156
157 PreGroupPartyState {
158 credential_with_key,
159 key_package_bundle,
160 signer,
161 core_state: self.core_state,
162 }
163 }
164}
165
166impl<Provider: OpenMlsProvider> CorePartyState<Provider> {
167 pub fn pre_group_builder<'a>(
169 &'a self,
170 ciphersuite: Ciphersuite,
171 ) -> PreGroupPartyStateBuilder<'a, Provider> {
172 PreGroupPartyStateBuilder {
173 ciphersuite,
174 lifetime: None,
175 key_package_extensions: None,
176 leaf_node_extensions: None,
177 leaf_node_capabilities: None,
178 core_state: self,
179 }
180 }
181
182 pub fn generate_pre_group(&self, ciphersuite: Ciphersuite) -> PreGroupPartyState<'_, Provider> {
184 self.pre_group_builder(ciphersuite).build()
185 }
186}
187
188pub struct MemberState<'a, Provider> {
190 pub party: PreGroupPartyState<'a, Provider>,
191 pub group: MlsGroup,
192}
193
194impl<Provider: OpenMlsProvider> MemberState<'_, Provider> {
195 pub fn get_storage_signature_key_pair(&self) -> Option<SignatureKeyPair> {
197 let ciphersuite = self
198 .party
199 .key_package_bundle
200 .key_package()
201 .ciphersuite()
202 .into();
203
204 SignatureKeyPair::read(
205 self.party.core_state.provider.storage(),
206 self.party.signer.public(),
207 ciphersuite,
208 )
209 }
210 pub fn group_storage_state(&self) -> GroupStorageState {
212 let storage_provider = self.party.core_state.provider.storage();
213 let group_id = self.group.group_id();
214
215 GroupStorageState::from_storage(storage_provider, group_id)
216 }
217 pub fn deliver_and_apply(&mut self, message: MlsMessageIn) -> Result<(), GroupError<Provider>> {
219 let message = message.try_into_protocol_message()?;
220
221 let processed_message = self
223 .group
224 .process_message(&self.party.core_state.provider, message)?;
225
226 match processed_message.into_content() {
227 ProcessedMessageContent::ApplicationMessage(_) => todo!(),
228 ProcessedMessageContent::ProposalMessage(_) => todo!(),
229 ProcessedMessageContent::ExternalJoinProposalMessage(_) => todo!(),
230 ProcessedMessageContent::StagedCommitMessage(m) => self
231 .group
232 .merge_staged_commit(&self.party.core_state.provider, *m)?,
233 };
234
235 Ok(())
236 }
237}
238
239impl<'commit_builder, 'b: 'commit_builder, 'a: 'b, Provider> MemberState<'a, Provider>
240where
241 Provider: openmls_traits::OpenMlsProvider,
242{
243 pub fn build_commit_and_stage(
245 &'b mut self,
246 f: impl FnOnce(
247 CommitBuilder<'commit_builder, Initial>,
248 ) -> CommitBuilder<'commit_builder, Initial>,
249 ) -> Result<CommitMessageBundle, GroupError<Provider>> {
250 let commit_builder = f(self.group.commit_builder());
251
252 let provider = &self.party.core_state.provider;
253
254 let bundle = commit_builder
256 .load_psks(provider.storage())?
257 .build(
258 provider.rand(),
259 provider.crypto(),
260 &self.party.signer,
261 |_| true,
262 )?
263 .stage_commit(provider)?;
264
265 Ok(bundle)
266 }
267}
268
269impl<'a, Provider: OpenMlsProvider> MemberState<'a, Provider> {
270 pub fn create_from_pre_group(
273 party: PreGroupPartyState<'a, Provider>,
274 mls_group_create_config: MlsGroupCreateConfig,
275 group_id: GroupId,
276 ) -> Result<Self, GroupError<Provider>> {
277 let group = MlsGroup::new_with_group_id(
279 &party.core_state.provider,
280 &party.signer,
281 &mls_group_create_config,
282 group_id,
283 party.credential_with_key.clone(),
284 )?;
285
286 Ok(Self { party, group })
287 }
288 pub fn join_from_pre_group(
291 party: PreGroupPartyState<'a, Provider>,
292 mls_group_join_config: MlsGroupJoinConfig,
293 welcome: Welcome,
294 tree: Option<RatchetTreeIn>,
295 ) -> Result<Self, GroupError<Provider>> {
296 let staged_join = StagedWelcome::new_from_welcome(
297 &party.core_state.provider,
298 &mls_group_join_config,
299 welcome,
300 tree,
301 )?;
302
303 let group = staged_join.into_group(&party.core_state.provider)?;
304
305 Ok(Self { party, group })
306 }
307}
308
309pub struct GroupState<'a, Provider> {
311 group_id: GroupId,
312 members: HashMap<Name, MemberState<'a, Provider>>,
313}
314
315impl<'a, Provider: OpenMlsProvider> GroupState<'a, Provider> {
316 pub fn new_from_party(
318 group_id: GroupId,
319 pre_group_state: PreGroupPartyState<'a, Provider>,
320 mls_group_create_config: MlsGroupCreateConfig,
321 ) -> Result<Self, GroupError<Provider>> {
322 let mut members = HashMap::new();
323
324 let name = pre_group_state.core_state.name;
325 let member_state = MemberState::create_from_pre_group(
326 pre_group_state,
327 mls_group_create_config,
328 group_id.clone(),
329 )?;
330
331 members.insert(name, member_state);
332
333 Ok(Self { group_id, members })
334 }
335
336 pub fn members_mut<const N: usize>(
340 &mut self,
341 names: &[Name; N],
342 ) -> [&mut MemberState<'a, Provider>; N] {
343 assert!(N > 0, "must request at least one member");
344 assert!(
345 N <= self.members.len(),
346 "cannot request more members than available"
347 );
348
349 let mut members: [(_, _); N] = self
351 .members
352 .iter_mut()
353 .filter_map(|(member_name, member)| {
354 let index = names.iter().position(|name| name == member_name)?;
358
359 Some((index, member))
360 })
361 .collect::<Vec<_>>()
363 .try_into()
364 .ok()
365 .expect("At least one requested member not found");
366
367 members.sort_by_key(|(pos, _member)| *pos);
369
370 members.map(|(_pos, member)| member)
371 }
372
373 pub fn deliver_and_apply(&mut self, message: MlsMessageIn) -> Result<(), GroupError<Provider>> {
375 self.deliver_and_apply_if(message, |_| true)
376 }
377 pub fn deliver_and_apply_if(
379 &mut self,
380 message: MlsMessageIn,
381 condition: impl Fn(&MemberState<'a, Provider>) -> bool,
382 ) -> Result<(), GroupError<Provider>> {
383 self.members
384 .values_mut()
385 .filter(|member| condition(member))
386 .try_for_each(|member| member.deliver_and_apply(message.clone()))?;
387
388 Ok(())
389 }
390
391 pub fn deliver_and_apply_welcome(
393 &mut self,
394 recipient: PreGroupPartyState<'a, Provider>,
395 mls_group_join_config: MlsGroupJoinConfig,
396 welcome: Welcome,
397 tree: Option<RatchetTreeIn>,
398 ) -> Result<(), GroupError<Provider>> {
399 let name = recipient.core_state.name;
401
402 let member_state =
403 MemberState::join_from_pre_group(recipient, mls_group_join_config, welcome, tree)?;
404
405 self.members.insert(name, member_state);
407
408 Ok(())
409 }
410
411 pub fn untrack_member(&mut self, name: Name) {
414 let _ = self.members.remove(&name);
415 }
416
417 pub fn add_member(
418 &mut self,
419 add_config: AddMemberConfig<'a, Provider>,
420 ) -> Result<(), GroupError<Provider>> {
421 let adder = self
422 .members
423 .get_mut(add_config.adder)
424 .ok_or(TestError::NoSuchMember)?;
425
426 let key_packages: Vec<_> = add_config
427 .addees
428 .iter()
429 .map(|addee| addee.key_package_bundle.key_package.clone())
430 .collect();
431
432 let (commit, welcome, _) = adder.group.add_members(
433 &adder.party.core_state.provider,
434 &adder.party.signer,
435 &key_packages,
436 )?;
437
438 self.deliver_and_apply_if(commit.into(), |member| {
440 member.party.core_state.name != add_config.adder
441 })?;
442
443 let welcome = match welcome.body() {
445 MlsMessageBodyOut::Welcome(welcome) => welcome.clone(),
446 _ => panic!("No welcome returned"),
447 };
448
449 for addee in add_config.addees.into_iter() {
450 self.deliver_and_apply_welcome(
451 addee,
452 add_config.join_config.clone(),
453 welcome.clone(),
454 None,
455 )?;
456 }
457
458 let adder = self
459 .members
460 .get_mut(add_config.adder)
461 .ok_or(TestError::NoSuchMember)?;
462
463 let staged_commit = adder.group.pending_commit().unwrap().clone();
464
465 adder
466 .group
467 .merge_staged_commit(&adder.party.core_state.provider, staged_commit)?;
468
469 Ok(())
470 }
471
472 pub fn group_id(&self) -> GroupId {
474 self.group_id.clone()
475 }
476}
477
478impl MlsGroupCreateConfig {
479 pub fn test_default_from_ciphersuite(ciphersuite: Ciphersuite) -> Self {
481 MlsGroupCreateConfig::builder()
482 .ciphersuite(ciphersuite)
483 .use_ratchet_tree_extension(true)
484 .wire_format_policy(PURE_PLAINTEXT_WIRE_FORMAT_POLICY) .build()
486 }
487}
488
489pub struct AddMemberConfig<'a, Provider> {
490 pub adder: Name,
491 pub addees: Vec<PreGroupPartyState<'a, Provider>>,
492 pub join_config: MlsGroupJoinConfig,
493 pub tree: Option<RatchetTreeIn>,
494}
495
496#[cfg(test)]
497mod test {
498
499 use super::*;
500 use openmls_test::openmls_test;
501
502 #[openmls_test]
503 fn test_members_mut() {
504 let alice_party = CorePartyState::<Provider>::new("alice");
505 let bob_party = CorePartyState::<Provider>::new("bob");
506 let charlie_party = CorePartyState::<Provider>::new("charlie");
507 let dave_party = CorePartyState::<Provider>::new("dave");
508
509 let alice_pre_group = alice_party.generate_pre_group(ciphersuite);
510 let bob_pre_group = bob_party.generate_pre_group(ciphersuite);
511 let charlie_pre_group = charlie_party.generate_pre_group(ciphersuite);
512 let dave_pre_group = dave_party.generate_pre_group(ciphersuite);
513
514 let mls_group_create_config = MlsGroupCreateConfig::builder()
516 .ciphersuite(ciphersuite)
517 .use_ratchet_tree_extension(true)
518 .build();
519
520 let mls_group_join_config = mls_group_create_config.join_config().clone();
522
523 let group_id = GroupId::from_slice(b"test");
525 let mut group_state =
526 GroupState::new_from_party(group_id, alice_pre_group, mls_group_create_config).unwrap();
527
528 group_state
529 .add_member(AddMemberConfig {
530 adder: "alice",
531 addees: vec![bob_pre_group, charlie_pre_group, dave_pre_group],
532 join_config: mls_group_join_config.clone(),
533 tree: None,
534 })
535 .expect("Could not add member");
536
537 let [alice, bob, charlie, dave] =
539 group_state.members_mut(&["alice", "bob", "charlie", "dave"]);
540 assert_eq!(alice.party.core_state.name, "alice");
541 assert_eq!(bob.party.core_state.name, "bob");
542 assert_eq!(charlie.party.core_state.name, "charlie");
543 assert_eq!(dave.party.core_state.name, "dave");
544
545 let [dave, charlie, bob, alice] =
546 group_state.members_mut(&["dave", "charlie", "bob", "alice"]);
547 assert_eq!(alice.party.core_state.name, "alice");
548 assert_eq!(bob.party.core_state.name, "bob");
549 assert_eq!(charlie.party.core_state.name, "charlie");
550 assert_eq!(dave.party.core_state.name, "dave");
551
552 let [dave, bob, charlie, alice] =
553 group_state.members_mut(&["dave", "bob", "charlie", "alice"]);
554 assert_eq!(alice.party.core_state.name, "alice");
555 assert_eq!(bob.party.core_state.name, "bob");
556 assert_eq!(charlie.party.core_state.name, "charlie");
557 assert_eq!(dave.party.core_state.name, "dave");
558
559 let [dave, bob] = group_state.members_mut(&["dave", "bob"]);
560 assert_eq!(bob.party.core_state.name, "bob");
561 assert_eq!(dave.party.core_state.name, "dave");
562
563 let [alice, charlie] = group_state.members_mut(&["alice", "charlie"]);
564 assert_eq!(alice.party.core_state.name, "alice");
565 assert_eq!(charlie.party.core_state.name, "charlie");
566 }
567 #[openmls_test]
568 pub fn simpler_example() {
569 let alice_party = CorePartyState::<Provider>::new("alice");
570 let bob_party = CorePartyState::<Provider>::new("bob");
571 let charlie_party = CorePartyState::<Provider>::new("charlie");
572 let dave_party = CorePartyState::<Provider>::new("dave");
573
574 let alice_pre_group = alice_party.generate_pre_group(ciphersuite);
575 let bob_pre_group = bob_party.generate_pre_group(ciphersuite);
576 let charlie_pre_group = charlie_party.generate_pre_group(ciphersuite);
577 let dave_pre_group = dave_party.generate_pre_group(ciphersuite);
578
579 let mls_group_create_config = MlsGroupCreateConfig::builder()
581 .ciphersuite(ciphersuite)
582 .use_ratchet_tree_extension(true)
583 .build();
584
585 let mls_group_join_config = mls_group_create_config.join_config().clone();
587
588 let group_id = GroupId::from_slice(b"test");
590 let mut group_state =
591 GroupState::new_from_party(group_id, alice_pre_group, mls_group_create_config).unwrap();
592
593 group_state
594 .add_member(AddMemberConfig {
595 adder: "alice",
596 addees: vec![bob_pre_group, charlie_pre_group],
597 join_config: mls_group_join_config.clone(),
598 tree: None,
599 })
600 .expect("Could not add member");
601
602 group_state.assert_membership();
603
604 group_state
605 .add_member(AddMemberConfig {
606 adder: "bob",
607 addees: vec![dave_pre_group],
608 join_config: mls_group_join_config,
609 tree: None,
610 })
611 .expect("Could not add member");
612
613 group_state.assert_membership();
614 }
615
616 #[openmls_test]
617 pub fn simple_example() {
618 let alice_party = CorePartyState::<Provider>::new("alice");
619 let bob_party = CorePartyState::<Provider>::new("bob");
620
621 let alice_pre_group = alice_party.generate_pre_group(ciphersuite);
622 let bob_pre_group = bob_party.generate_pre_group(ciphersuite);
623
624 let bob_key_package = bob_pre_group.key_package_bundle.key_package.clone();
627
628 let mls_group_create_config = MlsGroupCreateConfig::builder()
630 .ciphersuite(ciphersuite)
631 .use_ratchet_tree_extension(true)
632 .build();
633
634 let mls_group_join_config = mls_group_create_config.join_config().clone();
636
637 let group_id = GroupId::from_slice(b"test");
639 let mut group_state =
640 GroupState::new_from_party(group_id, alice_pre_group, mls_group_create_config).unwrap();
641
642 let [alice] = group_state.members_mut(&["alice"]);
644
645 let bundle = alice
647 .build_commit_and_stage(move |builder| {
648 let add_proposal = Proposal::add(AddProposal {
649 key_package: bob_key_package,
650 });
651
652 builder
655 .consume_proposal_store(false)
656 .add_proposal(add_proposal)
657 })
658 .expect("Could not stage commit");
659
660 let welcome = bundle.welcome().unwrap().clone();
662 group_state
663 .deliver_and_apply_welcome(bob_pre_group, mls_group_join_config, welcome, None)
664 .expect("Error delivering and applying welcome");
665
666 let [alice] = group_state.members_mut(&["alice"]);
667
668 let staged_commit = alice.group.pending_commit().unwrap().clone();
669
670 alice
671 .group
672 .merge_staged_commit(&alice.party.core_state.provider, staged_commit)
673 .expect("Error merging staged commit");
674
675 group_state.assert_membership();
676 }
677}