1use std::io::{Read, Write};
28
29use serde::{Deserialize, Serialize};
30use tls_codec::{
31 Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
32 Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
33};
34
35#[cfg(test)]
36mod tests;
37
38use crate::{ciphersuite::SignaturePublicKey, group::Member, treesync::LeafNode};
39use errors::*;
40
41pub mod errors;
43
44#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
77#[repr(u16)]
78pub enum CredentialType {
79 Basic = 1,
81 X509 = 2,
83 Other(u16),
85}
86
87impl Size for CredentialType {
88 fn tls_serialized_len(&self) -> usize {
89 2
90 }
91}
92
93impl TlsDeserializeTrait for CredentialType {
94 fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
95 where
96 Self: Sized,
97 {
98 let mut extension_type = [0u8; 2];
99 bytes.read_exact(&mut extension_type)?;
100
101 Ok(CredentialType::from(u16::from_be_bytes(extension_type)))
102 }
103}
104
105impl TlsSerializeTrait for CredentialType {
106 fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
107 writer.write_all(&u16::from(*self).to_be_bytes())?;
108
109 Ok(2)
110 }
111}
112
113impl DeserializeBytes for CredentialType {
114 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
115 where
116 Self: Sized,
117 {
118 let mut bytes_ref = bytes;
119 let credential_type = CredentialType::tls_deserialize(&mut bytes_ref)?;
120 let remainder = &bytes[credential_type.tls_serialized_len()..];
121 Ok((credential_type, remainder))
122 }
123}
124
125impl From<u16> for CredentialType {
126 fn from(value: u16) -> Self {
127 match value {
128 1 => CredentialType::Basic,
129 2 => CredentialType::X509,
130 other => CredentialType::Other(other),
131 }
132 }
133}
134
135impl From<CredentialType> for u16 {
136 fn from(value: CredentialType) -> Self {
137 match value {
138 CredentialType::Basic => 1,
139 CredentialType::X509 => 2,
140 CredentialType::Other(other) => other,
141 }
142 }
143}
144
145#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
156pub struct Certificate {
157 cert_data: Vec<u8>,
158}
159
160#[derive(
190 Debug,
191 PartialEq,
192 Eq,
193 Clone,
194 Serialize,
195 Deserialize,
196 TlsSize,
197 TlsSerialize,
198 TlsDeserialize,
199 TlsDeserializeBytes,
200)]
201pub struct Credential {
202 credential_type: CredentialType,
203 serialized_credential_content: VLBytes,
204}
205
206impl Credential {
207 pub fn credential_type(&self) -> CredentialType {
209 self.credential_type
210 }
211
212 pub fn new(credential_type: CredentialType, serialized_credential: Vec<u8>) -> Self {
215 Self {
216 credential_type,
217 serialized_credential_content: serialized_credential.into(),
218 }
219 }
220
221 pub fn serialized_content(&self) -> &[u8] {
226 self.serialized_credential_content.as_slice()
227 }
228
229 pub fn deserialized<T: tls_codec::Size + tls_codec::Deserialize>(
231 &self,
232 ) -> Result<T, tls_codec::Error> {
233 T::tls_deserialize_exact(&self.serialized_credential_content)
234 }
235}
236
237#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
248pub struct BasicCredential {
249 identity: VLBytes,
250}
251
252impl BasicCredential {
253 pub fn new(identity: Vec<u8>) -> Self {
260 Self {
261 identity: identity.into(),
262 }
263 }
264
265 pub fn identity(&self) -> &[u8] {
267 self.identity.as_slice()
268 }
269}
270
271impl From<BasicCredential> for Credential {
272 fn from(credential: BasicCredential) -> Self {
273 Credential {
274 credential_type: CredentialType::Basic,
275 serialized_credential_content: credential.identity,
276 }
277 }
278}
279
280impl TryFrom<Credential> for BasicCredential {
281 type Error = BasicCredentialError;
282
283 fn try_from(credential: Credential) -> Result<Self, Self::Error> {
284 match credential.credential_type {
285 CredentialType::Basic => Ok(BasicCredential::new(
286 credential.serialized_credential_content.into(),
287 )),
288 _ => Err(errors::BasicCredentialError::WrongCredentialType),
289 }
290 }
291}
292
293#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
294pub struct CredentialWithKey {
296 pub credential: Credential,
298 pub signature_key: SignaturePublicKey,
300}
301
302impl From<&LeafNode> for CredentialWithKey {
303 fn from(leaf_node: &LeafNode) -> Self {
304 Self {
305 credential: leaf_node.credential().clone(),
306 signature_key: leaf_node.signature_key().clone(),
307 }
308 }
309}
310
311impl From<&Member> for CredentialWithKey {
312 fn from(member: &Member) -> Self {
313 Self {
314 credential: member.credential.clone(),
315 signature_key: member.signature_key.clone().into(),
316 }
317 }
318}
319
320#[cfg(test)]
321impl CredentialWithKey {
322 pub fn from_parts(credential: Credential, key: &[u8]) -> Self {
323 Self {
324 credential,
325 signature_key: key.into(),
326 }
327 }
328}
329
330#[cfg(any(test, feature = "test-utils"))]
331pub mod test_utils {
332 use openmls_basic_credential::SignatureKeyPair;
333 use openmls_traits::{types::SignatureScheme, OpenMlsProvider};
334
335 use super::{BasicCredential, CredentialWithKey};
336
337 pub fn new_credential(
345 provider: &impl OpenMlsProvider,
346 identity: &[u8],
347 signature_scheme: SignatureScheme,
348 ) -> (CredentialWithKey, SignatureKeyPair) {
349 let credential = BasicCredential::new(identity.into());
350 let signature_keys = SignatureKeyPair::new(signature_scheme).unwrap();
351 signature_keys.store(provider.storage()).unwrap();
352
353 (
354 CredentialWithKey {
355 credential: credential.into(),
356 signature_key: signature_keys.public().into(),
357 },
358 signature_keys,
359 )
360 }
361}
362
363#[cfg(test)]
364mod unit_tests {
365 use tls_codec::{
366 DeserializeBytes, Serialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
367 };
368
369 use super::{BasicCredential, Credential, CredentialType};
370
371 #[test]
372 fn basic_credential_identity_and_codec() {
373 const IDENTITY: &str = "identity";
374 let basic_credential = BasicCredential::new(IDENTITY.into());
376 assert_eq!(basic_credential.identity(), IDENTITY.as_bytes());
377
378 let credential = Credential::from(basic_credential.clone());
380 let serialized = credential.tls_serialize_detached().unwrap();
381
382 let deserialized = Credential::tls_deserialize_exact_bytes(&serialized).unwrap();
383 assert_eq!(credential.credential_type(), deserialized.credential_type());
384 assert_eq!(
385 credential.serialized_content(),
386 deserialized.serialized_content()
387 );
388
389 let deserialized_basic_credential = BasicCredential::try_from(deserialized).unwrap();
390 assert_eq!(
391 deserialized_basic_credential.identity(),
392 IDENTITY.as_bytes()
393 );
394 assert_eq!(basic_credential, deserialized_basic_credential);
395 }
396
397 #[test]
399 fn custom_credential() {
400 #[derive(
401 Debug, Clone, PartialEq, Eq, TlsSize, TlsSerialize, TlsDeserialize, TlsDeserializeBytes,
402 )]
403 struct CustomCredential {
404 custom_field1: u32,
405 custom_field2: Vec<u8>,
406 custom_field3: Option<u8>,
407 }
408
409 let custom_credential = CustomCredential {
410 custom_field1: 42,
411 custom_field2: vec![1, 2, 3],
412 custom_field3: Some(2),
413 };
414
415 let credential = Credential::new(
416 CredentialType::Other(1234),
417 custom_credential.tls_serialize_detached().unwrap(),
418 );
419
420 let serialized = credential.tls_serialize_detached().unwrap();
421 let deserialized = Credential::tls_deserialize_exact_bytes(&serialized).unwrap();
422 assert_eq!(credential, deserialized);
423
424 let deserialized_custom_credential =
425 CustomCredential::tls_deserialize_exact_bytes(deserialized.serialized_content())
426 .unwrap();
427
428 assert_eq!(custom_credential, deserialized_custom_credential);
429 }
430}