1use crate::{
5 ciphersuite::{signable::*, *},
6 credentials::*,
7 extensions::{AnyObject, Extensions},
8 treesync::node::leaf_node::{LeafNodeIn, VerifiableLeafNode},
9 versions::ProtocolVersion,
10};
11use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
12use serde::{Deserialize, Serialize};
13use tls_codec::{
14 Serialize as TlsSerializeTrait, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
15};
16
17use super::{
18 errors::KeyPackageVerifyError, InitKey, KeyPackage, KeyPackageTbs, SIGNATURE_KEY_PACKAGE_LABEL,
19};
20
21#[cfg(any(feature = "test-utils", test))]
22use super::KeyPackageBundle;
23
24struct VerifiableKeyPackage {
26 payload: KeyPackageTbs,
27 signature: Signature,
28}
29
30impl VerifiableKeyPackage {
31 fn new(payload: KeyPackageTbs, signature: Signature) -> Self {
32 Self { payload, signature }
33 }
34}
35
36impl Verifiable for VerifiableKeyPackage {
37 type VerifiedStruct = KeyPackage;
38
39 fn unsigned_payload(&self) -> Result<Vec<u8>, tls_codec::Error> {
40 self.payload.tls_serialize_detached()
41 }
42
43 fn signature(&self) -> &Signature {
44 &self.signature
45 }
46
47 fn label(&self) -> &str {
48 SIGNATURE_KEY_PACKAGE_LABEL
49 }
50
51 fn verify(
52 self,
53 crypto: &impl OpenMlsCrypto,
54 pk: &OpenMlsSignaturePublicKey,
55 ) -> Result<Self::VerifiedStruct, SignatureError> {
56 self.verify_no_out(crypto, pk)?;
57
58 Ok(KeyPackage {
59 payload: self.payload,
60 signature: self.signature,
61 serialized_payload: None,
62 })
63 }
64}
65
66impl VerifiedStruct for KeyPackage {}
67
68#[derive(
80 Debug,
81 Clone,
82 PartialEq,
83 TlsSize,
84 TlsSerialize,
85 TlsDeserialize,
86 TlsDeserializeBytes,
87 Serialize,
88 Deserialize,
89)]
90struct KeyPackageTbsIn {
91 protocol_version: ProtocolVersion,
92 ciphersuite: Ciphersuite,
93 init_key: InitKey,
94 leaf_node: LeafNodeIn,
95 extensions: Extensions<AnyObject>,
96}
97
98#[derive(
100 Debug,
101 PartialEq,
102 Clone,
103 Serialize,
104 Deserialize,
105 TlsSerialize,
106 TlsDeserialize,
107 TlsDeserializeBytes,
108 TlsSize,
109)]
110pub struct KeyPackageIn {
111 payload: KeyPackageTbsIn,
112 signature: Signature,
113}
114
115impl KeyPackageIn {
116 pub fn unverified_credential(&self) -> CredentialWithKey {
118 let credential = self.payload.leaf_node.credential().clone();
119 let signature_key = self.payload.leaf_node.signature_key().clone();
120 CredentialWithKey {
121 credential,
122 signature_key,
123 }
124 }
125
126 pub fn validate(
137 self,
138 crypto: &impl OpenMlsCrypto,
139 protocol_version: ProtocolVersion,
140 ) -> Result<KeyPackage, KeyPackageVerifyError> {
141 let leaf_node = self.payload.leaf_node.clone().into_verifiable_leaf_node();
143
144 let signature_key = &OpenMlsSignaturePublicKey::from_signature_key(
145 self.payload.leaf_node.signature_key().clone(),
146 self.payload.ciphersuite.signature_algorithm(),
147 );
148
149 let leaf_node = match leaf_node {
151 VerifiableLeafNode::KeyPackage(leaf_node) => leaf_node
152 .verify(crypto, signature_key)
153 .map_err(|_| KeyPackageVerifyError::InvalidLeafNodeSignature)?,
154 _ => return Err(KeyPackageVerifyError::InvalidLeafNodeSourceType),
155 };
156
157 if !self.version_is_supported(protocol_version) {
160 return Err(KeyPackageVerifyError::InvalidProtocolVersion);
161 }
162
163 if leaf_node.encryption_key().key() == self.payload.init_key.key() {
166 return Err(KeyPackageVerifyError::InitKeyEqualsEncryptionKey);
167 }
168
169 let key_package_tbs = KeyPackageTbs {
170 protocol_version: self.payload.protocol_version,
171 ciphersuite: self.payload.ciphersuite,
172 init_key: self.payload.init_key,
173 leaf_node,
174 extensions: self.payload.extensions.try_into()?,
175 };
176
177 let key_package = VerifiableKeyPackage::new(key_package_tbs, self.signature)
180 .verify(crypto, signature_key)
181 .map_err(|_| KeyPackageVerifyError::InvalidSignature)?;
182
183 for extension in key_package.payload.extensions.iter() {
186 if !key_package
187 .payload
188 .leaf_node
189 .supports_extension(&extension.extension_type())
190 {
191 return Err(KeyPackageVerifyError::UnsupportedExtension);
192 }
193 }
194
195 if let Some(life_time) = key_package.payload.leaf_node.life_time() {
197 life_time.validate()?;
198 } else {
199 return Err(KeyPackageVerifyError::MissingLifetime);
202 }
203
204 Ok(key_package)
205 }
206
207 pub(crate) fn version_is_supported(&self, protocol_version: ProtocolVersion) -> bool {
210 self.payload.protocol_version == protocol_version
211 }
212}
213
214#[cfg(any(feature = "test-utils", test))]
215impl From<KeyPackageTbsIn> for KeyPackageTbs {
216 fn from(value: KeyPackageTbsIn) -> Self {
217 KeyPackageTbs {
218 protocol_version: value.protocol_version,
219 ciphersuite: value.ciphersuite,
220 init_key: value.init_key,
221 leaf_node: value.leaf_node.into(),
222 extensions: value.extensions.coerce(),
223 }
224 }
225}
226
227impl From<KeyPackageTbs> for KeyPackageTbsIn {
228 fn from(value: KeyPackageTbs) -> Self {
229 Self {
230 protocol_version: value.protocol_version,
231 ciphersuite: value.ciphersuite,
232 init_key: value.init_key,
233 leaf_node: value.leaf_node.into(),
234 extensions: value.extensions.into(),
235 }
236 }
237}
238
239impl From<KeyPackage> for KeyPackageIn {
240 fn from(value: KeyPackage) -> Self {
241 Self {
242 payload: value.payload.into(),
243 signature: value.signature,
244 }
245 }
246}
247
248#[cfg(any(feature = "test-utils", test))]
249impl From<KeyPackageBundle> for KeyPackageIn {
250 fn from(value: KeyPackageBundle) -> Self {
251 Self {
252 payload: value.key_package.payload.into(),
253 signature: value.key_package.signature,
254 }
255 }
256}
257
258#[cfg(any(feature = "test-utils", test))]
259impl From<KeyPackageIn> for KeyPackage {
260 fn from(value: KeyPackageIn) -> Self {
261 Self {
262 payload: value.payload.into(),
263 signature: value.signature,
264 serialized_payload: None,
265 }
266 }
267}