openmls/test_utils/frankenstein/
leaf_node.rs1use std::ops::{Deref, DerefMut};
2
3use openmls_basic_credential::SignatureKeyPair;
4use openmls_traits::{signatures::Signer, types::Ciphersuite, OpenMlsProvider};
5use tls_codec::*;
6
7use super::{extensions::FrankenExtension, key_package::FrankenLifetime, FrankenCredential};
8use crate::{
9 binary_tree::{array_representation::tree, LeafNodeIndex},
10 ciphersuite::{
11 signable::{Signable, SignedStruct},
12 signature::Signature,
13 },
14 group::GroupId,
15 treesync::{
16 node::leaf_node::{LeafNodeIn, LeafNodeTbs, TreePosition},
17 LeafNode,
18 },
19};
20
21#[derive(
22 Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
23)]
24pub struct FrankenLeafNode {
25 pub payload: FrankenLeafNodePayload,
26 pub signature: VLBytes,
27}
28
29impl FrankenLeafNode {
30 pub fn resign(&mut self, tree_position: Option<FrankenTreePosition>, signer: &impl Signer) {
32 let tbs = FrankenLeafNodeTbs {
33 payload: self.payload.clone(),
34 tree_position,
35 };
36 let new_self = tbs.sign(signer).unwrap();
37 let _ = std::mem::replace(self, new_self);
38 }
39}
40
41impl Deref for FrankenLeafNode {
42 type Target = FrankenLeafNodePayload;
43
44 fn deref(&self) -> &Self::Target {
45 &self.payload
46 }
47}
48
49impl DerefMut for FrankenLeafNode {
50 fn deref_mut(&mut self) -> &mut Self::Target {
51 &mut self.payload
52 }
53}
54
55impl SignedStruct<FrankenLeafNodeTbs> for FrankenLeafNode {
56 fn from_payload(
57 tbs: FrankenLeafNodeTbs,
58 signature: Signature,
59 _serialized_payload: Vec<u8>,
60 ) -> Self {
61 Self {
62 payload: tbs.payload,
63 signature: signature.as_slice().to_owned().into(),
64 }
65 }
66}
67
68const LEAF_NODE_SIGNATURE_LABEL: &str = "LeafNodeTBS";
69
70impl Signable for FrankenLeafNodeTbs {
71 type SignedOutput = FrankenLeafNode;
72
73 fn unsigned_payload(&self) -> Result<Vec<u8>, tls_codec::Error> {
74 self.tls_serialize_detached()
75 }
76
77 fn label(&self) -> &str {
78 LEAF_NODE_SIGNATURE_LABEL
79 }
80}
81
82impl From<LeafNode> for FrankenLeafNode {
83 fn from(ln: LeafNode) -> Self {
84 FrankenLeafNode::tls_deserialize(&mut ln.tls_serialize_detached().unwrap().as_slice())
85 .unwrap()
86 }
87}
88
89impl From<FrankenLeafNode> for LeafNode {
90 fn from(fln: FrankenLeafNode) -> Self {
91 LeafNodeIn::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice())
92 .unwrap()
93 .into()
94 }
95}
96
97impl From<FrankenLeafNode> for LeafNodeIn {
98 fn from(fln: FrankenLeafNode) -> Self {
99 LeafNodeIn::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice()).unwrap()
100 }
101}
102
103#[derive(
104 Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
105)]
106pub struct FrankenLeafNodePayload {
107 pub encryption_key: VLBytes,
108 pub signature_key: VLBytes,
109 pub credential: FrankenCredential,
110 pub capabilities: FrankenCapabilities,
111 pub leaf_node_source: FrankenLeafNodeSource,
112 pub extensions: Vec<FrankenExtension>,
113}
114
115#[derive(
116 Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
117)]
118pub struct FrankenTreePosition {
119 pub group_id: VLBytes,
120 pub leaf_index: u32,
121}
122
123impl From<TreePosition> for FrankenTreePosition {
124 fn from(tp: TreePosition) -> Self {
125 let (group_id, leaf_index) = tp.into_parts();
126 Self {
127 group_id: group_id.as_slice().to_owned().into(),
128 leaf_index: leaf_index.u32(),
129 }
130 }
131}
132
133impl From<FrankenTreePosition> for TreePosition {
134 fn from(ftp: FrankenTreePosition) -> Self {
135 Self::new(
136 GroupId::from_slice(ftp.group_id.as_slice()),
137 LeafNodeIndex::new(ftp.leaf_index),
138 )
139 }
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, TlsSize)]
143pub struct FrankenLeafNodeTbs {
144 pub payload: FrankenLeafNodePayload,
145 pub tree_position: Option<FrankenTreePosition>,
146}
147
148impl FrankenLeafNodeTbs {
149 fn deserialize_without_treeposition<R: std::io::Read>(bytes: &mut R) -> Result<Self, Error> {
150 let payload = FrankenLeafNodePayload::tls_deserialize(bytes)?;
151
152 Ok(Self {
153 payload,
154 tree_position: None,
155 })
156 }
157
158 fn deserialize_with_treeposition<R: std::io::Read>(bytes: &mut R) -> Result<Self, Error> {
159 let payload = FrankenLeafNodePayload::tls_deserialize(bytes)?;
160 let tree_position = FrankenTreePosition::tls_deserialize(bytes)?;
161 Ok(Self {
162 payload,
163 tree_position: Some(tree_position),
164 })
165 }
166}
167
168impl Deserialize for FrankenLeafNodeTbs {
169 fn tls_deserialize<R: std::io::prelude::Read>(bytes: &mut R) -> Result<Self, Error>
170 where
171 Self: Sized,
172 {
173 let payload = FrankenLeafNodePayload::tls_deserialize(bytes)?;
174 let tree_position = match payload.leaf_node_source {
175 FrankenLeafNodeSource::KeyPackage(_) => None,
176 FrankenLeafNodeSource::Update | FrankenLeafNodeSource::Commit(_) => {
177 let tree_position = FrankenTreePosition::tls_deserialize(bytes)?;
178 Some(tree_position)
179 }
180 };
181
182 Ok(Self {
183 payload,
184 tree_position,
185 })
186 }
187}
188
189impl DeserializeBytes for FrankenLeafNodeTbs {
190 fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
191 where
192 Self: Sized,
193 {
194 let (payload, rest) = FrankenLeafNodePayload::tls_deserialize_bytes(bytes)?;
195 let (tree_position, rest) = match payload.leaf_node_source {
196 FrankenLeafNodeSource::KeyPackage(_) => (None, rest),
197 FrankenLeafNodeSource::Update | FrankenLeafNodeSource::Commit(_) => {
198 let (tree_position, rest) = FrankenTreePosition::tls_deserialize_bytes(bytes)?;
199 (Some(tree_position), rest)
200 }
201 };
202
203 Ok((
204 Self {
205 payload,
206 tree_position,
207 },
208 rest,
209 ))
210 }
211}
212
213impl Serialize for FrankenLeafNodeTbs {
214 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
215 let mut written = self.payload.tls_serialize(writer)?;
216
217 if let Some(tree_info) = &self.tree_position {
218 written += tree_info.tls_serialize(writer)?
219 };
220
221 Ok(written)
222 }
223}
224
225#[derive(
226 Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
227)]
228pub struct FrankenCapabilities {
229 pub versions: Vec<u16>,
230 pub ciphersuites: Vec<u16>,
231 pub extensions: Vec<u16>,
232 pub proposals: Vec<u16>,
233 pub credentials: Vec<u16>,
234}
235
236#[derive(
237 Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
238)]
239#[repr(u8)]
240pub enum FrankenLeafNodeSource {
241 #[tls_codec(discriminant = 1)]
242 KeyPackage(FrankenLifetime),
243 Update,
244 Commit(VLBytes),
245}