1use crate::{
5 ciphersuite::{signable::*, *},
6 credentials::*,
7 extensions::Extensions,
8 treesync::node::leaf_node::{LeafNodeIn, VerifiableLeafNode},
9 versions::ProtocolVersion,
10};
11use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
12use serde::{Deserialize, Serialize};
13use tls_codec::{
14 Serialize as TlsSerializeTrait, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
15};
16
17use super::{
18 errors::KeyPackageVerifyError, InitKey, KeyPackage, KeyPackageTbs, SIGNATURE_KEY_PACKAGE_LABEL,
19};
20
21#[cfg(any(feature = "test-utils", test))]
22use super::KeyPackageBundle;
23
24struct VerifiableKeyPackage {
26 payload: KeyPackageTbs,
27 signature: Signature,
28}
29
30impl VerifiableKeyPackage {
31 fn new(payload: KeyPackageTbs, signature: Signature) -> Self {
32 Self { payload, signature }
33 }
34}
35
36impl Verifiable for VerifiableKeyPackage {
37 type VerifiedStruct = KeyPackage;
38
39 fn unsigned_payload(&self) -> Result<Vec<u8>, tls_codec::Error> {
40 self.payload.tls_serialize_detached()
41 }
42
43 fn signature(&self) -> &Signature {
44 &self.signature
45 }
46
47 fn label(&self) -> &str {
48 SIGNATURE_KEY_PACKAGE_LABEL
49 }
50
51 fn verify(
52 self,
53 crypto: &impl OpenMlsCrypto,
54 pk: &OpenMlsSignaturePublicKey,
55 ) -> Result<Self::VerifiedStruct, SignatureError> {
56 self.verify_no_out(crypto, pk)?;
57
58 Ok(KeyPackage {
59 payload: self.payload,
60 signature: self.signature,
61 serialized_payload: None,
62 })
63 }
64}
65
66impl VerifiedStruct for KeyPackage {}
67
68#[derive(
80 Debug,
81 Clone,
82 PartialEq,
83 TlsSize,
84 TlsSerialize,
85 TlsDeserialize,
86 TlsDeserializeBytes,
87 Serialize,
88 Deserialize,
89)]
90struct KeyPackageTbsIn {
91 protocol_version: ProtocolVersion,
92 ciphersuite: Ciphersuite,
93 init_key: InitKey,
94 leaf_node: LeafNodeIn,
95 extensions: Extensions,
96}
97
98#[derive(
100 Debug,
101 PartialEq,
102 Clone,
103 Serialize,
104 Deserialize,
105 TlsSerialize,
106 TlsDeserialize,
107 TlsDeserializeBytes,
108 TlsSize,
109)]
110pub struct KeyPackageIn {
111 payload: KeyPackageTbsIn,
112 signature: Signature,
113}
114
115impl KeyPackageIn {
116 pub fn unverified_credential(&self) -> CredentialWithKey {
118 let credential = self.payload.leaf_node.credential().clone();
119 let signature_key = self.payload.leaf_node.signature_key().clone();
120 CredentialWithKey {
121 credential,
122 signature_key,
123 }
124 }
125
126 pub fn validate(
137 self,
138 crypto: &impl OpenMlsCrypto,
139 protocol_version: ProtocolVersion,
140 ) -> Result<KeyPackage, KeyPackageVerifyError> {
141 let leaf_node = self.payload.leaf_node.clone().into_verifiable_leaf_node();
143
144 let signature_key = &OpenMlsSignaturePublicKey::from_signature_key(
145 self.payload.leaf_node.signature_key().clone(),
146 self.payload.ciphersuite.signature_algorithm(),
147 );
148
149 let leaf_node = match leaf_node {
151 VerifiableLeafNode::KeyPackage(leaf_node) => leaf_node
152 .verify(crypto, signature_key)
153 .map_err(|_| KeyPackageVerifyError::InvalidLeafNodeSignature)?,
154 _ => return Err(KeyPackageVerifyError::InvalidLeafNodeSourceType),
155 };
156
157 if !self.version_is_supported(protocol_version) {
160 return Err(KeyPackageVerifyError::InvalidProtocolVersion);
161 }
162
163 if leaf_node.encryption_key().key() == self.payload.init_key.key() {
166 return Err(KeyPackageVerifyError::InitKeyEqualsEncryptionKey);
167 }
168
169 let key_package_tbs = KeyPackageTbs {
170 protocol_version: self.payload.protocol_version,
171 ciphersuite: self.payload.ciphersuite,
172 init_key: self.payload.init_key,
173 leaf_node,
174 extensions: self.payload.extensions,
175 };
176
177 let key_package = VerifiableKeyPackage::new(key_package_tbs, self.signature)
180 .verify(crypto, signature_key)
181 .map_err(|_| KeyPackageVerifyError::InvalidSignature)?;
182
183 for extension in key_package.payload.extensions.iter() {
186 if !key_package
187 .payload
188 .leaf_node
189 .supports_extension(&extension.extension_type())
190 {
191 return Err(KeyPackageVerifyError::UnsupportedExtension);
192 }
193 }
194
195 if let Some(life_time) = key_package.payload.leaf_node.life_time() {
197 if !life_time.is_valid() {
198 return Err(KeyPackageVerifyError::InvalidLifetime);
199 }
200 } else {
201 return Err(KeyPackageVerifyError::MissingLifetime);
204 }
205
206 Ok(key_package)
207 }
208
209 pub(crate) fn version_is_supported(&self, protocol_version: ProtocolVersion) -> bool {
212 self.payload.protocol_version == protocol_version
213 }
214}
215
216#[cfg(any(feature = "test-utils", test))]
217impl From<KeyPackageTbsIn> for KeyPackageTbs {
218 fn from(value: KeyPackageTbsIn) -> Self {
219 KeyPackageTbs {
220 protocol_version: value.protocol_version,
221 ciphersuite: value.ciphersuite,
222 init_key: value.init_key,
223 leaf_node: value.leaf_node.into(),
224 extensions: value.extensions,
225 }
226 }
227}
228
229impl From<KeyPackageTbs> for KeyPackageTbsIn {
230 fn from(value: KeyPackageTbs) -> Self {
231 Self {
232 protocol_version: value.protocol_version,
233 ciphersuite: value.ciphersuite,
234 init_key: value.init_key,
235 leaf_node: value.leaf_node.into(),
236 extensions: value.extensions,
237 }
238 }
239}
240
241impl From<KeyPackage> for KeyPackageIn {
242 fn from(value: KeyPackage) -> Self {
243 Self {
244 payload: value.payload.into(),
245 signature: value.signature,
246 }
247 }
248}
249
250#[cfg(any(feature = "test-utils", test))]
251impl From<KeyPackageBundle> for KeyPackageIn {
252 fn from(value: KeyPackageBundle) -> Self {
253 Self {
254 payload: value.key_package.payload.into(),
255 signature: value.key_package.signature,
256 }
257 }
258}
259
260#[cfg(any(feature = "test-utils", test))]
261impl From<KeyPackageIn> for KeyPackage {
262 fn from(value: KeyPackageIn) -> Self {
263 Self {
264 payload: value.payload.into(),
265 signature: value.signature,
266 serialized_payload: None,
267 }
268 }
269}