1use log::info;
9use openmls_traits::{random::OpenMlsRand, types::HpkeKeyPair, OpenMlsProvider};
10use serde::{self, Deserialize, Serialize};
11use tls_codec::Serialize as TlsSerializeTrait;
12
13#[cfg(all(test, feature = "generate-kats"))]
14use crate::test_utils::write;
15use crate::{
16 ciphersuite::*,
17 extensions::Extensions,
18 group::*,
19 schedule::{errors::KsTestVectorError, CommitSecret, *},
20 test_utils::*,
21};
22
23#[derive(Serialize, Deserialize, Debug, Clone, Default)]
24struct Exporter {
25 label: String,
26 context: String,
27 length: u32,
28 secret: String,
29}
30
31#[derive(Serialize, Deserialize, Debug, Clone, Default)]
32struct Epoch {
33 tree_hash: String,
35 commit_secret: String,
36 psk_secret: String,
37 confirmed_transcript_hash: String,
38
39 group_context: String,
41 joiner_secret: String,
42 welcome_secret: String,
43 init_secret: String,
44 sender_data_secret: String,
45 encryption_secret: String,
46 exporter_secret: String,
47 epoch_authenticator: String,
48 external_secret: String,
49 confirmation_key: String,
50 membership_key: String,
51 resumption_psk: String,
52
53 external_pub: String,
54 exporter: Exporter,
55}
56
57#[derive(Serialize, Deserialize, Debug, Clone, Default)]
58pub struct KeyScheduleTestVector {
59 pub cipher_suite: u16,
60 group_id: String,
61 initial_init_secret: String,
62 epochs: Vec<Epoch>,
63}
64
65#[allow(clippy::type_complexity)]
67fn generate(
68 ciphersuite: Ciphersuite,
69 init_secret: &InitSecret,
70 group_id: &[u8],
71 epoch: u64,
72) -> (
73 Vec<u8>,
74 CommitSecret,
75 PskSecret,
76 JoinerSecret,
77 WelcomeSecret,
78 EpochSecrets,
79 Vec<u8>,
80 GroupContext,
81 HpkeKeyPair,
82) {
83 let crypto = OpenMlsRustCrypto::default();
84 let tree_hash = crypto
85 .rand()
86 .random_vec(ciphersuite.hash_length())
87 .expect("An unexpected error occurred.");
88 let commit_secret = CommitSecret::random(ciphersuite, crypto.rand());
89
90 let confirmed_transcript_hash = crypto
91 .rand()
92 .random_vec(ciphersuite.hash_length())
93 .expect("An unexpected error occurred.");
94
95 let a: [u8; 1] = crypto.rand().random_array().unwrap();
97 let psk_secret = if a[0] > 127 {
98 PskSecret::from(Secret::random(ciphersuite, crypto.rand()).unwrap())
99 } else {
100 PskSecret::from(Secret::zero(ciphersuite))
101 };
102
103 let group_context = GroupContext::new(
104 ciphersuite,
105 GroupId::from_slice(group_id),
106 epoch,
107 tree_hash.to_vec(),
108 confirmed_transcript_hash.clone(),
109 Extensions::empty(),
110 );
111
112 let joiner_secret = JoinerSecret::new(
113 crypto.crypto(),
114 ciphersuite,
115 commit_secret.clone(),
116 init_secret,
117 &group_context.tls_serialize_detached().unwrap(),
118 )
119 .expect("Could not create JoinerSecret.");
120 let mut key_schedule = KeySchedule::init(
121 ciphersuite,
122 crypto.crypto(),
123 &joiner_secret,
124 psk_secret.clone(),
125 )
126 .expect("Could not create KeySchedule.");
127 let welcome_secret = key_schedule
128 .welcome(crypto.crypto(), ciphersuite)
129 .expect("An unexpected error occurred.");
130
131 let serialized_group_context = group_context
132 .tls_serialize_detached()
133 .expect("Could not serialize group context.");
134
135 key_schedule
136 .add_context(crypto.crypto(), &serialized_group_context)
137 .expect("An unexpected error occurred.");
138 let EpochSecretsResult { epoch_secrets, .. } = key_schedule
139 .epoch_secrets(crypto.crypto(), ciphersuite)
140 .expect("An unexpected error occurred.");
141
142 let external_key_pair = epoch_secrets
144 .external_secret()
145 .derive_external_keypair(crypto.crypto(), ciphersuite)
146 .expect("An unexpected crypto error occurred.");
147
148 (
149 confirmed_transcript_hash,
150 commit_secret,
151 psk_secret,
152 joiner_secret,
153 welcome_secret,
154 epoch_secrets,
155 tree_hash,
156 group_context,
157 external_key_pair,
158 )
159}
160
161#[cfg(any(feature = "test-utils", test))]
162pub fn generate_test_vector(
163 n_epochs: u64,
164 ciphersuite: Ciphersuite,
165 provider: &impl OpenMlsProvider,
166) -> KeyScheduleTestVector {
167 use tls_codec::Serialize;
168
169 let mut init_secret =
171 InitSecret::random(ciphersuite, provider.rand()).expect("Not enough randomness.");
172 let initial_init_secret = init_secret.clone();
173 let group_id = provider
174 .rand()
175 .random_vec(16)
176 .expect("An unexpected error occurred.");
177
178 let mut epochs = Vec::new();
179
180 for epoch in 0..n_epochs {
182 println!("Generating epoch: {epoch:?}");
183 let (
184 confirmed_transcript_hash,
185 commit_secret,
186 psk_secret,
187 joiner_secret,
188 welcome_secret,
189 epoch_secrets,
190 tree_hash,
191 group_context,
192 external_key_pair,
193 ) = generate(ciphersuite, &init_secret, &group_id, epoch);
194
195 let exporter_label = "exporter label";
197 let exporter_length = 32u32;
198 let exporter_context = b"exporter context";
199 let exported = epoch_secrets
200 .exporter_secret()
201 .derive_exported_secret(
202 ciphersuite,
203 provider.crypto(),
204 exporter_label,
205 exporter_context,
206 exporter_length as usize,
207 )
208 .unwrap();
209
210 let epoch_info = Epoch {
211 tree_hash: bytes_to_hex(&tree_hash),
212 commit_secret: bytes_to_hex(commit_secret.as_slice()),
213 psk_secret: bytes_to_hex(psk_secret.as_slice()),
214 confirmed_transcript_hash: bytes_to_hex(&confirmed_transcript_hash),
215 group_context: bytes_to_hex(
216 &group_context
217 .tls_serialize_detached()
218 .expect("An unexpected error occurred."),
219 ),
220 joiner_secret: bytes_to_hex(joiner_secret.as_slice()),
221 welcome_secret: bytes_to_hex(welcome_secret.as_slice()),
222 init_secret: bytes_to_hex(epoch_secrets.init_secret().as_slice()),
223 sender_data_secret: bytes_to_hex(epoch_secrets.sender_data_secret().as_slice()),
224 encryption_secret: bytes_to_hex(epoch_secrets.encryption_secret().as_slice()),
225 exporter_secret: bytes_to_hex(epoch_secrets.exporter_secret().as_slice()),
226 epoch_authenticator: bytes_to_hex(epoch_secrets.epoch_authenticator().as_slice()),
227 external_secret: bytes_to_hex(epoch_secrets.external_secret().as_slice()),
228 confirmation_key: bytes_to_hex(epoch_secrets.confirmation_key().as_slice()),
229 membership_key: bytes_to_hex(epoch_secrets.membership_key().as_slice()),
230 resumption_psk: bytes_to_hex(epoch_secrets.resumption_psk().as_slice()),
231 external_pub: bytes_to_hex(&external_key_pair.public),
232 exporter: Exporter {
233 label: exporter_label.into(),
234 context: bytes_to_hex(exporter_context),
235 length: exporter_length,
236 secret: bytes_to_hex(&exported),
237 },
238 };
239 epochs.push(epoch_info);
240 init_secret = epoch_secrets.init_secret().clone();
241 }
242
243 KeyScheduleTestVector {
244 cipher_suite: ciphersuite as u16,
245 group_id: bytes_to_hex(&group_id),
246 initial_init_secret: bytes_to_hex(initial_init_secret.as_slice()),
247 epochs,
248 }
249}
250
251#[cfg(feature = "generate-kats")]
252#[test]
253fn write_test_vectors() {
254 const NUM_EPOCHS: u64 = 2;
255 let mut tests = Vec::new();
256 let provider = OpenMlsRustCrypto::default();
257 for &ciphersuite in provider.crypto().supported_ciphersuites().iter() {
258 tests.push(generate_test_vector(NUM_EPOCHS, ciphersuite, &provider));
259 }
260 write("test_vectors/key-schedule-new.json", &tests);
261}
262
263#[openmls_test::openmls_test]
264fn read_test_vectors_key_schedule() {
265 let provider = &Provider::default();
266
267 let _ = pretty_env_logger::try_init();
268
269 let tests: Vec<KeyScheduleTestVector> =
270 read_json!("../../../../test_vectors/key-schedule.json");
271
272 for test_vector in tests {
273 match run_test_vector(test_vector, provider) {
274 Ok(_) => {}
275 Err(e) => panic!("Error while checking key schedule test vector.\n{e:?}"),
276 }
277 }
278}
279
280#[cfg(any(feature = "test-utils", test))]
281pub fn run_test_vector(
282 test_vector: KeyScheduleTestVector,
283 provider: &impl OpenMlsProvider,
284) -> Result<(), KsTestVectorError> {
285 let ciphersuite = Ciphersuite::try_from(test_vector.cipher_suite).expect("Invalid ciphersuite");
286 log::trace!(" {test_vector:?}");
287
288 if !provider
289 .crypto()
290 .supported_ciphersuites()
291 .contains(&ciphersuite)
292 {
293 info!("Skipping unsupported ciphersuite `{ciphersuite:?}`.");
294 return Ok(());
295 }
296
297 let group_id = hex_to_bytes(&test_vector.group_id);
298 let init_secret = hex_to_bytes(&test_vector.initial_init_secret);
299 log::trace!(
300 " InitSecret from tve: {:?}",
301 test_vector.initial_init_secret
302 );
303 let mut init_secret = InitSecret::from(Secret::from_slice(&init_secret));
304
305 for (epoch_ctr, epoch) in test_vector.epochs.iter().enumerate() {
306 let tree_hash = hex_to_bytes(&epoch.tree_hash);
307 let secret = hex_to_bytes(&epoch.commit_secret);
308 let commit_secret = CommitSecret::from(PathSecret::from(Secret::from_slice(&secret)));
309 log::trace!(" CommitSecret from tve {:?}", epoch.commit_secret);
310
311 let confirmed_transcript_hash = hex_to_bytes(&epoch.confirmed_transcript_hash);
312
313 let group_context = GroupContext::new(
314 ciphersuite,
315 GroupId::from_slice(&group_id),
316 GroupEpoch::from(epoch_ctr as u64),
317 tree_hash.to_vec(),
318 confirmed_transcript_hash.clone(),
319 Extensions::empty(),
320 );
321
322 let joiner_secret = JoinerSecret::new(
323 provider.crypto(),
324 ciphersuite,
325 commit_secret,
326 &init_secret,
327 &group_context.tls_serialize_detached().unwrap(),
328 )
329 .expect("Could not create JoinerSecret.");
330 if hex_to_bytes(&epoch.joiner_secret) != joiner_secret.as_slice() {
331 if cfg!(test) {
332 panic!("Joiner secret mismatch");
333 }
334 return Err(KsTestVectorError::JoinerSecretMismatch);
335 }
336
337 let psk_secret_inner = Secret::from_slice(&hex_to_bytes(&epoch.psk_secret));
338 let psk_secret = PskSecret::from(psk_secret_inner);
339
340 let mut key_schedule =
341 KeySchedule::init(ciphersuite, provider.crypto(), &joiner_secret, psk_secret)
342 .expect("Could not create KeySchedule.");
343 let welcome_secret = key_schedule
344 .welcome(provider.crypto(), ciphersuite)
345 .expect("An unexpected error occurred.");
346
347 if hex_to_bytes(&epoch.welcome_secret) != welcome_secret.as_slice() {
348 if cfg!(test) {
349 panic!("Welcome secret mismatch");
350 }
351 return Err(KsTestVectorError::WelcomeSecretMismatch);
352 }
353
354 let expected_group_context = hex_to_bytes(&epoch.group_context);
355 let group_context_serialized = group_context
356 .tls_serialize_detached()
357 .expect("An unexpected error occurred.");
358 if group_context_serialized != expected_group_context {
359 log::error!(" Group context mismatch");
360 log::debug!(" Computed: {group_context_serialized:x?}");
361 log::debug!(" Expected: {expected_group_context:x?}");
362 if cfg!(test) {
363 panic!("Group context mismatch");
364 }
365 return Err(KsTestVectorError::GroupContextMismatch);
366 }
367
368 key_schedule
369 .add_context(provider.crypto(), &group_context_serialized)
370 .expect("An unexpected error occurred.");
371
372 let EpochSecretsResult { epoch_secrets, .. } = key_schedule
373 .epoch_secrets(provider.crypto(), ciphersuite)
374 .expect("An unexpected error occurred.");
375
376 init_secret = epoch_secrets.init_secret().clone();
377 if hex_to_bytes(&epoch.init_secret) != init_secret.as_slice() {
378 log_crypto!(
379 debug,
380 " Epoch secret mismatch: {:x?} != {:x?}",
381 hex_to_bytes(&epoch.init_secret),
382 init_secret.as_slice()
383 );
384 if cfg!(test) {
385 panic!("Init secret mismatch");
386 }
387 return Err(KsTestVectorError::InitSecretMismatch);
388 }
389 if hex_to_bytes(&epoch.sender_data_secret) != epoch_secrets.sender_data_secret().as_slice()
390 {
391 if cfg!(test) {
392 panic!("Sender data secret mismatch");
393 }
394 return Err(KsTestVectorError::SenderDataSecretMismatch);
395 }
396 if hex_to_bytes(&epoch.encryption_secret) != epoch_secrets.encryption_secret().as_slice() {
397 if cfg!(test) {
398 panic!("Encryption secret mismatch");
399 }
400 return Err(KsTestVectorError::EncryptionSecretMismatch);
401 }
402 if hex_to_bytes(&epoch.exporter_secret) != epoch_secrets.exporter_secret().as_slice() {
403 if cfg!(test) {
404 panic!("Exporter secret mismatch");
405 }
406 return Err(KsTestVectorError::ExporterSecretMismatch);
407 }
408 if hex_to_bytes(&epoch.epoch_authenticator)
409 != epoch_secrets.epoch_authenticator().as_slice()
410 {
411 if cfg!(test) {
412 panic!("Epoch authenticator mismatch");
413 }
414 return Err(KsTestVectorError::EpochAuthenticatorMismatch);
415 }
416 if hex_to_bytes(&epoch.external_secret) != epoch_secrets.external_secret().as_slice() {
417 if cfg!(test) {
418 panic!("External secret mismatch");
419 }
420 return Err(KsTestVectorError::ExternalSecretMismatch);
421 }
422 if hex_to_bytes(&epoch.confirmation_key) != epoch_secrets.confirmation_key().as_slice() {
423 if cfg!(test) {
424 panic!("Confirmation key mismatch");
425 }
426 return Err(KsTestVectorError::ConfirmationKeyMismatch);
427 }
428 if hex_to_bytes(&epoch.membership_key) != epoch_secrets.membership_key().as_slice() {
429 if cfg!(test) {
430 panic!("Membership key mismatch");
431 }
432 return Err(KsTestVectorError::MembershipKeyMismatch);
433 }
434 if hex_to_bytes(&epoch.resumption_psk) != epoch_secrets.resumption_psk().as_slice() {
435 if cfg!(test) {
436 panic!("Resumption psk mismatch");
437 }
438 return Err(KsTestVectorError::ResumptionPskMismatch);
439 }
440
441 let external_key_pair = epoch_secrets
443 .external_secret()
444 .derive_external_keypair(provider.crypto(), ciphersuite)
445 .expect("an unexpected crypto error occurred");
446 if hex_to_bytes(&epoch.external_pub) != external_key_pair.public {
447 log::error!(" External public key mismatch");
448 log::debug!(
449 " Computed: {:x?}",
450 HpkePublicKey::from(external_key_pair.public)
451 .tls_serialize_detached()
452 .expect("An unexpected error occurred.")
453 );
454 log::debug!(" Expected: {:x?}", hex_to_bytes(&epoch.external_pub));
455 if cfg!(test) {
456 panic!("External pub mismatch");
457 }
458 return Err(KsTestVectorError::ExternalPubMismatch);
459 }
460
461 let exported = epoch_secrets
463 .exporter_secret()
464 .derive_exported_secret(
465 ciphersuite,
466 provider.crypto(),
467 &epoch.exporter.label,
468 &hex_to_bytes(&epoch.exporter.context),
469 epoch.exporter.length as usize,
470 )
471 .unwrap();
472 if hex_to_bytes(&epoch.exporter.secret) != exported {
473 if cfg!(test) {
474 panic!("Exporter mismatch");
475 }
476 return Err(KsTestVectorError::ExporterMismatch);
477 }
478 }
479 Ok(())
480}