Skip to main content

openmls/schedule/tests_and_kats/kats/
key_schedule.rs

1//! # Known Answer Tests for the key schedule
2//!
3//! See <https://github.com/mlswg/mls-implementations/blob/master/test-vectors.md>
4//! for more description on the test vectors.
5//!
6//! If values are not present, they are encoded as empty strings.
7
8use log::info;
9use openmls_traits::{random::OpenMlsRand, types::HpkeKeyPair, OpenMlsProvider};
10use serde::{self, Deserialize, Serialize};
11use tls_codec::Serialize as TlsSerializeTrait;
12
13#[cfg(all(test, feature = "generate-kats"))]
14use crate::test_utils::write;
15use crate::{
16    ciphersuite::*,
17    extensions::Extensions,
18    group::*,
19    schedule::{errors::KsTestVectorError, CommitSecret, *},
20    test_utils::*,
21};
22
23#[derive(Serialize, Deserialize, Debug, Clone, Default)]
24struct Exporter {
25    label: String,
26    context: String,
27    length: u32,
28    secret: String,
29}
30
31#[derive(Serialize, Deserialize, Debug, Clone, Default)]
32struct Epoch {
33    // Chosen by the generator
34    tree_hash: String,
35    commit_secret: String,
36    psk_secret: String,
37    confirmed_transcript_hash: String,
38
39    // Computed values
40    group_context: String,
41    joiner_secret: String,
42    welcome_secret: String,
43    init_secret: String,
44    sender_data_secret: String,
45    encryption_secret: String,
46    exporter_secret: String,
47    epoch_authenticator: String,
48    external_secret: String,
49    confirmation_key: String,
50    membership_key: String,
51    resumption_psk: String,
52
53    external_pub: String,
54    exporter: Exporter,
55}
56
57#[derive(Serialize, Deserialize, Debug, Clone, Default)]
58pub struct KeyScheduleTestVector {
59    pub cipher_suite: u16,
60    group_id: String,
61    initial_init_secret: String,
62    epochs: Vec<Epoch>,
63}
64
65// Ignore clippy warning since this just used for testing
66#[allow(clippy::type_complexity)]
67fn generate(
68    ciphersuite: Ciphersuite,
69    init_secret: &InitSecret,
70    group_id: &[u8],
71    epoch: u64,
72) -> (
73    Vec<u8>,
74    CommitSecret,
75    PskSecret,
76    JoinerSecret,
77    WelcomeSecret,
78    EpochSecrets,
79    Vec<u8>,
80    GroupContext,
81    HpkeKeyPair,
82) {
83    let crypto = OpenMlsRustCrypto::default();
84    let tree_hash = crypto
85        .rand()
86        .random_vec(ciphersuite.hash_length())
87        .expect("An unexpected error occurred.");
88    let commit_secret = CommitSecret::random(ciphersuite, crypto.rand());
89
90    let confirmed_transcript_hash = crypto
91        .rand()
92        .random_vec(ciphersuite.hash_length())
93        .expect("An unexpected error occurred.");
94
95    // PSK secret can sometimes be the all zero vector
96    let a: [u8; 1] = crypto.rand().random_array().unwrap();
97    let psk_secret = if a[0] > 127 {
98        PskSecret::from(Secret::random(ciphersuite, crypto.rand()).unwrap())
99    } else {
100        PskSecret::from(Secret::zero(ciphersuite))
101    };
102
103    let group_context = GroupContext::new(
104        ciphersuite,
105        GroupId::from_slice(group_id),
106        epoch,
107        tree_hash.to_vec(),
108        confirmed_transcript_hash.clone(),
109        Extensions::empty(),
110    );
111
112    let joiner_secret = JoinerSecret::new(
113        crypto.crypto(),
114        ciphersuite,
115        commit_secret.clone(),
116        init_secret,
117        &group_context.tls_serialize_detached().unwrap(),
118    )
119    .expect("Could not create JoinerSecret.");
120    let mut key_schedule = KeySchedule::init(
121        ciphersuite,
122        crypto.crypto(),
123        &joiner_secret,
124        psk_secret.clone(),
125    )
126    .expect("Could not create KeySchedule.");
127    let welcome_secret = key_schedule
128        .welcome(crypto.crypto(), ciphersuite)
129        .expect("An unexpected error occurred.");
130
131    let serialized_group_context = group_context
132        .tls_serialize_detached()
133        .expect("Could not serialize group context.");
134
135    key_schedule
136        .add_context(crypto.crypto(), &serialized_group_context)
137        .expect("An unexpected error occurred.");
138    let EpochSecretsResult { epoch_secrets, .. } = key_schedule
139        .epoch_secrets(crypto.crypto(), ciphersuite)
140        .expect("An unexpected error occurred.");
141
142    // Calculate external HPKE key pair
143    let external_key_pair = epoch_secrets
144        .external_secret()
145        .derive_external_keypair(crypto.crypto(), ciphersuite)
146        .expect("An unexpected crypto error occurred.");
147
148    (
149        confirmed_transcript_hash,
150        commit_secret,
151        psk_secret,
152        joiner_secret,
153        welcome_secret,
154        epoch_secrets,
155        tree_hash,
156        group_context,
157        external_key_pair,
158    )
159}
160
161#[cfg(any(feature = "test-utils", test))]
162pub fn generate_test_vector(
163    n_epochs: u64,
164    ciphersuite: Ciphersuite,
165    provider: &impl OpenMlsProvider,
166) -> KeyScheduleTestVector {
167    use tls_codec::Serialize;
168
169    // Set up setting.
170    let mut init_secret =
171        InitSecret::random(ciphersuite, provider.rand()).expect("Not enough randomness.");
172    let initial_init_secret = init_secret.clone();
173    let group_id = provider
174        .rand()
175        .random_vec(16)
176        .expect("An unexpected error occurred.");
177
178    let mut epochs = Vec::new();
179
180    // Generate info for all epochs
181    for epoch in 0..n_epochs {
182        println!("Generating epoch: {epoch:?}");
183        let (
184            confirmed_transcript_hash,
185            commit_secret,
186            psk_secret,
187            joiner_secret,
188            welcome_secret,
189            epoch_secrets,
190            tree_hash,
191            group_context,
192            external_key_pair,
193        ) = generate(ciphersuite, &init_secret, &group_id, epoch);
194
195        // exporter
196        let exporter_label = "exporter label";
197        let exporter_length = 32u32;
198        let exporter_context = b"exporter context";
199        let exported = epoch_secrets
200            .exporter_secret()
201            .derive_exported_secret(
202                ciphersuite,
203                provider.crypto(),
204                exporter_label,
205                exporter_context,
206                exporter_length as usize,
207            )
208            .unwrap();
209
210        let epoch_info = Epoch {
211            tree_hash: bytes_to_hex(&tree_hash),
212            commit_secret: bytes_to_hex(commit_secret.as_slice()),
213            psk_secret: bytes_to_hex(psk_secret.as_slice()),
214            confirmed_transcript_hash: bytes_to_hex(&confirmed_transcript_hash),
215            group_context: bytes_to_hex(
216                &group_context
217                    .tls_serialize_detached()
218                    .expect("An unexpected error occurred."),
219            ),
220            joiner_secret: bytes_to_hex(joiner_secret.as_slice()),
221            welcome_secret: bytes_to_hex(welcome_secret.as_slice()),
222            init_secret: bytes_to_hex(epoch_secrets.init_secret().as_slice()),
223            sender_data_secret: bytes_to_hex(epoch_secrets.sender_data_secret().as_slice()),
224            encryption_secret: bytes_to_hex(epoch_secrets.encryption_secret().as_slice()),
225            exporter_secret: bytes_to_hex(epoch_secrets.exporter_secret().as_slice()),
226            epoch_authenticator: bytes_to_hex(epoch_secrets.epoch_authenticator().as_slice()),
227            external_secret: bytes_to_hex(epoch_secrets.external_secret().as_slice()),
228            confirmation_key: bytes_to_hex(epoch_secrets.confirmation_key().as_slice()),
229            membership_key: bytes_to_hex(epoch_secrets.membership_key().as_slice()),
230            resumption_psk: bytes_to_hex(epoch_secrets.resumption_psk().as_slice()),
231            external_pub: bytes_to_hex(&external_key_pair.public),
232            exporter: Exporter {
233                label: exporter_label.into(),
234                context: bytes_to_hex(exporter_context),
235                length: exporter_length,
236                secret: bytes_to_hex(&exported),
237            },
238        };
239        epochs.push(epoch_info);
240        init_secret = epoch_secrets.init_secret().clone();
241    }
242
243    KeyScheduleTestVector {
244        cipher_suite: ciphersuite as u16,
245        group_id: bytes_to_hex(&group_id),
246        initial_init_secret: bytes_to_hex(initial_init_secret.as_slice()),
247        epochs,
248    }
249}
250
251#[cfg(feature = "generate-kats")]
252#[test]
253fn write_test_vectors() {
254    const NUM_EPOCHS: u64 = 2;
255    let mut tests = Vec::new();
256    let provider = OpenMlsRustCrypto::default();
257    for &ciphersuite in provider.crypto().supported_ciphersuites().iter() {
258        tests.push(generate_test_vector(NUM_EPOCHS, ciphersuite, &provider));
259    }
260    write("test_vectors/key-schedule-new.json", &tests);
261}
262
263#[openmls_test::openmls_test]
264fn read_test_vectors_key_schedule() {
265    let provider = &Provider::default();
266
267    let _ = pretty_env_logger::try_init();
268
269    let tests: Vec<KeyScheduleTestVector> =
270        read_json!("../../../../test_vectors/key-schedule.json");
271
272    for test_vector in tests {
273        match run_test_vector(test_vector, provider) {
274            Ok(_) => {}
275            Err(e) => panic!("Error while checking key schedule test vector.\n{e:?}"),
276        }
277    }
278}
279
280#[cfg(any(feature = "test-utils", test))]
281pub fn run_test_vector(
282    test_vector: KeyScheduleTestVector,
283    provider: &impl OpenMlsProvider,
284) -> Result<(), KsTestVectorError> {
285    let ciphersuite = Ciphersuite::try_from(test_vector.cipher_suite).expect("Invalid ciphersuite");
286    log::trace!("  {test_vector:?}");
287
288    if !provider
289        .crypto()
290        .supported_ciphersuites()
291        .contains(&ciphersuite)
292    {
293        info!("Skipping unsupported ciphersuite `{ciphersuite:?}`.");
294        return Ok(());
295    }
296
297    let group_id = hex_to_bytes(&test_vector.group_id);
298    let init_secret = hex_to_bytes(&test_vector.initial_init_secret);
299    log::trace!(
300        "  InitSecret from tve: {:?}",
301        test_vector.initial_init_secret
302    );
303    let mut init_secret = InitSecret::from(Secret::from_slice(&init_secret));
304
305    for (epoch_ctr, epoch) in test_vector.epochs.iter().enumerate() {
306        let tree_hash = hex_to_bytes(&epoch.tree_hash);
307        let secret = hex_to_bytes(&epoch.commit_secret);
308        let commit_secret = CommitSecret::from(PathSecret::from(Secret::from_slice(&secret)));
309        log::trace!("    CommitSecret from tve {:?}", epoch.commit_secret);
310
311        let confirmed_transcript_hash = hex_to_bytes(&epoch.confirmed_transcript_hash);
312
313        let group_context = GroupContext::new(
314            ciphersuite,
315            GroupId::from_slice(&group_id),
316            GroupEpoch::from(epoch_ctr as u64),
317            tree_hash.to_vec(),
318            confirmed_transcript_hash.clone(),
319            Extensions::empty(),
320        );
321
322        let joiner_secret = JoinerSecret::new(
323            provider.crypto(),
324            ciphersuite,
325            commit_secret,
326            &init_secret,
327            &group_context.tls_serialize_detached().unwrap(),
328        )
329        .expect("Could not create JoinerSecret.");
330        if hex_to_bytes(&epoch.joiner_secret) != joiner_secret.as_slice() {
331            if cfg!(test) {
332                panic!("Joiner secret mismatch");
333            }
334            return Err(KsTestVectorError::JoinerSecretMismatch);
335        }
336
337        let psk_secret_inner = Secret::from_slice(&hex_to_bytes(&epoch.psk_secret));
338        let psk_secret = PskSecret::from(psk_secret_inner);
339
340        let mut key_schedule =
341            KeySchedule::init(ciphersuite, provider.crypto(), &joiner_secret, psk_secret)
342                .expect("Could not create KeySchedule.");
343        let welcome_secret = key_schedule
344            .welcome(provider.crypto(), ciphersuite)
345            .expect("An unexpected error occurred.");
346
347        if hex_to_bytes(&epoch.welcome_secret) != welcome_secret.as_slice() {
348            if cfg!(test) {
349                panic!("Welcome secret mismatch");
350            }
351            return Err(KsTestVectorError::WelcomeSecretMismatch);
352        }
353
354        let expected_group_context = hex_to_bytes(&epoch.group_context);
355        let group_context_serialized = group_context
356            .tls_serialize_detached()
357            .expect("An unexpected error occurred.");
358        if group_context_serialized != expected_group_context {
359            log::error!("  Group context mismatch");
360            log::debug!("    Computed: {group_context_serialized:x?}");
361            log::debug!("    Expected: {expected_group_context:x?}");
362            if cfg!(test) {
363                panic!("Group context mismatch");
364            }
365            return Err(KsTestVectorError::GroupContextMismatch);
366        }
367
368        key_schedule
369            .add_context(provider.crypto(), &group_context_serialized)
370            .expect("An unexpected error occurred.");
371
372        let EpochSecretsResult { epoch_secrets, .. } = key_schedule
373            .epoch_secrets(provider.crypto(), ciphersuite)
374            .expect("An unexpected error occurred.");
375
376        init_secret = epoch_secrets.init_secret().clone();
377        if hex_to_bytes(&epoch.init_secret) != init_secret.as_slice() {
378            log_crypto!(
379                debug,
380                "    Epoch secret mismatch: {:x?} != {:x?}",
381                hex_to_bytes(&epoch.init_secret),
382                init_secret.as_slice()
383            );
384            if cfg!(test) {
385                panic!("Init secret mismatch");
386            }
387            return Err(KsTestVectorError::InitSecretMismatch);
388        }
389        if hex_to_bytes(&epoch.sender_data_secret) != epoch_secrets.sender_data_secret().as_slice()
390        {
391            if cfg!(test) {
392                panic!("Sender data secret mismatch");
393            }
394            return Err(KsTestVectorError::SenderDataSecretMismatch);
395        }
396        if hex_to_bytes(&epoch.encryption_secret) != epoch_secrets.encryption_secret().as_slice() {
397            if cfg!(test) {
398                panic!("Encryption secret mismatch");
399            }
400            return Err(KsTestVectorError::EncryptionSecretMismatch);
401        }
402        if hex_to_bytes(&epoch.exporter_secret) != epoch_secrets.exporter_secret().as_slice() {
403            if cfg!(test) {
404                panic!("Exporter secret mismatch");
405            }
406            return Err(KsTestVectorError::ExporterSecretMismatch);
407        }
408        if hex_to_bytes(&epoch.epoch_authenticator)
409            != epoch_secrets.epoch_authenticator().as_slice()
410        {
411            if cfg!(test) {
412                panic!("Epoch authenticator mismatch");
413            }
414            return Err(KsTestVectorError::EpochAuthenticatorMismatch);
415        }
416        if hex_to_bytes(&epoch.external_secret) != epoch_secrets.external_secret().as_slice() {
417            if cfg!(test) {
418                panic!("External secret mismatch");
419            }
420            return Err(KsTestVectorError::ExternalSecretMismatch);
421        }
422        if hex_to_bytes(&epoch.confirmation_key) != epoch_secrets.confirmation_key().as_slice() {
423            if cfg!(test) {
424                panic!("Confirmation key mismatch");
425            }
426            return Err(KsTestVectorError::ConfirmationKeyMismatch);
427        }
428        if hex_to_bytes(&epoch.membership_key) != epoch_secrets.membership_key().as_slice() {
429            if cfg!(test) {
430                panic!("Membership key mismatch");
431            }
432            return Err(KsTestVectorError::MembershipKeyMismatch);
433        }
434        if hex_to_bytes(&epoch.resumption_psk) != epoch_secrets.resumption_psk().as_slice() {
435            if cfg!(test) {
436                panic!("Resumption psk mismatch");
437            }
438            return Err(KsTestVectorError::ResumptionPskMismatch);
439        }
440
441        // Calculate external HPKE key pair
442        let external_key_pair = epoch_secrets
443            .external_secret()
444            .derive_external_keypair(provider.crypto(), ciphersuite)
445            .expect("an unexpected crypto error occurred");
446        if hex_to_bytes(&epoch.external_pub) != external_key_pair.public {
447            log::error!("  External public key mismatch");
448            log::debug!(
449                "    Computed: {:x?}",
450                HpkePublicKey::from(external_key_pair.public)
451                    .tls_serialize_detached()
452                    .expect("An unexpected error occurred.")
453            );
454            log::debug!("    Expected: {:x?}", hex_to_bytes(&epoch.external_pub));
455            if cfg!(test) {
456                panic!("External pub mismatch");
457            }
458            return Err(KsTestVectorError::ExternalPubMismatch);
459        }
460
461        // Check exported secret
462        let exported = epoch_secrets
463            .exporter_secret()
464            .derive_exported_secret(
465                ciphersuite,
466                provider.crypto(),
467                &epoch.exporter.label,
468                &hex_to_bytes(&epoch.exporter.context),
469                epoch.exporter.length as usize,
470            )
471            .unwrap();
472        if hex_to_bytes(&epoch.exporter.secret) != exported {
473            if cfg!(test) {
474                panic!("Exporter mismatch");
475            }
476            return Err(KsTestVectorError::ExporterMismatch);
477        }
478    }
479    Ok(())
480}