1use std::io::{Read, Write};
28
29use openmls_traits::signatures::Signer;
30use serde::{Deserialize, Serialize};
31use tls_codec::{
32 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
33 Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
34};
35
36#[cfg(test)]
37mod tests;
38
39use crate::{ciphersuite::SignaturePublicKey, group::Member, treesync::LeafNode};
40use errors::*;
41
42#[cfg(doc)]
43use crate::group::MlsGroup;
44
45pub mod errors;
47
48#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
81#[cfg_attr(
82 feature = "0-8-1-storage-format",
83 derive(serde::Serialize, serde::Deserialize)
84)]
85#[cfg_attr(
86 not(feature = "0-8-1-storage-format"),
87 derive(
88 openmls_serialization_helpers::Serialize,
89 openmls_serialization_helpers::Deserialize,
90 )
91)]
92#[repr(u16)]
93pub enum CredentialType {
94 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 0)]
95 Basic = 1,
97 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 1)]
98 X509 = 2,
100 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 3)]
101 Grease(u16),
103 #[cfg_attr(not(feature = "0-8-1-storage-format"), storage_tag = 2)]
104 Other(u16),
106}
107
108impl CredentialType {
109 pub fn is_grease(&self) -> bool {
114 matches!(self, CredentialType::Grease(_))
115 }
116}
117
118impl Size for CredentialType {
119 fn tls_serialized_len(&self) -> usize {
120 2
121 }
122}
123
124impl TlsDeserializeTrait for CredentialType {
125 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
126 where
127 Self: Sized,
128 {
129 let mut extension_type = [0u8; 2];
130 bytes.read_exact(&mut extension_type)?;
131
132 Ok(CredentialType::from(u16::from_be_bytes(extension_type)))
133 }
134}
135
136impl TlsSerializeTrait for CredentialType {
137 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
138 writer.write_all(&u16::from(*self).to_be_bytes())?;
139
140 Ok(2)
141 }
142}
143
144impl DeserializeBytes for CredentialType {
145 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
146 where
147 Self: Sized,
148 {
149 let mut bytes_ref = bytes;
150 let credential_type = CredentialType::tls_deserialize(&mut bytes_ref)?;
151 let remainder = &bytes[credential_type.tls_serialized_len()..];
152 Ok((credential_type, remainder))
153 }
154}
155
156impl From<u16> for CredentialType {
157 fn from(value: u16) -> Self {
158 match value {
159 1 => CredentialType::Basic,
160 2 => CredentialType::X509,
161 other if crate::grease::is_grease_value(other) => CredentialType::Grease(other),
162 other => CredentialType::Other(other),
163 }
164 }
165}
166
167impl From<CredentialType> for u16 {
168 fn from(value: CredentialType) -> Self {
169 match value {
170 CredentialType::Basic => 1,
171 CredentialType::X509 => 2,
172 CredentialType::Grease(value) => value,
173 CredentialType::Other(other) => other,
174 }
175 }
176}
177
178#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
189pub struct Certificate {
190 cert_data: Vec<u8>,
191}
192
193#[derive(
223 Debug,
224 PartialEq,
225 Eq,
226 Clone,
227 Serialize,
228 Deserialize,
229 TlsSize,
230 TlsSerialize,
231 TlsDeserialize,
232 TlsDeserializeBytes,
233)]
234pub struct Credential {
235 credential_type: CredentialType,
236 serialized_credential_content: VLBytes,
237}
238
239impl Credential {
240 pub fn credential_type(&self) -> CredentialType {
242 self.credential_type
243 }
244
245 pub fn new(credential_type: CredentialType, serialized_credential: Vec<u8>) -> Self {
248 Self {
249 credential_type,
250 serialized_credential_content: serialized_credential.into(),
251 }
252 }
253
254 pub fn serialized_content(&self) -> &[u8] {
259 self.serialized_credential_content.as_slice()
260 }
261
262 pub fn deserialized<T: tls_codec::Size + tls_codec::Deserialize>(
264 &self,
265 ) -> Result<T, tls_codec::Error> {
266 T::tls_deserialize_exact(&self.serialized_credential_content)
267 }
268}
269
270#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
281pub struct BasicCredential {
282 identity: VLBytes,
283}
284
285impl BasicCredential {
286 pub fn new(identity: Vec<u8>) -> Self {
293 Self {
294 identity: identity.into(),
295 }
296 }
297
298 pub fn identity(&self) -> &[u8] {
300 self.identity.as_slice()
301 }
302}
303
304impl From<BasicCredential> for Credential {
305 fn from(credential: BasicCredential) -> Self {
306 Credential {
307 credential_type: CredentialType::Basic,
308 serialized_credential_content: credential.identity,
309 }
310 }
311}
312
313impl TryFrom<Credential> for BasicCredential {
314 type Error = BasicCredentialError;
315
316 fn try_from(credential: Credential) -> Result<Self, Self::Error> {
317 match credential.credential_type {
318 CredentialType::Basic => Ok(BasicCredential::new(
319 credential.serialized_credential_content.into(),
320 )),
321 _ => Err(errors::BasicCredentialError::WrongCredentialType),
322 }
323 }
324}
325
326#[derive(Debug, Clone)]
330pub struct NewSignerBundle<'a, S: Signer> {
331 pub signer: &'a S,
333 pub credential_with_key: CredentialWithKey,
335}
336
337#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
338pub struct CredentialWithKey {
340 pub credential: Credential,
342 pub signature_key: SignaturePublicKey,
344}
345
346impl From<&LeafNode> for CredentialWithKey {
347 fn from(leaf_node: &LeafNode) -> Self {
348 Self {
349 credential: leaf_node.credential().clone(),
350 signature_key: leaf_node.signature_key().clone(),
351 }
352 }
353}
354
355impl From<&Member> for CredentialWithKey {
356 fn from(member: &Member) -> Self {
357 Self {
358 credential: member.credential.clone(),
359 signature_key: member.signature_key.clone().into(),
360 }
361 }
362}
363
364#[cfg(test)]
365impl CredentialWithKey {
366 pub fn from_parts(credential: Credential, key: &[u8]) -> Self {
367 Self {
368 credential,
369 signature_key: key.into(),
370 }
371 }
372}
373
374#[cfg(any(test, feature = "test-utils"))]
375pub mod test_utils {
376 use openmls_basic_credential::SignatureKeyPair;
377 use openmls_traits::{types::SignatureScheme, OpenMlsProvider};
378
379 use super::{BasicCredential, CredentialWithKey};
380
381 pub fn new_credential(
389 provider: &impl OpenMlsProvider,
390 identity: &[u8],
391 signature_scheme: SignatureScheme,
392 ) -> (CredentialWithKey, SignatureKeyPair) {
393 let credential = BasicCredential::new(identity.into());
394 let signature_keys = SignatureKeyPair::new(signature_scheme).unwrap();
395 signature_keys.store(provider.storage()).unwrap();
396
397 (
398 CredentialWithKey {
399 credential: credential.into(),
400 signature_key: signature_keys.public().into(),
401 },
402 signature_keys,
403 )
404 }
405}
406
407#[cfg(test)]
408mod unit_tests {
409 use tls_codec::{
410 DeserializeBytes, Serialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
411 };
412
413 use super::{BasicCredential, Credential, CredentialType};
414
415 #[test]
416 fn basic_credential_identity_and_codec() {
417 const IDENTITY: &str = "identity";
418 let basic_credential = BasicCredential::new(IDENTITY.into());
420 assert_eq!(basic_credential.identity(), IDENTITY.as_bytes());
421
422 let credential = Credential::from(basic_credential.clone());
424 let serialized = credential.tls_serialize_detached().unwrap();
425
426 let deserialized = Credential::tls_deserialize_exact_bytes(&serialized).unwrap();
427 assert_eq!(credential.credential_type(), deserialized.credential_type());
428 assert_eq!(
429 credential.serialized_content(),
430 deserialized.serialized_content()
431 );
432
433 let deserialized_basic_credential = BasicCredential::try_from(deserialized).unwrap();
434 assert_eq!(
435 deserialized_basic_credential.identity(),
436 IDENTITY.as_bytes()
437 );
438 assert_eq!(basic_credential, deserialized_basic_credential);
439 }
440
441 #[test]
443 fn custom_credential() {
444 #[derive(
445 Debug, Clone, PartialEq, Eq, TlsSize, TlsSerialize, TlsDeserialize, TlsDeserializeBytes,
446 )]
447 struct CustomCredential {
448 custom_field1: u32,
449 custom_field2: Vec<u8>,
450 custom_field3: Option<u8>,
451 }
452
453 let custom_credential = CustomCredential {
454 custom_field1: 42,
455 custom_field2: vec![1, 2, 3],
456 custom_field3: Some(2),
457 };
458
459 let credential = Credential::new(
460 CredentialType::Other(1234),
461 custom_credential.tls_serialize_detached().unwrap(),
462 );
463
464 let serialized = credential.tls_serialize_detached().unwrap();
465 let deserialized = Credential::tls_deserialize_exact_bytes(&serialized).unwrap();
466 assert_eq!(credential, deserialized);
467
468 let deserialized_custom_credential =
469 CustomCredential::tls_deserialize_exact_bytes(deserialized.serialized_content())
470 .unwrap();
471
472 assert_eq!(custom_credential, deserialized_custom_credential);
473 }
474}