openmls/test_utils/test_framework/
client.rs

1//! This module provides the `Client` datastructure, which contains the state
2//! associated with a client in the context of MLS, along with functions to have
3//! that client perform certain MLS operations.
4use std::{collections::HashMap, sync::RwLock};
5
6use commit_builder::CommitMessageBundle;
7use openmls_basic_credential::SignatureKeyPair;
8use openmls_traits::{
9    types::{Ciphersuite, HpkeKeyPair, SignatureScheme},
10    OpenMlsProvider as _,
11};
12use tls_codec::{Deserialize, Serialize};
13
14use super::OpenMlsRustCrypto;
15
16use crate::{
17    binary_tree::array_representation::LeafNodeIndex,
18    ciphersuite::hash_ref::KeyPackageRef,
19    credentials::*,
20    extensions::*,
21    framing::*,
22    group::*,
23    key_packages::*,
24    messages::{group_info::GroupInfo, *},
25    storage::OpenMlsProvider,
26    treesync::{
27        node::{leaf_node::Capabilities, Node},
28        LeafNode, LeafNodeParameters, RatchetTree, RatchetTreeIn,
29    },
30    versions::ProtocolVersion,
31};
32
33use super::{errors::ClientError, ActionType};
34
35#[derive(Debug)]
36/// The client contains the necessary state for a client in the context of MLS.
37/// It contains the group states, as well as a reference to a `KeyStore`
38/// containing its `CredentialWithKey`s. The `key_package_bundles` field
39/// contains generated `KeyPackageBundle`s that are waiting to be used for new
40/// groups.
41pub struct Client<Provider: OpenMlsProvider> {
42    /// Name of the client.
43    pub identity: Vec<u8>,
44    /// Ciphersuites supported by the client.
45    pub credentials: HashMap<Ciphersuite, CredentialWithKey>,
46    pub provider: Provider,
47    pub groups: RwLock<HashMap<GroupId, MlsGroup>>,
48}
49
50impl<Provider: OpenMlsProvider> Client<Provider> {
51    /// Generate a fresh key package and return it.
52    /// The first ciphersuite determines the
53    /// credential used to generate the `KeyPackage`.
54    pub fn get_fresh_key_package(
55        &self,
56        ciphersuite: Ciphersuite,
57    ) -> Result<KeyPackage, ClientError<Provider::StorageError>> {
58        let credential_with_key = self
59            .credentials
60            .get(&ciphersuite)
61            .ok_or(ClientError::CiphersuiteNotSupported)?;
62        let keys = SignatureKeyPair::read(
63            self.provider.storage(),
64            credential_with_key.signature_key.as_slice(),
65            ciphersuite.signature_algorithm(),
66        )
67        .unwrap();
68
69        let key_package = KeyPackage::builder()
70            .build(
71                ciphersuite,
72                &self.provider,
73                &keys,
74                credential_with_key.clone(),
75            )
76            .unwrap();
77
78        Ok(key_package.key_package)
79    }
80
81    /// Create a group with the given [MlsGroupCreateConfig] and [Ciphersuite], and return the created [GroupId].
82    ///
83    /// Returns an error if the client doesn't support the `ciphersuite`.
84    pub fn create_group(
85        &self,
86        mls_group_create_config: MlsGroupCreateConfig,
87        ciphersuite: Ciphersuite,
88    ) -> Result<GroupId, ClientError<Provider::StorageError>> {
89        let credential_with_key = self
90            .credentials
91            .get(&ciphersuite)
92            .ok_or(ClientError::CiphersuiteNotSupported);
93        let credential_with_key = credential_with_key?;
94        let signer = SignatureKeyPair::read(
95            self.provider.storage(),
96            credential_with_key.signature_key.as_slice(),
97            ciphersuite.signature_algorithm(),
98        )
99        .unwrap();
100
101        let group_state = MlsGroup::new(
102            &self.provider,
103            &signer,
104            &mls_group_create_config,
105            credential_with_key.clone(),
106        )?;
107        let group_id = group_state.group_id().clone();
108        self.groups
109            .write()
110            .expect("An unexpected error occurred.")
111            .insert(group_state.group_id().clone(), group_state);
112        Ok(group_id)
113    }
114
115    /// Join a group based on the given `welcome` and `ratchet_tree`. The group
116    /// is created with the given `MlsGroupCreateConfig`. Throws an error if no
117    /// `KeyPackage` exists matching the `Welcome`, if the client doesn't
118    /// support the ciphersuite, or if an error occurs processing the `Welcome`.
119    pub fn join_group(
120        &self,
121        mls_group_config: MlsGroupJoinConfig,
122        welcome: Welcome,
123        ratchet_tree: Option<RatchetTreeIn>,
124    ) -> Result<(), ClientError<Provider::StorageError>> {
125        let staged_join = StagedWelcome::new_from_welcome(
126            &self.provider,
127            &mls_group_config,
128            welcome,
129            ratchet_tree,
130        )?;
131        let new_group = staged_join.into_group(&self.provider)?;
132        self.groups
133            .write()
134            .expect("An unexpected error occurred.")
135            .insert(new_group.group_id().to_owned(), new_group);
136        Ok(())
137    }
138
139    /// Have the client process the given messages. Returns an error if an error
140    /// occurs during message processing or if no group exists for one of the
141    /// messages.
142    pub fn receive_messages_for_group<AS: Fn(&Credential) -> bool>(
143        &self,
144        message: &ProtocolMessage,
145        sender_id: &[u8],
146        authentication_service: &AS,
147    ) -> Result<(), ClientError<Provider::StorageError>> {
148        let mut group_states = self.groups.write().expect("An unexpected error occurred.");
149        let group_id = message.group_id();
150        let group_state = group_states
151            .get_mut(group_id)
152            .ok_or(ClientError::NoMatchingGroup)?;
153        if sender_id == self.identity && message.content_type() == ContentType::Commit {
154            group_state.merge_pending_commit(&self.provider)?
155        } else {
156            if message.content_type() == ContentType::Commit {
157                // Clear any potential pending commits.
158                group_state.clear_pending_commit(self.provider.storage())?;
159            }
160            // Process the message.
161            let processed_message = group_state
162                .process_message(&self.provider, message.clone())
163                .map_err(ClientError::ProcessMessageError)?;
164
165            match processed_message.into_content() {
166                ProcessedMessageContent::ApplicationMessage(_) => {}
167                ProcessedMessageContent::ProposalMessage(staged_proposal) => {
168                    group_state
169                        .store_pending_proposal(self.provider.storage(), *staged_proposal)?;
170                }
171                ProcessedMessageContent::ExternalJoinProposalMessage(staged_proposal) => {
172                    group_state
173                        .store_pending_proposal(self.provider.storage(), *staged_proposal)?;
174                }
175                ProcessedMessageContent::StagedCommitMessage(staged_commit) => {
176                    for credential in staged_commit.credentials_to_verify() {
177                        if !authentication_service(credential) {
178                            println!(
179                                "authentication service callback denied credential {credential:?}"
180                            );
181                            return Err(ClientError::NoMatchingCredential);
182                        }
183                    }
184                    group_state.merge_staged_commit(&self.provider, *staged_commit)?;
185                }
186            }
187        }
188
189        Ok(())
190    }
191
192    /// Get the credential and the index of each group member of the group with
193    /// the given id. Returns an error if no group exists with the given group
194    /// id.
195    pub fn get_members_of_group(
196        &self,
197        group_id: &GroupId,
198    ) -> Result<Vec<Member>, ClientError<Provider::StorageError>> {
199        let groups = self.groups.read().expect("An unexpected error occurred.");
200        let group = groups.get(group_id).ok_or(ClientError::NoMatchingGroup)?;
201        let members = group.members().collect();
202        Ok(members)
203    }
204
205    /// Have the client either propose or commit (depending on the
206    /// `action_type`) a self update in the group with the given group id.
207    /// Optionally, a `HpkeKeyPair` can be provided, which the client will
208    /// update their leaf with. Returns an error if no group with the given
209    /// group id can be found or if an error occurs while creating the update.
210    #[allow(clippy::type_complexity)]
211    pub fn self_update(
212        &self,
213        action_type: ActionType,
214        group_id: &GroupId,
215        leaf_node_parameters: LeafNodeParameters,
216    ) -> Result<
217        (MlsMessageOut, Option<Welcome>, Option<GroupInfo>),
218        ClientError<Provider::StorageError>,
219    > {
220        let mut groups = self.groups.write().expect("An unexpected error occurred.");
221        let group = groups
222            .get_mut(group_id)
223            .ok_or(ClientError::NoMatchingGroup)?;
224        // Get the signature public key to read the signer from the
225        // key store.
226        let signature_pk = group.own_leaf().unwrap().signature_key();
227        let signer = SignatureKeyPair::read(
228            self.provider.storage(),
229            signature_pk.as_slice(),
230            group.ciphersuite().signature_algorithm(),
231        )
232        .unwrap();
233        let (msg, welcome_option, group_info) = match action_type {
234            ActionType::Commit => {
235                let bundle =
236                    group.self_update(&self.provider, &signer, LeafNodeParameters::default())?;
237
238                let welcome = bundle.to_welcome_msg();
239                let (msg, _, group_info) = bundle.into_contents();
240
241                (msg, welcome, group_info)
242            }
243            ActionType::Proposal => {
244                let (msg, _) =
245                    group.propose_self_update(&self.provider, &signer, leaf_node_parameters)?;
246
247                (msg, None, None)
248            }
249        };
250        Ok((
251            msg,
252            welcome_option.map(|w| w.into_welcome().expect("Unexpected message type.")),
253            group_info,
254        ))
255    }
256
257    /// Have the client either propose or commit (depending on the
258    /// `action_type`) adding the clients with the given `KeyPackage`s to the
259    /// group with the given group id. Returns an error if no group with the
260    /// given group id can be found or if an error occurs while performing the
261    /// add operation.
262    #[allow(clippy::type_complexity)]
263    pub fn add_members(
264        &self,
265        action_type: ActionType,
266        group_id: &GroupId,
267        key_packages: &[KeyPackage],
268    ) -> Result<
269        (Vec<MlsMessageOut>, Option<Welcome>, Option<GroupInfo>),
270        ClientError<Provider::StorageError>,
271    > {
272        let mut groups = self.groups.write().expect("An unexpected error occurred.");
273        let group = groups
274            .get_mut(group_id)
275            .ok_or(ClientError::NoMatchingGroup)?;
276        // Get the signature public key to read the signer from the
277        // key store.
278        let signature_pk = group.own_leaf().unwrap().signature_key();
279        let signer = SignatureKeyPair::read(
280            self.provider.storage(),
281            signature_pk.as_slice(),
282            group.ciphersuite().signature_algorithm(),
283        )
284        .unwrap();
285        let action_results = match action_type {
286            ActionType::Commit => {
287                let (messages, welcome_message, group_info) =
288                    group.add_members(&self.provider, &signer, key_packages)?;
289                (
290                    vec![messages],
291                    Some(
292                        welcome_message
293                            .into_welcome()
294                            .expect("Unexpected message type."),
295                    ),
296                    group_info,
297                )
298            }
299            ActionType::Proposal => {
300                let mut messages = Vec::new();
301                for key_package in key_packages {
302                    let message = group
303                        .propose_add_member(&self.provider, &signer, key_package)
304                        .map(|(out, _)| out)?;
305                    messages.push(message);
306                }
307                (messages, None, None)
308            }
309        };
310        Ok(action_results)
311    }
312
313    /// Have the client either propose or commit (depending on the
314    /// `action_type`) removing the clients with the given indices from the
315    /// group with the given group id. Returns an error if no group with the
316    /// given group id can be found or if an error occurs while performing the
317    /// remove operation.
318    #[allow(clippy::type_complexity)]
319    pub fn remove_members(
320        &self,
321        action_type: ActionType,
322        group_id: &GroupId,
323        targets: &[LeafNodeIndex],
324    ) -> Result<
325        (Vec<MlsMessageOut>, Option<Welcome>, Option<GroupInfo>),
326        ClientError<Provider::StorageError>,
327    > {
328        let mut groups = self.groups.write().expect("An unexpected error occurred.");
329        let group = groups
330            .get_mut(group_id)
331            .ok_or(ClientError::NoMatchingGroup)?;
332        // Get the signature public key to read the signer from the
333        // key store.
334        let signature_pk = group.own_leaf().unwrap().signature_key();
335        let signer = SignatureKeyPair::read(
336            self.provider.storage(),
337            signature_pk.as_slice(),
338            group.ciphersuite().signature_algorithm(),
339        )
340        .unwrap();
341        let action_results = match action_type {
342            ActionType::Commit => {
343                let (message, welcome_option, group_info) =
344                    group.remove_members(&self.provider, &signer, targets)?;
345                (
346                    vec![message],
347                    welcome_option.map(|w| w.into_welcome().expect("Unexpected message type.")),
348                    group_info,
349                )
350            }
351            ActionType::Proposal => {
352                let mut messages = Vec::new();
353                for target in targets {
354                    let message = group
355                        .propose_remove_member(&self.provider, &signer, *target)
356                        .map(|(out, _)| out)?;
357                    messages.push(message);
358                }
359                (messages, None, None)
360            }
361        };
362        Ok(action_results)
363    }
364
365    /// Get the identity of this client in the given group.
366    pub fn identity(&self, group_id: &GroupId) -> Option<Vec<u8>> {
367        let groups = self.groups.read().unwrap();
368        let group = groups.get(group_id).unwrap();
369        let leaf = group.own_leaf();
370        leaf.map(|l| {
371            let credential = BasicCredential::try_from(l.credential().clone()).unwrap();
372            credential.identity().to_vec()
373        })
374    }
375}