1use std::io::{Read, Write};
28
29use openmls_traits::signatures::Signer;
30use serde::{Deserialize, Serialize};
31use tls_codec::{
32 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
33 Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
34};
35
36#[cfg(test)]
37mod tests;
38
39use crate::{ciphersuite::SignaturePublicKey, group::Member, treesync::LeafNode};
40use errors::*;
41
42#[cfg(doc)]
43use crate::group::MlsGroup;
44
45pub mod errors;
47
48#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
81#[repr(u16)]
82pub enum CredentialType {
83 Basic = 1,
85 X509 = 2,
87 Grease(u16),
89 Other(u16),
91}
92
93impl CredentialType {
94 pub fn is_grease(&self) -> bool {
99 matches!(self, CredentialType::Grease(_))
100 }
101}
102
103impl Size for CredentialType {
104 fn tls_serialized_len(&self) -> usize {
105 2
106 }
107}
108
109impl TlsDeserializeTrait for CredentialType {
110 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
111 where
112 Self: Sized,
113 {
114 let mut extension_type = [0u8; 2];
115 bytes.read_exact(&mut extension_type)?;
116
117 Ok(CredentialType::from(u16::from_be_bytes(extension_type)))
118 }
119}
120
121impl TlsSerializeTrait for CredentialType {
122 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
123 writer.write_all(&u16::from(*self).to_be_bytes())?;
124
125 Ok(2)
126 }
127}
128
129impl DeserializeBytes for CredentialType {
130 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
131 where
132 Self: Sized,
133 {
134 let mut bytes_ref = bytes;
135 let credential_type = CredentialType::tls_deserialize(&mut bytes_ref)?;
136 let remainder = &bytes[credential_type.tls_serialized_len()..];
137 Ok((credential_type, remainder))
138 }
139}
140
141impl From<u16> for CredentialType {
142 fn from(value: u16) -> Self {
143 match value {
144 1 => CredentialType::Basic,
145 2 => CredentialType::X509,
146 other if crate::grease::is_grease_value(other) => CredentialType::Grease(other),
147 other => CredentialType::Other(other),
148 }
149 }
150}
151
152impl From<CredentialType> for u16 {
153 fn from(value: CredentialType) -> Self {
154 match value {
155 CredentialType::Basic => 1,
156 CredentialType::X509 => 2,
157 CredentialType::Grease(value) => value,
158 CredentialType::Other(other) => other,
159 }
160 }
161}
162
163#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
174pub struct Certificate {
175 cert_data: Vec<u8>,
176}
177
178#[derive(
208 Debug,
209 PartialEq,
210 Eq,
211 Clone,
212 Serialize,
213 Deserialize,
214 TlsSize,
215 TlsSerialize,
216 TlsDeserialize,
217 TlsDeserializeBytes,
218)]
219pub struct Credential {
220 credential_type: CredentialType,
221 serialized_credential_content: VLBytes,
222}
223
224impl Credential {
225 pub fn credential_type(&self) -> CredentialType {
227 self.credential_type
228 }
229
230 pub fn new(credential_type: CredentialType, serialized_credential: Vec<u8>) -> Self {
233 Self {
234 credential_type,
235 serialized_credential_content: serialized_credential.into(),
236 }
237 }
238
239 pub fn serialized_content(&self) -> &[u8] {
244 self.serialized_credential_content.as_slice()
245 }
246
247 pub fn deserialized<T: tls_codec::Size + tls_codec::Deserialize>(
249 &self,
250 ) -> Result<T, tls_codec::Error> {
251 T::tls_deserialize_exact(&self.serialized_credential_content)
252 }
253}
254
255#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
266pub struct BasicCredential {
267 identity: VLBytes,
268}
269
270impl BasicCredential {
271 pub fn new(identity: Vec<u8>) -> Self {
278 Self {
279 identity: identity.into(),
280 }
281 }
282
283 pub fn identity(&self) -> &[u8] {
285 self.identity.as_slice()
286 }
287}
288
289impl From<BasicCredential> for Credential {
290 fn from(credential: BasicCredential) -> Self {
291 Credential {
292 credential_type: CredentialType::Basic,
293 serialized_credential_content: credential.identity,
294 }
295 }
296}
297
298impl TryFrom<Credential> for BasicCredential {
299 type Error = BasicCredentialError;
300
301 fn try_from(credential: Credential) -> Result<Self, Self::Error> {
302 match credential.credential_type {
303 CredentialType::Basic => Ok(BasicCredential::new(
304 credential.serialized_credential_content.into(),
305 )),
306 _ => Err(errors::BasicCredentialError::WrongCredentialType),
307 }
308 }
309}
310
311#[derive(Debug, Clone)]
315pub struct NewSignerBundle<'a, S: Signer> {
316 pub signer: &'a S,
318 pub credential_with_key: CredentialWithKey,
320}
321
322#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
323pub struct CredentialWithKey {
325 pub credential: Credential,
327 pub signature_key: SignaturePublicKey,
329}
330
331impl From<&LeafNode> for CredentialWithKey {
332 fn from(leaf_node: &LeafNode) -> Self {
333 Self {
334 credential: leaf_node.credential().clone(),
335 signature_key: leaf_node.signature_key().clone(),
336 }
337 }
338}
339
340impl From<&Member> for CredentialWithKey {
341 fn from(member: &Member) -> Self {
342 Self {
343 credential: member.credential.clone(),
344 signature_key: member.signature_key.clone().into(),
345 }
346 }
347}
348
349#[cfg(test)]
350impl CredentialWithKey {
351 pub fn from_parts(credential: Credential, key: &[u8]) -> Self {
352 Self {
353 credential,
354 signature_key: key.into(),
355 }
356 }
357}
358
359#[cfg(any(test, feature = "test-utils"))]
360pub mod test_utils {
361 use openmls_basic_credential::SignatureKeyPair;
362 use openmls_traits::{types::SignatureScheme, OpenMlsProvider};
363
364 use super::{BasicCredential, CredentialWithKey};
365
366 pub fn new_credential(
374 provider: &impl OpenMlsProvider,
375 identity: &[u8],
376 signature_scheme: SignatureScheme,
377 ) -> (CredentialWithKey, SignatureKeyPair) {
378 let credential = BasicCredential::new(identity.into());
379 let signature_keys = SignatureKeyPair::new(signature_scheme).unwrap();
380 signature_keys.store(provider.storage()).unwrap();
381
382 (
383 CredentialWithKey {
384 credential: credential.into(),
385 signature_key: signature_keys.public().into(),
386 },
387 signature_keys,
388 )
389 }
390}
391
392#[cfg(test)]
393mod unit_tests {
394 use tls_codec::{
395 DeserializeBytes, Serialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
396 };
397
398 use super::{BasicCredential, Credential, CredentialType};
399
400 #[test]
401 fn basic_credential_identity_and_codec() {
402 const IDENTITY: &str = "identity";
403 let basic_credential = BasicCredential::new(IDENTITY.into());
405 assert_eq!(basic_credential.identity(), IDENTITY.as_bytes());
406
407 let credential = Credential::from(basic_credential.clone());
409 let serialized = credential.tls_serialize_detached().unwrap();
410
411 let deserialized = Credential::tls_deserialize_exact_bytes(&serialized).unwrap();
412 assert_eq!(credential.credential_type(), deserialized.credential_type());
413 assert_eq!(
414 credential.serialized_content(),
415 deserialized.serialized_content()
416 );
417
418 let deserialized_basic_credential = BasicCredential::try_from(deserialized).unwrap();
419 assert_eq!(
420 deserialized_basic_credential.identity(),
421 IDENTITY.as_bytes()
422 );
423 assert_eq!(basic_credential, deserialized_basic_credential);
424 }
425
426 #[test]
428 fn custom_credential() {
429 #[derive(
430 Debug, Clone, PartialEq, Eq, TlsSize, TlsSerialize, TlsDeserialize, TlsDeserializeBytes,
431 )]
432 struct CustomCredential {
433 custom_field1: u32,
434 custom_field2: Vec<u8>,
435 custom_field3: Option<u8>,
436 }
437
438 let custom_credential = CustomCredential {
439 custom_field1: 42,
440 custom_field2: vec![1, 2, 3],
441 custom_field3: Some(2),
442 };
443
444 let credential = Credential::new(
445 CredentialType::Other(1234),
446 custom_credential.tls_serialize_detached().unwrap(),
447 );
448
449 let serialized = credential.tls_serialize_detached().unwrap();
450 let deserialized = Credential::tls_deserialize_exact_bytes(&serialized).unwrap();
451 assert_eq!(credential, deserialized);
452
453 let deserialized_custom_credential =
454 CustomCredential::tls_deserialize_exact_bytes(deserialized.serialized_content())
455 .unwrap();
456
457 assert_eq!(custom_credential, deserialized_custom_credential);
458 }
459}