1use log::info;
9use openmls_traits::{random::OpenMlsRand, types::HpkeKeyPair, OpenMlsProvider};
10use serde::{self, Deserialize, Serialize};
11use tls_codec::Serialize as TlsSerializeTrait;
12
13#[cfg(test)]
14use crate::test_utils::write;
15use crate::{
16 ciphersuite::*,
17 extensions::Extensions,
18 group::*,
19 schedule::{errors::KsTestVectorError, CommitSecret, *},
20 test_utils::*,
21};
22
23#[derive(Serialize, Deserialize, Debug, Clone, Default)]
24struct Exporter {
25 label: String,
26 context: String,
27 length: u32,
28 secret: String,
29}
30
31#[derive(Serialize, Deserialize, Debug, Clone, Default)]
32struct Epoch {
33 tree_hash: String,
35 commit_secret: String,
36 psk_secret: String,
37 confirmed_transcript_hash: String,
38
39 group_context: String,
41 joiner_secret: String,
42 welcome_secret: String,
43 init_secret: String,
44 sender_data_secret: String,
45 encryption_secret: String,
46 exporter_secret: String,
47 epoch_authenticator: String,
48 external_secret: String,
49 confirmation_key: String,
50 membership_key: String,
51 resumption_psk: String,
52
53 external_pub: String,
54 exporter: Exporter,
55}
56
57#[derive(Serialize, Deserialize, Debug, Clone, Default)]
58pub struct KeyScheduleTestVector {
59 pub cipher_suite: u16,
60 group_id: String,
61 initial_init_secret: String,
62 epochs: Vec<Epoch>,
63}
64
65#[allow(clippy::type_complexity)]
67fn generate(
68 ciphersuite: Ciphersuite,
69 init_secret: &InitSecret,
70 group_id: &[u8],
71 epoch: u64,
72) -> (
73 Vec<u8>,
74 CommitSecret,
75 PskSecret,
76 JoinerSecret,
77 WelcomeSecret,
78 EpochSecrets,
79 Vec<u8>,
80 GroupContext,
81 HpkeKeyPair,
82) {
83 let crypto = OpenMlsRustCrypto::default();
84 let tree_hash = crypto
85 .rand()
86 .random_vec(ciphersuite.hash_length())
87 .expect("An unexpected error occurred.");
88 let commit_secret = CommitSecret::random(ciphersuite, crypto.rand());
89
90 let confirmed_transcript_hash = crypto
91 .rand()
92 .random_vec(ciphersuite.hash_length())
93 .expect("An unexpected error occurred.");
94
95 let a: [u8; 1] = crypto.rand().random_array().unwrap();
97 let psk_secret = if a[0] > 127 {
98 PskSecret::from(Secret::random(ciphersuite, crypto.rand()).unwrap())
99 } else {
100 PskSecret::from(Secret::zero(ciphersuite))
101 };
102
103 let group_context = GroupContext::new(
104 ciphersuite,
105 GroupId::from_slice(group_id),
106 epoch,
107 tree_hash.to_vec(),
108 confirmed_transcript_hash.clone(),
109 Extensions::empty(),
110 );
111
112 let joiner_secret = JoinerSecret::new(
113 crypto.crypto(),
114 ciphersuite,
115 commit_secret.clone(),
116 init_secret,
117 &group_context.tls_serialize_detached().unwrap(),
118 )
119 .expect("Could not create JoinerSecret.");
120 let mut key_schedule = KeySchedule::init(
121 ciphersuite,
122 crypto.crypto(),
123 &joiner_secret,
124 psk_secret.clone(),
125 )
126 .expect("Could not create KeySchedule.");
127 let welcome_secret = key_schedule
128 .welcome(crypto.crypto(), ciphersuite)
129 .expect("An unexpected error occurred.");
130
131 let serialized_group_context = group_context
132 .tls_serialize_detached()
133 .expect("Could not serialize group context.");
134
135 key_schedule
136 .add_context(crypto.crypto(), &serialized_group_context)
137 .expect("An unexpected error occurred.");
138 let EpochSecretsResult { epoch_secrets, .. } = key_schedule
139 .epoch_secrets(crypto.crypto(), ciphersuite)
140 .expect("An unexpected error occurred.");
141
142 let external_key_pair = epoch_secrets
144 .external_secret()
145 .derive_external_keypair(crypto.crypto(), ciphersuite)
146 .expect("An unexpected crypto error occurred.");
147
148 (
149 confirmed_transcript_hash,
150 commit_secret,
151 psk_secret,
152 joiner_secret,
153 welcome_secret,
154 epoch_secrets,
155 tree_hash,
156 group_context,
157 external_key_pair,
158 )
159}
160
161#[cfg(any(feature = "test-utils", test))]
162pub fn generate_test_vector(
163 n_epochs: u64,
164 ciphersuite: Ciphersuite,
165 provider: &impl OpenMlsProvider,
166) -> KeyScheduleTestVector {
167 use tls_codec::Serialize;
168
169 let mut init_secret =
171 InitSecret::random(ciphersuite, provider.rand()).expect("Not enough randomness.");
172 let initial_init_secret = init_secret.clone();
173 let group_id = provider
174 .rand()
175 .random_vec(16)
176 .expect("An unexpected error occurred.");
177
178 let mut epochs = Vec::new();
179
180 for epoch in 0..n_epochs {
182 println!("Generating epoch: {epoch:?}");
183 let (
184 confirmed_transcript_hash,
185 commit_secret,
186 psk_secret,
187 joiner_secret,
188 welcome_secret,
189 epoch_secrets,
190 tree_hash,
191 group_context,
192 external_key_pair,
193 ) = generate(ciphersuite, &init_secret, &group_id, epoch);
194
195 let exporter_label = "exporter label";
197 let exporter_length = 32u32;
198 let exporter_context = b"exporter context";
199 let exported = epoch_secrets
200 .exporter_secret()
201 .derive_exported_secret(
202 ciphersuite,
203 provider.crypto(),
204 exporter_label,
205 exporter_context,
206 exporter_length as usize,
207 )
208 .unwrap();
209
210 let epoch_info = Epoch {
211 tree_hash: bytes_to_hex(&tree_hash),
212 commit_secret: bytes_to_hex(commit_secret.as_slice()),
213 psk_secret: bytes_to_hex(psk_secret.as_slice()),
214 confirmed_transcript_hash: bytes_to_hex(&confirmed_transcript_hash),
215 group_context: bytes_to_hex(
216 &group_context
217 .tls_serialize_detached()
218 .expect("An unexpected error occurred."),
219 ),
220 joiner_secret: bytes_to_hex(joiner_secret.as_slice()),
221 welcome_secret: bytes_to_hex(welcome_secret.as_slice()),
222 init_secret: bytes_to_hex(epoch_secrets.init_secret().as_slice()),
223 sender_data_secret: bytes_to_hex(epoch_secrets.sender_data_secret().as_slice()),
224 encryption_secret: bytes_to_hex(epoch_secrets.encryption_secret().as_slice()),
225 exporter_secret: bytes_to_hex(epoch_secrets.exporter_secret().as_slice()),
226 epoch_authenticator: bytes_to_hex(epoch_secrets.epoch_authenticator().as_slice()),
227 external_secret: bytes_to_hex(epoch_secrets.external_secret().as_slice()),
228 confirmation_key: bytes_to_hex(epoch_secrets.confirmation_key().as_slice()),
229 membership_key: bytes_to_hex(epoch_secrets.membership_key().as_slice()),
230 resumption_psk: bytes_to_hex(epoch_secrets.resumption_psk().as_slice()),
231 external_pub: bytes_to_hex(&external_key_pair.public),
232 exporter: Exporter {
233 label: exporter_label.into(),
234 context: bytes_to_hex(exporter_context),
235 length: exporter_length,
236 secret: bytes_to_hex(&exported),
237 },
238 };
239 epochs.push(epoch_info);
240 init_secret = epoch_secrets.init_secret().clone();
241 }
242
243 KeyScheduleTestVector {
244 cipher_suite: ciphersuite as u16,
245 group_id: bytes_to_hex(&group_id),
246 initial_init_secret: bytes_to_hex(initial_init_secret.as_slice()),
247 epochs,
248 }
249}
250
251#[test]
252fn write_test_vectors() {
253 const NUM_EPOCHS: u64 = 2;
254 let mut tests = Vec::new();
255 let provider = OpenMlsRustCrypto::default();
256 for &ciphersuite in provider.crypto().supported_ciphersuites().iter() {
257 tests.push(generate_test_vector(NUM_EPOCHS, ciphersuite, &provider));
258 }
259 write("test_vectors/key-schedule-new.json", &tests);
260}
261
262#[openmls_test::openmls_test]
263fn read_test_vectors_key_schedule() {
264 let provider = &Provider::default();
265
266 let _ = pretty_env_logger::try_init();
267
268 let tests: Vec<KeyScheduleTestVector> =
269 read_json!("../../../../test_vectors/key-schedule.json");
270
271 for test_vector in tests {
272 match run_test_vector(test_vector, provider) {
273 Ok(_) => {}
274 Err(e) => panic!("Error while checking key schedule test vector.\n{e:?}"),
275 }
276 }
277}
278
279#[cfg(any(feature = "test-utils", test))]
280pub fn run_test_vector(
281 test_vector: KeyScheduleTestVector,
282 provider: &impl OpenMlsProvider,
283) -> Result<(), KsTestVectorError> {
284 let ciphersuite = Ciphersuite::try_from(test_vector.cipher_suite).expect("Invalid ciphersuite");
285 log::trace!(" {test_vector:?}");
286
287 if !provider
288 .crypto()
289 .supported_ciphersuites()
290 .contains(&ciphersuite)
291 {
292 info!("Skipping unsupported ciphersuite `{ciphersuite:?}`.");
293 return Ok(());
294 }
295
296 let group_id = hex_to_bytes(&test_vector.group_id);
297 let init_secret = hex_to_bytes(&test_vector.initial_init_secret);
298 log::trace!(
299 " InitSecret from tve: {:?}",
300 test_vector.initial_init_secret
301 );
302 let mut init_secret = InitSecret::from(Secret::from_slice(&init_secret));
303
304 for (epoch_ctr, epoch) in test_vector.epochs.iter().enumerate() {
305 let tree_hash = hex_to_bytes(&epoch.tree_hash);
306 let secret = hex_to_bytes(&epoch.commit_secret);
307 let commit_secret = CommitSecret::from(PathSecret::from(Secret::from_slice(&secret)));
308 log::trace!(" CommitSecret from tve {:?}", epoch.commit_secret);
309
310 let confirmed_transcript_hash = hex_to_bytes(&epoch.confirmed_transcript_hash);
311
312 let group_context = GroupContext::new(
313 ciphersuite,
314 GroupId::from_slice(&group_id),
315 GroupEpoch::from(epoch_ctr as u64),
316 tree_hash.to_vec(),
317 confirmed_transcript_hash.clone(),
318 Extensions::empty(),
319 );
320
321 let joiner_secret = JoinerSecret::new(
322 provider.crypto(),
323 ciphersuite,
324 commit_secret,
325 &init_secret,
326 &group_context.tls_serialize_detached().unwrap(),
327 )
328 .expect("Could not create JoinerSecret.");
329 if hex_to_bytes(&epoch.joiner_secret) != joiner_secret.as_slice() {
330 if cfg!(test) {
331 panic!("Joiner secret mismatch");
332 }
333 return Err(KsTestVectorError::JoinerSecretMismatch);
334 }
335
336 let psk_secret_inner = Secret::from_slice(&hex_to_bytes(&epoch.psk_secret));
337 let psk_secret = PskSecret::from(psk_secret_inner);
338
339 let mut key_schedule =
340 KeySchedule::init(ciphersuite, provider.crypto(), &joiner_secret, psk_secret)
341 .expect("Could not create KeySchedule.");
342 let welcome_secret = key_schedule
343 .welcome(provider.crypto(), ciphersuite)
344 .expect("An unexpected error occurred.");
345
346 if hex_to_bytes(&epoch.welcome_secret) != welcome_secret.as_slice() {
347 if cfg!(test) {
348 panic!("Welcome secret mismatch");
349 }
350 return Err(KsTestVectorError::WelcomeSecretMismatch);
351 }
352
353 let expected_group_context = hex_to_bytes(&epoch.group_context);
354 let group_context_serialized = group_context
355 .tls_serialize_detached()
356 .expect("An unexpected error occurred.");
357 if group_context_serialized != expected_group_context {
358 log::error!(" Group context mismatch");
359 log::debug!(" Computed: {group_context_serialized:x?}");
360 log::debug!(" Expected: {expected_group_context:x?}");
361 if cfg!(test) {
362 panic!("Group context mismatch");
363 }
364 return Err(KsTestVectorError::GroupContextMismatch);
365 }
366
367 key_schedule
368 .add_context(provider.crypto(), &group_context_serialized)
369 .expect("An unexpected error occurred.");
370
371 let EpochSecretsResult { epoch_secrets, .. } = key_schedule
372 .epoch_secrets(provider.crypto(), ciphersuite)
373 .expect("An unexpected error occurred.");
374
375 init_secret = epoch_secrets.init_secret().clone();
376 if hex_to_bytes(&epoch.init_secret) != init_secret.as_slice() {
377 log_crypto!(
378 debug,
379 " Epoch secret mismatch: {:x?} != {:x?}",
380 hex_to_bytes(&epoch.init_secret),
381 init_secret.as_slice()
382 );
383 if cfg!(test) {
384 panic!("Init secret mismatch");
385 }
386 return Err(KsTestVectorError::InitSecretMismatch);
387 }
388 if hex_to_bytes(&epoch.sender_data_secret) != epoch_secrets.sender_data_secret().as_slice()
389 {
390 if cfg!(test) {
391 panic!("Sender data secret mismatch");
392 }
393 return Err(KsTestVectorError::SenderDataSecretMismatch);
394 }
395 if hex_to_bytes(&epoch.encryption_secret) != epoch_secrets.encryption_secret().as_slice() {
396 if cfg!(test) {
397 panic!("Encryption secret mismatch");
398 }
399 return Err(KsTestVectorError::EncryptionSecretMismatch);
400 }
401 if hex_to_bytes(&epoch.exporter_secret) != epoch_secrets.exporter_secret().as_slice() {
402 if cfg!(test) {
403 panic!("Exporter secret mismatch");
404 }
405 return Err(KsTestVectorError::ExporterSecretMismatch);
406 }
407 if hex_to_bytes(&epoch.epoch_authenticator)
408 != epoch_secrets.epoch_authenticator().as_slice()
409 {
410 if cfg!(test) {
411 panic!("Epoch authenticator mismatch");
412 }
413 return Err(KsTestVectorError::EpochAuthenticatorMismatch);
414 }
415 if hex_to_bytes(&epoch.external_secret) != epoch_secrets.external_secret().as_slice() {
416 if cfg!(test) {
417 panic!("External secret mismatch");
418 }
419 return Err(KsTestVectorError::ExternalSecretMismatch);
420 }
421 if hex_to_bytes(&epoch.confirmation_key) != epoch_secrets.confirmation_key().as_slice() {
422 if cfg!(test) {
423 panic!("Confirmation key mismatch");
424 }
425 return Err(KsTestVectorError::ConfirmationKeyMismatch);
426 }
427 if hex_to_bytes(&epoch.membership_key) != epoch_secrets.membership_key().as_slice() {
428 if cfg!(test) {
429 panic!("Membership key mismatch");
430 }
431 return Err(KsTestVectorError::MembershipKeyMismatch);
432 }
433 if hex_to_bytes(&epoch.resumption_psk) != epoch_secrets.resumption_psk().as_slice() {
434 if cfg!(test) {
435 panic!("Resumption psk mismatch");
436 }
437 return Err(KsTestVectorError::ResumptionPskMismatch);
438 }
439
440 let external_key_pair = epoch_secrets
442 .external_secret()
443 .derive_external_keypair(provider.crypto(), ciphersuite)
444 .expect("an unexpected crypto error occurred");
445 if hex_to_bytes(&epoch.external_pub) != external_key_pair.public {
446 log::error!(" External public key mismatch");
447 log::debug!(
448 " Computed: {:x?}",
449 HpkePublicKey::from(external_key_pair.public)
450 .tls_serialize_detached()
451 .expect("An unexpected error occurred.")
452 );
453 log::debug!(" Expected: {:x?}", hex_to_bytes(&epoch.external_pub));
454 if cfg!(test) {
455 panic!("External pub mismatch");
456 }
457 return Err(KsTestVectorError::ExternalPubMismatch);
458 }
459
460 let exported = epoch_secrets
462 .exporter_secret()
463 .derive_exported_secret(
464 ciphersuite,
465 provider.crypto(),
466 &epoch.exporter.label,
467 &hex_to_bytes(&epoch.exporter.context),
468 epoch.exporter.length as usize,
469 )
470 .unwrap();
471 if hex_to_bytes(&epoch.exporter.secret) != exported {
472 if cfg!(test) {
473 panic!("Exporter mismatch");
474 }
475 return Err(KsTestVectorError::ExporterMismatch);
476 }
477 }
478 Ok(())
479}