1use openmls_basic_credential::SignatureKeyPair;
2use openmls_traits::{signatures::Signer, types::SignatureScheme};
3pub use openmls_traits::{
4 storage::StorageProvider as StorageProviderTrait,
5 types::{Ciphersuite, HpkeKeyPair},
6 OpenMlsProvider,
7};
8
9pub use crate::utils::*;
10use crate::{
11 credentials::CredentialWithKey,
12 key_packages::KeyPackageBuilder,
13 prelude::{commit_builder::*, *},
14};
15
16use crate::test_utils::storage_state::GroupStorageState;
17
18mod assertions;
19
20mod errors;
21pub use errors::GroupError;
22use errors::*;
23
24use std::collections::HashMap;
25
26type Name = &'static str;
28
29pub fn generate_credential(
32 identity: Vec<u8>,
33 signature_algorithm: SignatureScheme,
34 provider: &impl crate::storage::OpenMlsProvider,
35) -> (CredentialWithKey, SignatureKeyPair) {
36 let credential = BasicCredential::new(identity);
37 let signature_keys = SignatureKeyPair::new(signature_algorithm).unwrap();
38 signature_keys.store(provider.storage()).unwrap();
39
40 (
41 CredentialWithKey {
42 credential: credential.into(),
43 signature_key: signature_keys.to_public_vec().into(),
44 },
45 signature_keys,
46 )
47}
48
49pub(crate) fn generate_key_package(
51 ciphersuite: Ciphersuite,
52 credential_with_key: CredentialWithKey,
53 extensions: Extensions,
54 provider: &impl crate::storage::OpenMlsProvider,
55 lifetime: impl Into<Option<Lifetime>>,
56 signer: &impl Signer,
57) -> KeyPackageBundle {
58 let mut builder = KeyPackage::builder().key_package_extensions(extensions);
59
60 if let Some(lifetime) = lifetime.into() {
61 builder = builder.key_package_lifetime(lifetime);
62 }
63
64 builder
65 .build(ciphersuite, provider, signer, credential_with_key)
66 .unwrap()
67}
68
69pub struct CorePartyState<Provider> {
71 pub name: Name,
72 pub provider: Provider,
73}
74
75impl<Provider: Default> CorePartyState<Provider> {
76 pub fn new(name: Name) -> Self {
77 Self {
78 name,
79 provider: Provider::default(),
80 }
81 }
82}
83
84pub struct PreGroupPartyState<'a, Provider> {
86 pub credential_with_key: CredentialWithKey,
87 pub key_package_bundle: KeyPackageBundle,
89 pub signer: SignatureKeyPair,
90 pub core_state: &'a CorePartyState<Provider>,
91}
92
93impl<Provider: OpenMlsProvider> CorePartyState<Provider> {
95 pub fn generate_pre_group(&self, ciphersuite: Ciphersuite) -> PreGroupPartyState<'_, Provider> {
97 self.generate_pre_group_lifetime(ciphersuite, None)
98 }
99
100 pub fn generate_pre_group_lifetime(
102 &self,
103 ciphersuite: Ciphersuite,
104 lifetime: impl Into<Option<Lifetime>>,
105 ) -> PreGroupPartyState<'_, Provider> {
106 let (credential_with_key, signer) = generate_credential(
107 self.name.into(),
108 ciphersuite.signature_algorithm(),
109 &self.provider,
110 );
111
112 let key_package_bundle = generate_key_package(
113 ciphersuite,
114 credential_with_key.clone(),
115 Extensions::default(), &self.provider,
117 lifetime,
118 &signer,
119 );
120
121 PreGroupPartyState {
122 credential_with_key,
123 key_package_bundle,
124 signer,
125 core_state: self,
126 }
127 }
128}
129
130pub struct MemberState<'a, Provider> {
132 pub party: PreGroupPartyState<'a, Provider>,
133 pub group: MlsGroup,
134}
135
136impl<Provider: OpenMlsProvider> MemberState<'_, Provider> {
137 pub fn get_storage_signature_key_pair(&self) -> Option<SignatureKeyPair> {
139 let ciphersuite = self
140 .party
141 .key_package_bundle
142 .key_package()
143 .ciphersuite()
144 .into();
145
146 SignatureKeyPair::read(
147 self.party.core_state.provider.storage(),
148 self.party.signer.public(),
149 ciphersuite,
150 )
151 }
152 pub fn group_storage_state(&self) -> GroupStorageState {
154 let storage_provider = self.party.core_state.provider.storage();
155 let group_id = self.group.group_id();
156
157 GroupStorageState::from_storage(storage_provider, group_id)
158 }
159 pub fn deliver_and_apply(&mut self, message: MlsMessageIn) -> Result<(), GroupError<Provider>> {
161 let message = message.try_into_protocol_message()?;
162
163 let processed_message = self
165 .group
166 .process_message(&self.party.core_state.provider, message)?;
167
168 match processed_message.into_content() {
169 ProcessedMessageContent::ApplicationMessage(_) => todo!(),
170 ProcessedMessageContent::ProposalMessage(_) => todo!(),
171 ProcessedMessageContent::ExternalJoinProposalMessage(_) => todo!(),
172 ProcessedMessageContent::StagedCommitMessage(m) => self
173 .group
174 .merge_staged_commit(&self.party.core_state.provider, *m)?,
175 };
176
177 Ok(())
178 }
179}
180
181impl<'commit_builder, 'b: 'commit_builder, 'a: 'b, Provider> MemberState<'a, Provider>
182where
183 Provider: openmls_traits::OpenMlsProvider,
184{
185 pub fn build_commit_and_stage(
187 &'b mut self,
188 f: impl FnOnce(
189 CommitBuilder<'commit_builder, Initial>,
190 ) -> CommitBuilder<'commit_builder, Initial>,
191 ) -> Result<CommitMessageBundle, GroupError<Provider>> {
192 let commit_builder = f(self.group.commit_builder());
193
194 let provider = &self.party.core_state.provider;
195
196 let bundle = commit_builder
198 .load_psks(provider.storage())?
199 .build(
200 provider.rand(),
201 provider.crypto(),
202 &self.party.signer,
203 |_| true,
204 )?
205 .stage_commit(provider)?;
206
207 Ok(bundle)
208 }
209}
210
211impl<'a, Provider: OpenMlsProvider> MemberState<'a, Provider> {
212 pub fn create_from_pre_group(
215 party: PreGroupPartyState<'a, Provider>,
216 mls_group_create_config: MlsGroupCreateConfig,
217 group_id: GroupId,
218 ) -> Result<Self, GroupError<Provider>> {
219 let group = MlsGroup::new_with_group_id(
221 &party.core_state.provider,
222 &party.signer,
223 &mls_group_create_config,
224 group_id,
225 party.credential_with_key.clone(),
226 )?;
227
228 Ok(Self { party, group })
229 }
230 pub fn join_from_pre_group(
233 party: PreGroupPartyState<'a, Provider>,
234 mls_group_join_config: MlsGroupJoinConfig,
235 welcome: Welcome,
236 tree: Option<RatchetTreeIn>,
237 ) -> Result<Self, GroupError<Provider>> {
238 let staged_join = StagedWelcome::new_from_welcome(
239 &party.core_state.provider,
240 &mls_group_join_config,
241 welcome,
242 tree,
243 )?;
244
245 let group = staged_join.into_group(&party.core_state.provider)?;
246
247 Ok(Self { party, group })
248 }
249}
250
251pub struct GroupState<'a, Provider> {
253 group_id: GroupId,
254 members: HashMap<Name, MemberState<'a, Provider>>,
255}
256
257impl<'a, Provider: OpenMlsProvider> GroupState<'a, Provider> {
258 pub fn new_from_party(
260 group_id: GroupId,
261 pre_group_state: PreGroupPartyState<'a, Provider>,
262 mls_group_create_config: MlsGroupCreateConfig,
263 ) -> Result<Self, GroupError<Provider>> {
264 let mut members = HashMap::new();
265
266 let name = pre_group_state.core_state.name;
267 let member_state = MemberState::create_from_pre_group(
268 pre_group_state,
269 mls_group_create_config,
270 group_id.clone(),
271 )?;
272
273 members.insert(name, member_state);
274
275 Ok(Self { group_id, members })
276 }
277
278 pub fn members_mut<const N: usize>(
282 &mut self,
283 names: &[Name; N],
284 ) -> [&mut MemberState<'a, Provider>; N] {
285 assert!(N > 0, "must request at least one member");
286 assert!(
287 N <= self.members.len(),
288 "cannot request more members than available"
289 );
290
291 let mut members: [(_, _); N] = self
293 .members
294 .iter_mut()
295 .filter_map(|(member_name, member)| {
296 let index = names.iter().position(|name| name == member_name)?;
300
301 Some((index, member))
302 })
303 .collect::<Vec<_>>()
305 .try_into()
306 .ok()
307 .expect("At least one requested member not found");
308
309 members.sort_by_key(|(pos, _member)| *pos);
311
312 members.map(|(_pos, member)| member)
313 }
314
315 pub fn deliver_and_apply(&mut self, message: MlsMessageIn) -> Result<(), GroupError<Provider>> {
317 self.deliver_and_apply_if(message, |_| true)
318 }
319 pub fn deliver_and_apply_if(
321 &mut self,
322 message: MlsMessageIn,
323 condition: impl Fn(&MemberState<'a, Provider>) -> bool,
324 ) -> Result<(), GroupError<Provider>> {
325 self.members
326 .values_mut()
327 .filter(|member| condition(member))
328 .try_for_each(|member| member.deliver_and_apply(message.clone()))?;
329
330 Ok(())
331 }
332
333 pub fn deliver_and_apply_welcome(
335 &mut self,
336 recipient: PreGroupPartyState<'a, Provider>,
337 mls_group_join_config: MlsGroupJoinConfig,
338 welcome: Welcome,
339 tree: Option<RatchetTreeIn>,
340 ) -> Result<(), GroupError<Provider>> {
341 let name = recipient.core_state.name;
343
344 let member_state =
345 MemberState::join_from_pre_group(recipient, mls_group_join_config, welcome, tree)?;
346
347 self.members.insert(name, member_state);
349
350 Ok(())
351 }
352
353 pub fn untrack_member(&mut self, name: Name) {
356 let _ = self.members.remove(&name);
357 }
358
359 pub fn add_member(
360 &mut self,
361 add_config: AddMemberConfig<'a, Provider>,
362 ) -> Result<(), GroupError<Provider>> {
363 let adder = self
364 .members
365 .get_mut(add_config.adder)
366 .ok_or(TestError::NoSuchMember)?;
367
368 let key_packages: Vec<_> = add_config
369 .addees
370 .iter()
371 .map(|addee| addee.key_package_bundle.key_package.clone())
372 .collect();
373
374 let (commit, welcome, _) = adder.group.add_members(
375 &adder.party.core_state.provider,
376 &adder.party.signer,
377 &key_packages,
378 )?;
379
380 self.deliver_and_apply_if(commit.into(), |member| {
382 member.party.core_state.name != add_config.adder
383 })?;
384
385 let welcome = match welcome.body() {
387 MlsMessageBodyOut::Welcome(welcome) => welcome.clone(),
388 _ => panic!("No welcome returned"),
389 };
390
391 for addee in add_config.addees.into_iter() {
392 self.deliver_and_apply_welcome(
393 addee,
394 add_config.join_config.clone(),
395 welcome.clone(),
396 None,
397 )?;
398 }
399
400 let adder = self
401 .members
402 .get_mut(add_config.adder)
403 .ok_or(TestError::NoSuchMember)?;
404
405 let staged_commit = adder.group.pending_commit().unwrap().clone();
406
407 adder
408 .group
409 .merge_staged_commit(&adder.party.core_state.provider, staged_commit)?;
410
411 Ok(())
412 }
413
414 pub fn group_id(&self) -> GroupId {
416 self.group_id.clone()
417 }
418}
419
420impl MlsGroupCreateConfig {
421 pub fn test_default_from_ciphersuite(ciphersuite: Ciphersuite) -> Self {
423 MlsGroupCreateConfig::builder()
424 .ciphersuite(ciphersuite)
425 .use_ratchet_tree_extension(true)
426 .wire_format_policy(PURE_PLAINTEXT_WIRE_FORMAT_POLICY) .build()
428 }
429}
430
431pub struct AddMemberConfig<'a, Provider> {
432 pub adder: Name,
433 pub addees: Vec<PreGroupPartyState<'a, Provider>>,
434 pub join_config: MlsGroupJoinConfig,
435 pub tree: Option<RatchetTreeIn>,
436}
437
438#[cfg(test)]
439mod test {
440
441 use super::*;
442 use openmls_test::openmls_test;
443
444 #[openmls_test]
445 fn test_members_mut() {
446 let alice_party = CorePartyState::<Provider>::new("alice");
447 let bob_party = CorePartyState::<Provider>::new("bob");
448 let charlie_party = CorePartyState::<Provider>::new("charlie");
449 let dave_party = CorePartyState::<Provider>::new("dave");
450
451 let alice_pre_group = alice_party.generate_pre_group(ciphersuite);
452 let bob_pre_group = bob_party.generate_pre_group(ciphersuite);
453 let charlie_pre_group = charlie_party.generate_pre_group(ciphersuite);
454 let dave_pre_group = dave_party.generate_pre_group(ciphersuite);
455
456 let mls_group_create_config = MlsGroupCreateConfig::builder()
458 .ciphersuite(ciphersuite)
459 .use_ratchet_tree_extension(true)
460 .build();
461
462 let mls_group_join_config = mls_group_create_config.join_config().clone();
464
465 let group_id = GroupId::from_slice(b"test");
467 let mut group_state =
468 GroupState::new_from_party(group_id, alice_pre_group, mls_group_create_config).unwrap();
469
470 group_state
471 .add_member(AddMemberConfig {
472 adder: "alice",
473 addees: vec![bob_pre_group, charlie_pre_group, dave_pre_group],
474 join_config: mls_group_join_config.clone(),
475 tree: None,
476 })
477 .expect("Could not add member");
478
479 let [alice, bob, charlie, dave] =
481 group_state.members_mut(&["alice", "bob", "charlie", "dave"]);
482 assert_eq!(alice.party.core_state.name, "alice");
483 assert_eq!(bob.party.core_state.name, "bob");
484 assert_eq!(charlie.party.core_state.name, "charlie");
485 assert_eq!(dave.party.core_state.name, "dave");
486
487 let [dave, charlie, bob, alice] =
488 group_state.members_mut(&["dave", "charlie", "bob", "alice"]);
489 assert_eq!(alice.party.core_state.name, "alice");
490 assert_eq!(bob.party.core_state.name, "bob");
491 assert_eq!(charlie.party.core_state.name, "charlie");
492 assert_eq!(dave.party.core_state.name, "dave");
493
494 let [dave, bob, charlie, alice] =
495 group_state.members_mut(&["dave", "bob", "charlie", "alice"]);
496 assert_eq!(alice.party.core_state.name, "alice");
497 assert_eq!(bob.party.core_state.name, "bob");
498 assert_eq!(charlie.party.core_state.name, "charlie");
499 assert_eq!(dave.party.core_state.name, "dave");
500
501 let [dave, bob] = group_state.members_mut(&["dave", "bob"]);
502 assert_eq!(bob.party.core_state.name, "bob");
503 assert_eq!(dave.party.core_state.name, "dave");
504
505 let [alice, charlie] = group_state.members_mut(&["alice", "charlie"]);
506 assert_eq!(alice.party.core_state.name, "alice");
507 assert_eq!(charlie.party.core_state.name, "charlie");
508 }
509 #[openmls_test]
510 pub fn simpler_example() {
511 let alice_party = CorePartyState::<Provider>::new("alice");
512 let bob_party = CorePartyState::<Provider>::new("bob");
513 let charlie_party = CorePartyState::<Provider>::new("charlie");
514 let dave_party = CorePartyState::<Provider>::new("dave");
515
516 let alice_pre_group = alice_party.generate_pre_group(ciphersuite);
517 let bob_pre_group = bob_party.generate_pre_group(ciphersuite);
518 let charlie_pre_group = charlie_party.generate_pre_group(ciphersuite);
519 let dave_pre_group = dave_party.generate_pre_group(ciphersuite);
520
521 let mls_group_create_config = MlsGroupCreateConfig::builder()
523 .ciphersuite(ciphersuite)
524 .use_ratchet_tree_extension(true)
525 .build();
526
527 let mls_group_join_config = mls_group_create_config.join_config().clone();
529
530 let group_id = GroupId::from_slice(b"test");
532 let mut group_state =
533 GroupState::new_from_party(group_id, alice_pre_group, mls_group_create_config).unwrap();
534
535 group_state
536 .add_member(AddMemberConfig {
537 adder: "alice",
538 addees: vec![bob_pre_group, charlie_pre_group],
539 join_config: mls_group_join_config.clone(),
540 tree: None,
541 })
542 .expect("Could not add member");
543
544 group_state.assert_membership();
545
546 group_state
547 .add_member(AddMemberConfig {
548 adder: "bob",
549 addees: vec![dave_pre_group],
550 join_config: mls_group_join_config,
551 tree: None,
552 })
553 .expect("Could not add member");
554
555 group_state.assert_membership();
556 }
557
558 #[openmls_test]
559 pub fn simple_example() {
560 let alice_party = CorePartyState::<Provider>::new("alice");
561 let bob_party = CorePartyState::<Provider>::new("bob");
562
563 let alice_pre_group = alice_party.generate_pre_group(ciphersuite);
564 let bob_pre_group = bob_party.generate_pre_group(ciphersuite);
565
566 let bob_key_package = bob_pre_group.key_package_bundle.key_package.clone();
569
570 let mls_group_create_config = MlsGroupCreateConfig::builder()
572 .ciphersuite(ciphersuite)
573 .use_ratchet_tree_extension(true)
574 .build();
575
576 let mls_group_join_config = mls_group_create_config.join_config().clone();
578
579 let group_id = GroupId::from_slice(b"test");
581 let mut group_state =
582 GroupState::new_from_party(group_id, alice_pre_group, mls_group_create_config).unwrap();
583
584 let [alice] = group_state.members_mut(&["alice"]);
586
587 let bundle = alice
589 .build_commit_and_stage(move |builder| {
590 let add_proposal = Proposal::add(AddProposal {
591 key_package: bob_key_package,
592 });
593
594 builder
597 .consume_proposal_store(false)
598 .add_proposal(add_proposal)
599 })
600 .expect("Could not stage commit");
601
602 let welcome = bundle.welcome().unwrap().clone();
604 group_state
605 .deliver_and_apply_welcome(bob_pre_group, mls_group_join_config, welcome, None)
606 .expect("Error delivering and applying welcome");
607
608 let [alice] = group_state.members_mut(&["alice"]);
609
610 let staged_commit = alice.group.pending_commit().unwrap().clone();
611
612 alice
613 .group
614 .merge_staged_commit(&alice.party.core_state.provider, staged_commit)
615 .expect("Error merging staged commit");
616
617 group_state.assert_membership();
618 }
619}