1use std::io::{Read, Write};
28
29use openmls_traits::signatures::Signer;
30use serde::{Deserialize, Serialize};
31use tls_codec::{
32 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
33 Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
34};
35
36#[cfg(test)]
37mod tests;
38
39use crate::{ciphersuite::SignaturePublicKey, group::Member, treesync::LeafNode};
40use errors::*;
41
42#[cfg(doc)]
43use crate::group::MlsGroup;
44
45pub mod errors;
47
48#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
81#[repr(u16)]
82pub enum CredentialType {
83 Basic = 1,
85 X509 = 2,
87 Other(u16),
89}
90
91impl Size for CredentialType {
92 fn tls_serialized_len(&self) -> usize {
93 2
94 }
95}
96
97impl TlsDeserializeTrait for CredentialType {
98 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
99 where
100 Self: Sized,
101 {
102 let mut extension_type = [0u8; 2];
103 bytes.read_exact(&mut extension_type)?;
104
105 Ok(CredentialType::from(u16::from_be_bytes(extension_type)))
106 }
107}
108
109impl TlsSerializeTrait for CredentialType {
110 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
111 writer.write_all(&u16::from(*self).to_be_bytes())?;
112
113 Ok(2)
114 }
115}
116
117impl DeserializeBytes for CredentialType {
118 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
119 where
120 Self: Sized,
121 {
122 let mut bytes_ref = bytes;
123 let credential_type = CredentialType::tls_deserialize(&mut bytes_ref)?;
124 let remainder = &bytes[credential_type.tls_serialized_len()..];
125 Ok((credential_type, remainder))
126 }
127}
128
129impl From<u16> for CredentialType {
130 fn from(value: u16) -> Self {
131 match value {
132 1 => CredentialType::Basic,
133 2 => CredentialType::X509,
134 other => CredentialType::Other(other),
135 }
136 }
137}
138
139impl From<CredentialType> for u16 {
140 fn from(value: CredentialType) -> Self {
141 match value {
142 CredentialType::Basic => 1,
143 CredentialType::X509 => 2,
144 CredentialType::Other(other) => other,
145 }
146 }
147}
148
149#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
160pub struct Certificate {
161 cert_data: Vec<u8>,
162}
163
164#[derive(
194 Debug,
195 PartialEq,
196 Eq,
197 Clone,
198 Serialize,
199 Deserialize,
200 TlsSize,
201 TlsSerialize,
202 TlsDeserialize,
203 TlsDeserializeBytes,
204)]
205pub struct Credential {
206 credential_type: CredentialType,
207 serialized_credential_content: VLBytes,
208}
209
210impl Credential {
211 pub fn credential_type(&self) -> CredentialType {
213 self.credential_type
214 }
215
216 pub fn new(credential_type: CredentialType, serialized_credential: Vec<u8>) -> Self {
219 Self {
220 credential_type,
221 serialized_credential_content: serialized_credential.into(),
222 }
223 }
224
225 pub fn serialized_content(&self) -> &[u8] {
230 self.serialized_credential_content.as_slice()
231 }
232
233 pub fn deserialized<T: tls_codec::Size + tls_codec::Deserialize>(
235 &self,
236 ) -> Result<T, tls_codec::Error> {
237 T::tls_deserialize_exact(&self.serialized_credential_content)
238 }
239}
240
241#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
252pub struct BasicCredential {
253 identity: VLBytes,
254}
255
256impl BasicCredential {
257 pub fn new(identity: Vec<u8>) -> Self {
264 Self {
265 identity: identity.into(),
266 }
267 }
268
269 pub fn identity(&self) -> &[u8] {
271 self.identity.as_slice()
272 }
273}
274
275impl From<BasicCredential> for Credential {
276 fn from(credential: BasicCredential) -> Self {
277 Credential {
278 credential_type: CredentialType::Basic,
279 serialized_credential_content: credential.identity,
280 }
281 }
282}
283
284impl TryFrom<Credential> for BasicCredential {
285 type Error = BasicCredentialError;
286
287 fn try_from(credential: Credential) -> Result<Self, Self::Error> {
288 match credential.credential_type {
289 CredentialType::Basic => Ok(BasicCredential::new(
290 credential.serialized_credential_content.into(),
291 )),
292 _ => Err(errors::BasicCredentialError::WrongCredentialType),
293 }
294 }
295}
296
297#[derive(Debug, Clone)]
301pub struct NewSignerBundle<'a, S: Signer> {
302 pub signer: &'a S,
304 pub credential_with_key: CredentialWithKey,
306}
307
308#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
309pub struct CredentialWithKey {
311 pub credential: Credential,
313 pub signature_key: SignaturePublicKey,
315}
316
317impl From<&LeafNode> for CredentialWithKey {
318 fn from(leaf_node: &LeafNode) -> Self {
319 Self {
320 credential: leaf_node.credential().clone(),
321 signature_key: leaf_node.signature_key().clone(),
322 }
323 }
324}
325
326impl From<&Member> for CredentialWithKey {
327 fn from(member: &Member) -> Self {
328 Self {
329 credential: member.credential.clone(),
330 signature_key: member.signature_key.clone().into(),
331 }
332 }
333}
334
335#[cfg(test)]
336impl CredentialWithKey {
337 pub fn from_parts(credential: Credential, key: &[u8]) -> Self {
338 Self {
339 credential,
340 signature_key: key.into(),
341 }
342 }
343}
344
345#[cfg(any(test, feature = "test-utils"))]
346pub mod test_utils {
347 use openmls_basic_credential::SignatureKeyPair;
348 use openmls_traits::{types::SignatureScheme, OpenMlsProvider};
349
350 use super::{BasicCredential, CredentialWithKey};
351
352 pub fn new_credential(
360 provider: &impl OpenMlsProvider,
361 identity: &[u8],
362 signature_scheme: SignatureScheme,
363 ) -> (CredentialWithKey, SignatureKeyPair) {
364 let credential = BasicCredential::new(identity.into());
365 let signature_keys = SignatureKeyPair::new(signature_scheme).unwrap();
366 signature_keys.store(provider.storage()).unwrap();
367
368 (
369 CredentialWithKey {
370 credential: credential.into(),
371 signature_key: signature_keys.public().into(),
372 },
373 signature_keys,
374 )
375 }
376}
377
378#[cfg(test)]
379mod unit_tests {
380 use tls_codec::{
381 DeserializeBytes, Serialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
382 };
383
384 use super::{BasicCredential, Credential, CredentialType};
385
386 #[test]
387 fn basic_credential_identity_and_codec() {
388 const IDENTITY: &str = "identity";
389 let basic_credential = BasicCredential::new(IDENTITY.into());
391 assert_eq!(basic_credential.identity(), IDENTITY.as_bytes());
392
393 let credential = Credential::from(basic_credential.clone());
395 let serialized = credential.tls_serialize_detached().unwrap();
396
397 let deserialized = Credential::tls_deserialize_exact_bytes(&serialized).unwrap();
398 assert_eq!(credential.credential_type(), deserialized.credential_type());
399 assert_eq!(
400 credential.serialized_content(),
401 deserialized.serialized_content()
402 );
403
404 let deserialized_basic_credential = BasicCredential::try_from(deserialized).unwrap();
405 assert_eq!(
406 deserialized_basic_credential.identity(),
407 IDENTITY.as_bytes()
408 );
409 assert_eq!(basic_credential, deserialized_basic_credential);
410 }
411
412 #[test]
414 fn custom_credential() {
415 #[derive(
416 Debug, Clone, PartialEq, Eq, TlsSize, TlsSerialize, TlsDeserialize, TlsDeserializeBytes,
417 )]
418 struct CustomCredential {
419 custom_field1: u32,
420 custom_field2: Vec<u8>,
421 custom_field3: Option<u8>,
422 }
423
424 let custom_credential = CustomCredential {
425 custom_field1: 42,
426 custom_field2: vec![1, 2, 3],
427 custom_field3: Some(2),
428 };
429
430 let credential = Credential::new(
431 CredentialType::Other(1234),
432 custom_credential.tls_serialize_detached().unwrap(),
433 );
434
435 let serialized = credential.tls_serialize_detached().unwrap();
436 let deserialized = Credential::tls_deserialize_exact_bytes(&serialized).unwrap();
437 assert_eq!(credential, deserialized);
438
439 let deserialized_custom_credential =
440 CustomCredential::tls_deserialize_exact_bytes(deserialized.serialized_content())
441 .unwrap();
442
443 assert_eq!(custom_credential, deserialized_custom_credential);
444 }
445}