openmls/schedule/tests_and_kats/kats/
key_schedule.rs

1//! # Known Answer Tests for the key schedule
2//!
3//! See <https://github.com/mlswg/mls-implementations/blob/master/test-vectors.md>
4//! for more description on the test vectors.
5//!
6//! If values are not present, they are encoded as empty strings.
7
8use log::info;
9use openmls_traits::{random::OpenMlsRand, types::HpkeKeyPair, OpenMlsProvider};
10use serde::{self, Deserialize, Serialize};
11use tls_codec::Serialize as TlsSerializeTrait;
12
13#[cfg(test)]
14use crate::test_utils::write;
15use crate::{
16    ciphersuite::*,
17    extensions::Extensions,
18    group::*,
19    schedule::{errors::KsTestVectorError, CommitSecret, *},
20    test_utils::*,
21};
22
23#[derive(Serialize, Deserialize, Debug, Clone, Default)]
24struct Exporter {
25    label: String,
26    context: String,
27    length: u32,
28    secret: String,
29}
30
31#[derive(Serialize, Deserialize, Debug, Clone, Default)]
32struct Epoch {
33    // Chosen by the generator
34    tree_hash: String,
35    commit_secret: String,
36    psk_secret: String,
37    confirmed_transcript_hash: String,
38
39    // Computed values
40    group_context: String,
41    joiner_secret: String,
42    welcome_secret: String,
43    init_secret: String,
44    sender_data_secret: String,
45    encryption_secret: String,
46    exporter_secret: String,
47    epoch_authenticator: String,
48    external_secret: String,
49    confirmation_key: String,
50    membership_key: String,
51    resumption_psk: String,
52
53    external_pub: String,
54    exporter: Exporter,
55}
56
57#[derive(Serialize, Deserialize, Debug, Clone, Default)]
58pub struct KeyScheduleTestVector {
59    pub cipher_suite: u16,
60    group_id: String,
61    initial_init_secret: String,
62    epochs: Vec<Epoch>,
63}
64
65// Ignore clippy warning since this just used for testing
66#[allow(clippy::type_complexity)]
67fn generate(
68    ciphersuite: Ciphersuite,
69    init_secret: &InitSecret,
70    group_id: &[u8],
71    epoch: u64,
72) -> (
73    Vec<u8>,
74    CommitSecret,
75    PskSecret,
76    JoinerSecret,
77    WelcomeSecret,
78    EpochSecrets,
79    Vec<u8>,
80    GroupContext,
81    HpkeKeyPair,
82) {
83    let crypto = OpenMlsRustCrypto::default();
84    let tree_hash = crypto
85        .rand()
86        .random_vec(ciphersuite.hash_length())
87        .expect("An unexpected error occurred.");
88    let commit_secret = CommitSecret::random(ciphersuite, crypto.rand());
89
90    let confirmed_transcript_hash = crypto
91        .rand()
92        .random_vec(ciphersuite.hash_length())
93        .expect("An unexpected error occurred.");
94
95    // PSK secret can sometimes be the all zero vector
96    let a: [u8; 1] = crypto.rand().random_array().unwrap();
97    let psk_secret = if a[0] > 127 {
98        PskSecret::from(Secret::random(ciphersuite, crypto.rand()).unwrap())
99    } else {
100        PskSecret::from(Secret::zero(ciphersuite))
101    };
102
103    let group_context = GroupContext::new(
104        ciphersuite,
105        GroupId::from_slice(group_id),
106        epoch,
107        tree_hash.to_vec(),
108        confirmed_transcript_hash.clone(),
109        Extensions::empty(),
110    );
111
112    let joiner_secret = JoinerSecret::new(
113        crypto.crypto(),
114        ciphersuite,
115        commit_secret.clone(),
116        init_secret,
117        &group_context.tls_serialize_detached().unwrap(),
118    )
119    .expect("Could not create JoinerSecret.");
120    let mut key_schedule = KeySchedule::init(
121        ciphersuite,
122        crypto.crypto(),
123        &joiner_secret,
124        psk_secret.clone(),
125    )
126    .expect("Could not create KeySchedule.");
127    let welcome_secret = key_schedule
128        .welcome(crypto.crypto(), ciphersuite)
129        .expect("An unexpected error occurred.");
130
131    let serialized_group_context = group_context
132        .tls_serialize_detached()
133        .expect("Could not serialize group context.");
134
135    key_schedule
136        .add_context(crypto.crypto(), &serialized_group_context)
137        .expect("An unexpected error occurred.");
138    let EpochSecretsResult { epoch_secrets, .. } = key_schedule
139        .epoch_secrets(crypto.crypto(), ciphersuite)
140        .expect("An unexpected error occurred.");
141
142    // Calculate external HPKE key pair
143    let external_key_pair = epoch_secrets
144        .external_secret()
145        .derive_external_keypair(crypto.crypto(), ciphersuite)
146        .expect("An unexpected crypto error occurred.");
147
148    (
149        confirmed_transcript_hash,
150        commit_secret,
151        psk_secret,
152        joiner_secret,
153        welcome_secret,
154        epoch_secrets,
155        tree_hash,
156        group_context,
157        external_key_pair,
158    )
159}
160
161#[cfg(any(feature = "test-utils", test))]
162pub fn generate_test_vector(
163    n_epochs: u64,
164    ciphersuite: Ciphersuite,
165    provider: &impl OpenMlsProvider,
166) -> KeyScheduleTestVector {
167    use tls_codec::Serialize;
168
169    // Set up setting.
170    let mut init_secret =
171        InitSecret::random(ciphersuite, provider.rand()).expect("Not enough randomness.");
172    let initial_init_secret = init_secret.clone();
173    let group_id = provider
174        .rand()
175        .random_vec(16)
176        .expect("An unexpected error occurred.");
177
178    let mut epochs = Vec::new();
179
180    // Generate info for all epochs
181    for epoch in 0..n_epochs {
182        println!("Generating epoch: {epoch:?}");
183        let (
184            confirmed_transcript_hash,
185            commit_secret,
186            psk_secret,
187            joiner_secret,
188            welcome_secret,
189            epoch_secrets,
190            tree_hash,
191            group_context,
192            external_key_pair,
193        ) = generate(ciphersuite, &init_secret, &group_id, epoch);
194
195        // exporter
196        let exporter_label = "exporter label";
197        let exporter_length = 32u32;
198        let exporter_context = b"exporter context";
199        let exported = epoch_secrets
200            .exporter_secret()
201            .derive_exported_secret(
202                ciphersuite,
203                provider.crypto(),
204                exporter_label,
205                exporter_context,
206                exporter_length as usize,
207            )
208            .unwrap();
209
210        let epoch_info = Epoch {
211            tree_hash: bytes_to_hex(&tree_hash),
212            commit_secret: bytes_to_hex(commit_secret.as_slice()),
213            psk_secret: bytes_to_hex(psk_secret.as_slice()),
214            confirmed_transcript_hash: bytes_to_hex(&confirmed_transcript_hash),
215            group_context: bytes_to_hex(
216                &group_context
217                    .tls_serialize_detached()
218                    .expect("An unexpected error occurred."),
219            ),
220            joiner_secret: bytes_to_hex(joiner_secret.as_slice()),
221            welcome_secret: bytes_to_hex(welcome_secret.as_slice()),
222            init_secret: bytes_to_hex(epoch_secrets.init_secret().as_slice()),
223            sender_data_secret: bytes_to_hex(epoch_secrets.sender_data_secret().as_slice()),
224            encryption_secret: bytes_to_hex(epoch_secrets.encryption_secret().as_slice()),
225            exporter_secret: bytes_to_hex(epoch_secrets.exporter_secret().as_slice()),
226            epoch_authenticator: bytes_to_hex(epoch_secrets.epoch_authenticator().as_slice()),
227            external_secret: bytes_to_hex(epoch_secrets.external_secret().as_slice()),
228            confirmation_key: bytes_to_hex(epoch_secrets.confirmation_key().as_slice()),
229            membership_key: bytes_to_hex(epoch_secrets.membership_key().as_slice()),
230            resumption_psk: bytes_to_hex(epoch_secrets.resumption_psk().as_slice()),
231            external_pub: bytes_to_hex(&external_key_pair.public),
232            exporter: Exporter {
233                label: exporter_label.into(),
234                context: bytes_to_hex(exporter_context),
235                length: exporter_length,
236                secret: bytes_to_hex(&exported),
237            },
238        };
239        epochs.push(epoch_info);
240        init_secret = epoch_secrets.init_secret().clone();
241    }
242
243    KeyScheduleTestVector {
244        cipher_suite: ciphersuite as u16,
245        group_id: bytes_to_hex(&group_id),
246        initial_init_secret: bytes_to_hex(initial_init_secret.as_slice()),
247        epochs,
248    }
249}
250
251#[test]
252fn write_test_vectors() {
253    const NUM_EPOCHS: u64 = 2;
254    let mut tests = Vec::new();
255    let provider = OpenMlsRustCrypto::default();
256    for &ciphersuite in provider.crypto().supported_ciphersuites().iter() {
257        tests.push(generate_test_vector(NUM_EPOCHS, ciphersuite, &provider));
258    }
259    write("test_vectors/key-schedule-new.json", &tests);
260}
261
262#[openmls_test::openmls_test]
263fn read_test_vectors_key_schedule() {
264    let provider = &Provider::default();
265
266    let _ = pretty_env_logger::try_init();
267
268    let tests: Vec<KeyScheduleTestVector> =
269        read_json!("../../../../test_vectors/key-schedule.json");
270
271    for test_vector in tests {
272        match run_test_vector(test_vector, provider) {
273            Ok(_) => {}
274            Err(e) => panic!("Error while checking key schedule test vector.\n{e:?}"),
275        }
276    }
277}
278
279#[cfg(any(feature = "test-utils", test))]
280pub fn run_test_vector(
281    test_vector: KeyScheduleTestVector,
282    provider: &impl OpenMlsProvider,
283) -> Result<(), KsTestVectorError> {
284    let ciphersuite = Ciphersuite::try_from(test_vector.cipher_suite).expect("Invalid ciphersuite");
285    log::trace!("  {test_vector:?}");
286
287    if !provider
288        .crypto()
289        .supported_ciphersuites()
290        .contains(&ciphersuite)
291    {
292        info!("Skipping unsupported ciphersuite `{ciphersuite:?}`.");
293        return Ok(());
294    }
295
296    let group_id = hex_to_bytes(&test_vector.group_id);
297    let init_secret = hex_to_bytes(&test_vector.initial_init_secret);
298    log::trace!(
299        "  InitSecret from tve: {:?}",
300        test_vector.initial_init_secret
301    );
302    let mut init_secret = InitSecret::from(Secret::from_slice(&init_secret));
303
304    for (epoch_ctr, epoch) in test_vector.epochs.iter().enumerate() {
305        let tree_hash = hex_to_bytes(&epoch.tree_hash);
306        let secret = hex_to_bytes(&epoch.commit_secret);
307        let commit_secret = CommitSecret::from(PathSecret::from(Secret::from_slice(&secret)));
308        log::trace!("    CommitSecret from tve {:?}", epoch.commit_secret);
309
310        let confirmed_transcript_hash = hex_to_bytes(&epoch.confirmed_transcript_hash);
311
312        let group_context = GroupContext::new(
313            ciphersuite,
314            GroupId::from_slice(&group_id),
315            GroupEpoch::from(epoch_ctr as u64),
316            tree_hash.to_vec(),
317            confirmed_transcript_hash.clone(),
318            Extensions::empty(),
319        );
320
321        let joiner_secret = JoinerSecret::new(
322            provider.crypto(),
323            ciphersuite,
324            commit_secret,
325            &init_secret,
326            &group_context.tls_serialize_detached().unwrap(),
327        )
328        .expect("Could not create JoinerSecret.");
329        if hex_to_bytes(&epoch.joiner_secret) != joiner_secret.as_slice() {
330            if cfg!(test) {
331                panic!("Joiner secret mismatch");
332            }
333            return Err(KsTestVectorError::JoinerSecretMismatch);
334        }
335
336        let psk_secret_inner = Secret::from_slice(&hex_to_bytes(&epoch.psk_secret));
337        let psk_secret = PskSecret::from(psk_secret_inner);
338
339        let mut key_schedule =
340            KeySchedule::init(ciphersuite, provider.crypto(), &joiner_secret, psk_secret)
341                .expect("Could not create KeySchedule.");
342        let welcome_secret = key_schedule
343            .welcome(provider.crypto(), ciphersuite)
344            .expect("An unexpected error occurred.");
345
346        if hex_to_bytes(&epoch.welcome_secret) != welcome_secret.as_slice() {
347            if cfg!(test) {
348                panic!("Welcome secret mismatch");
349            }
350            return Err(KsTestVectorError::WelcomeSecretMismatch);
351        }
352
353        let expected_group_context = hex_to_bytes(&epoch.group_context);
354        let group_context_serialized = group_context
355            .tls_serialize_detached()
356            .expect("An unexpected error occurred.");
357        if group_context_serialized != expected_group_context {
358            log::error!("  Group context mismatch");
359            log::debug!("    Computed: {group_context_serialized:x?}");
360            log::debug!("    Expected: {expected_group_context:x?}");
361            if cfg!(test) {
362                panic!("Group context mismatch");
363            }
364            return Err(KsTestVectorError::GroupContextMismatch);
365        }
366
367        key_schedule
368            .add_context(provider.crypto(), &group_context_serialized)
369            .expect("An unexpected error occurred.");
370
371        let EpochSecretsResult { epoch_secrets, .. } = key_schedule
372            .epoch_secrets(provider.crypto(), ciphersuite)
373            .expect("An unexpected error occurred.");
374
375        init_secret = epoch_secrets.init_secret().clone();
376        if hex_to_bytes(&epoch.init_secret) != init_secret.as_slice() {
377            log_crypto!(
378                debug,
379                "    Epoch secret mismatch: {:x?} != {:x?}",
380                hex_to_bytes(&epoch.init_secret),
381                init_secret.as_slice()
382            );
383            if cfg!(test) {
384                panic!("Init secret mismatch");
385            }
386            return Err(KsTestVectorError::InitSecretMismatch);
387        }
388        if hex_to_bytes(&epoch.sender_data_secret) != epoch_secrets.sender_data_secret().as_slice()
389        {
390            if cfg!(test) {
391                panic!("Sender data secret mismatch");
392            }
393            return Err(KsTestVectorError::SenderDataSecretMismatch);
394        }
395        if hex_to_bytes(&epoch.encryption_secret) != epoch_secrets.encryption_secret().as_slice() {
396            if cfg!(test) {
397                panic!("Encryption secret mismatch");
398            }
399            return Err(KsTestVectorError::EncryptionSecretMismatch);
400        }
401        if hex_to_bytes(&epoch.exporter_secret) != epoch_secrets.exporter_secret().as_slice() {
402            if cfg!(test) {
403                panic!("Exporter secret mismatch");
404            }
405            return Err(KsTestVectorError::ExporterSecretMismatch);
406        }
407        if hex_to_bytes(&epoch.epoch_authenticator)
408            != epoch_secrets.epoch_authenticator().as_slice()
409        {
410            if cfg!(test) {
411                panic!("Epoch authenticator mismatch");
412            }
413            return Err(KsTestVectorError::EpochAuthenticatorMismatch);
414        }
415        if hex_to_bytes(&epoch.external_secret) != epoch_secrets.external_secret().as_slice() {
416            if cfg!(test) {
417                panic!("External secret mismatch");
418            }
419            return Err(KsTestVectorError::ExternalSecretMismatch);
420        }
421        if hex_to_bytes(&epoch.confirmation_key) != epoch_secrets.confirmation_key().as_slice() {
422            if cfg!(test) {
423                panic!("Confirmation key mismatch");
424            }
425            return Err(KsTestVectorError::ConfirmationKeyMismatch);
426        }
427        if hex_to_bytes(&epoch.membership_key) != epoch_secrets.membership_key().as_slice() {
428            if cfg!(test) {
429                panic!("Membership key mismatch");
430            }
431            return Err(KsTestVectorError::MembershipKeyMismatch);
432        }
433        if hex_to_bytes(&epoch.resumption_psk) != epoch_secrets.resumption_psk().as_slice() {
434            if cfg!(test) {
435                panic!("Resumption psk mismatch");
436            }
437            return Err(KsTestVectorError::ResumptionPskMismatch);
438        }
439
440        // Calculate external HPKE key pair
441        let external_key_pair = epoch_secrets
442            .external_secret()
443            .derive_external_keypair(provider.crypto(), ciphersuite)
444            .expect("an unexpected crypto error occurred");
445        if hex_to_bytes(&epoch.external_pub) != external_key_pair.public {
446            log::error!("  External public key mismatch");
447            log::debug!(
448                "    Computed: {:x?}",
449                HpkePublicKey::from(external_key_pair.public)
450                    .tls_serialize_detached()
451                    .expect("An unexpected error occurred.")
452            );
453            log::debug!("    Expected: {:x?}", hex_to_bytes(&epoch.external_pub));
454            if cfg!(test) {
455                panic!("External pub mismatch");
456            }
457            return Err(KsTestVectorError::ExternalPubMismatch);
458        }
459
460        // Check exported secret
461        let exported = epoch_secrets
462            .exporter_secret()
463            .derive_exported_secret(
464                ciphersuite,
465                provider.crypto(),
466                &epoch.exporter.label,
467                &hex_to_bytes(&epoch.exporter.context),
468                epoch.exporter.length as usize,
469            )
470            .unwrap();
471        if hex_to_bytes(&epoch.exporter.secret) != exported {
472            if cfg!(test) {
473                panic!("Exporter mismatch");
474            }
475            return Err(KsTestVectorError::ExporterMismatch);
476        }
477    }
478    Ok(())
479}