1use crate::{
5 ciphersuite::{signable::*, *},
6 credentials::*,
7 extensions::Extensions,
8 treesync::node::leaf_node::{LeafNodeIn, VerifiableLeafNode},
9 versions::ProtocolVersion,
10};
11use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
12use serde::{Deserialize, Serialize};
13use tls_codec::{
14 Serialize as TlsSerializeTrait, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
15};
16
17use super::{
18 errors::KeyPackageVerifyError, InitKey, KeyPackage, KeyPackageTbs, SIGNATURE_KEY_PACKAGE_LABEL,
19};
20
21#[cfg(any(feature = "test-utils", test))]
22use super::KeyPackageBundle;
23
24struct VerifiableKeyPackage {
26 payload: KeyPackageTbs,
27 signature: Signature,
28}
29
30impl VerifiableKeyPackage {
31 fn new(payload: KeyPackageTbs, signature: Signature) -> Self {
32 Self { payload, signature }
33 }
34}
35
36impl Verifiable for VerifiableKeyPackage {
37 type VerifiedStruct = KeyPackage;
38
39 fn unsigned_payload(&self) -> Result<Vec<u8>, tls_codec::Error> {
40 self.payload.tls_serialize_detached()
41 }
42
43 fn signature(&self) -> &Signature {
44 &self.signature
45 }
46
47 fn label(&self) -> &str {
48 SIGNATURE_KEY_PACKAGE_LABEL
49 }
50
51 fn verify(
52 self,
53 crypto: &impl OpenMlsCrypto,
54 pk: &OpenMlsSignaturePublicKey,
55 ) -> Result<Self::VerifiedStruct, SignatureError> {
56 self.verify_no_out(crypto, pk)?;
57
58 Ok(KeyPackage {
59 payload: self.payload,
60 signature: self.signature,
61 })
62 }
63}
64
65impl VerifiedStruct for KeyPackage {}
66
67#[derive(
79 Debug,
80 Clone,
81 PartialEq,
82 TlsSize,
83 TlsSerialize,
84 TlsDeserialize,
85 TlsDeserializeBytes,
86 Serialize,
87 Deserialize,
88)]
89struct KeyPackageTbsIn {
90 protocol_version: ProtocolVersion,
91 ciphersuite: Ciphersuite,
92 init_key: InitKey,
93 leaf_node: LeafNodeIn,
94 extensions: Extensions,
95}
96
97#[derive(
99 Debug,
100 PartialEq,
101 Clone,
102 Serialize,
103 Deserialize,
104 TlsSerialize,
105 TlsDeserialize,
106 TlsDeserializeBytes,
107 TlsSize,
108)]
109pub struct KeyPackageIn {
110 payload: KeyPackageTbsIn,
111 signature: Signature,
112}
113
114impl KeyPackageIn {
115 pub fn unverified_credential(&self) -> CredentialWithKey {
117 let credential = self.payload.leaf_node.credential().clone();
118 let signature_key = self.payload.leaf_node.signature_key().clone();
119 CredentialWithKey {
120 credential,
121 signature_key,
122 }
123 }
124
125 pub fn validate(
136 self,
137 crypto: &impl OpenMlsCrypto,
138 protocol_version: ProtocolVersion,
139 ) -> Result<KeyPackage, KeyPackageVerifyError> {
140 let leaf_node = self.payload.leaf_node.clone().into_verifiable_leaf_node();
142
143 let signature_key = &OpenMlsSignaturePublicKey::from_signature_key(
144 self.payload.leaf_node.signature_key().clone(),
145 self.payload.ciphersuite.signature_algorithm(),
146 );
147
148 let leaf_node = match leaf_node {
150 VerifiableLeafNode::KeyPackage(leaf_node) => leaf_node
151 .verify(crypto, signature_key)
152 .map_err(|_| KeyPackageVerifyError::InvalidLeafNodeSignature)?,
153 _ => return Err(KeyPackageVerifyError::InvalidLeafNodeSourceType),
154 };
155
156 if !self.version_is_supported(protocol_version) {
159 return Err(KeyPackageVerifyError::InvalidProtocolVersion);
160 }
161
162 if leaf_node.encryption_key().key() == self.payload.init_key.key() {
165 return Err(KeyPackageVerifyError::InitKeyEqualsEncryptionKey);
166 }
167
168 let key_package_tbs = KeyPackageTbs {
169 protocol_version: self.payload.protocol_version,
170 ciphersuite: self.payload.ciphersuite,
171 init_key: self.payload.init_key,
172 leaf_node,
173 extensions: self.payload.extensions,
174 };
175
176 let key_package = VerifiableKeyPackage::new(key_package_tbs, self.signature)
179 .verify(crypto, signature_key)
180 .map_err(|_| KeyPackageVerifyError::InvalidSignature)?;
181
182 for extension in key_package.payload.extensions.iter() {
185 if !key_package
186 .payload
187 .leaf_node
188 .supports_extension(&extension.extension_type())
189 {
190 return Err(KeyPackageVerifyError::UnsupportedExtension);
191 }
192 }
193
194 if let Some(life_time) = key_package.payload.leaf_node.life_time() {
196 if !life_time.is_valid() {
197 return Err(KeyPackageVerifyError::InvalidLifetime);
198 }
199 } else {
200 return Err(KeyPackageVerifyError::MissingLifetime);
203 }
204
205 Ok(key_package)
206 }
207
208 pub(crate) fn version_is_supported(&self, protocol_version: ProtocolVersion) -> bool {
211 self.payload.protocol_version == protocol_version
212 }
213}
214
215#[cfg(any(feature = "test-utils", test))]
216impl From<KeyPackageTbsIn> for KeyPackageTbs {
217 fn from(value: KeyPackageTbsIn) -> Self {
218 KeyPackageTbs {
219 protocol_version: value.protocol_version,
220 ciphersuite: value.ciphersuite,
221 init_key: value.init_key,
222 leaf_node: value.leaf_node.into(),
223 extensions: value.extensions,
224 }
225 }
226}
227
228impl From<KeyPackageTbs> for KeyPackageTbsIn {
229 fn from(value: KeyPackageTbs) -> Self {
230 Self {
231 protocol_version: value.protocol_version,
232 ciphersuite: value.ciphersuite,
233 init_key: value.init_key,
234 leaf_node: value.leaf_node.into(),
235 extensions: value.extensions,
236 }
237 }
238}
239
240impl From<KeyPackage> for KeyPackageIn {
241 fn from(value: KeyPackage) -> Self {
242 Self {
243 payload: value.payload.into(),
244 signature: value.signature,
245 }
246 }
247}
248
249#[cfg(any(feature = "test-utils", test))]
250impl From<KeyPackageBundle> for KeyPackageIn {
251 fn from(value: KeyPackageBundle) -> Self {
252 Self {
253 payload: value.key_package.payload.into(),
254 signature: value.key_package.signature,
255 }
256 }
257}
258
259#[cfg(any(feature = "test-utils", test))]
260impl From<KeyPackageIn> for KeyPackage {
261 fn from(value: KeyPackageIn) -> Self {
262 Self {
263 payload: value.payload.into(),
264 signature: value.signature,
265 }
266 }
267}