openmls/test_utils/frankenstein/
leaf_node.rs

1use std::ops::{Deref, DerefMut};
2
3use openmls_basic_credential::SignatureKeyPair;
4use openmls_traits::{signatures::Signer, types::Ciphersuite, OpenMlsProvider};
5use tls_codec::*;
6
7use super::{extensions::FrankenExtension, key_package::FrankenLifetime, FrankenCredential};
8use crate::{
9    binary_tree::{array_representation::tree, LeafNodeIndex},
10    ciphersuite::{
11        signable::{Signable, SignedStruct},
12        signature::Signature,
13    },
14    group::GroupId,
15    treesync::{
16        node::leaf_node::{LeafNodeIn, LeafNodeTbs, TreePosition},
17        LeafNode,
18    },
19};
20
21#[derive(
22    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
23)]
24pub struct FrankenLeafNode {
25    pub payload: FrankenLeafNodePayload,
26    pub signature: VLBytes,
27}
28
29impl FrankenLeafNode {
30    // Re-sign the LeafNode
31    pub fn resign(&mut self, tree_position: Option<FrankenTreePosition>, signer: &impl Signer) {
32        let tbs = FrankenLeafNodeTbs {
33            payload: self.payload.clone(),
34            tree_position,
35        };
36        let new_self = tbs.sign(signer).unwrap();
37        let _ = std::mem::replace(self, new_self);
38    }
39}
40
41impl Deref for FrankenLeafNode {
42    type Target = FrankenLeafNodePayload;
43
44    fn deref(&self) -> &Self::Target {
45        &self.payload
46    }
47}
48
49impl DerefMut for FrankenLeafNode {
50    fn deref_mut(&mut self) -> &mut Self::Target {
51        &mut self.payload
52    }
53}
54
55impl SignedStruct<FrankenLeafNodeTbs> for FrankenLeafNode {
56    fn from_payload(
57        tbs: FrankenLeafNodeTbs,
58        signature: Signature,
59        _serialized_payload: Vec<u8>,
60    ) -> Self {
61        Self {
62            payload: tbs.payload,
63            signature: signature.as_slice().to_owned().into(),
64        }
65    }
66}
67
68const LEAF_NODE_SIGNATURE_LABEL: &str = "LeafNodeTBS";
69
70impl Signable for FrankenLeafNodeTbs {
71    type SignedOutput = FrankenLeafNode;
72
73    fn unsigned_payload(&self) -> Result<Vec<u8>, tls_codec::Error> {
74        self.tls_serialize_detached()
75    }
76
77    fn label(&self) -> &str {
78        LEAF_NODE_SIGNATURE_LABEL
79    }
80}
81
82impl From<LeafNode> for FrankenLeafNode {
83    fn from(ln: LeafNode) -> Self {
84        FrankenLeafNode::tls_deserialize(&mut ln.tls_serialize_detached().unwrap().as_slice())
85            .unwrap()
86    }
87}
88
89impl From<FrankenLeafNode> for LeafNode {
90    fn from(fln: FrankenLeafNode) -> Self {
91        LeafNodeIn::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice())
92            .unwrap()
93            .into()
94    }
95}
96
97impl From<FrankenLeafNode> for LeafNodeIn {
98    fn from(fln: FrankenLeafNode) -> Self {
99        LeafNodeIn::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice()).unwrap()
100    }
101}
102
103#[derive(
104    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
105)]
106pub struct FrankenLeafNodePayload {
107    pub encryption_key: VLBytes,
108    pub signature_key: VLBytes,
109    pub credential: FrankenCredential,
110    pub capabilities: FrankenCapabilities,
111    pub leaf_node_source: FrankenLeafNodeSource,
112    pub extensions: Vec<FrankenExtension>,
113}
114
115#[derive(
116    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
117)]
118pub struct FrankenTreePosition {
119    pub group_id: VLBytes,
120    pub leaf_index: u32,
121}
122
123impl From<TreePosition> for FrankenTreePosition {
124    fn from(tp: TreePosition) -> Self {
125        let (group_id, leaf_index) = tp.into_parts();
126        Self {
127            group_id: group_id.as_slice().to_owned().into(),
128            leaf_index: leaf_index.u32(),
129        }
130    }
131}
132
133impl From<FrankenTreePosition> for TreePosition {
134    fn from(ftp: FrankenTreePosition) -> Self {
135        Self::new(
136            GroupId::from_slice(ftp.group_id.as_slice()),
137            LeafNodeIndex::new(ftp.leaf_index),
138        )
139    }
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, TlsSize)]
143pub struct FrankenLeafNodeTbs {
144    pub payload: FrankenLeafNodePayload,
145    pub tree_position: Option<FrankenTreePosition>,
146}
147
148impl FrankenLeafNodeTbs {
149    fn deserialize_without_treeposition<R: std::io::Read>(bytes: &mut R) -> Result<Self, Error> {
150        let payload = FrankenLeafNodePayload::tls_deserialize(bytes)?;
151
152        Ok(Self {
153            payload,
154            tree_position: None,
155        })
156    }
157
158    fn deserialize_with_treeposition<R: std::io::Read>(bytes: &mut R) -> Result<Self, Error> {
159        let payload = FrankenLeafNodePayload::tls_deserialize(bytes)?;
160        let tree_position = FrankenTreePosition::tls_deserialize(bytes)?;
161        Ok(Self {
162            payload,
163            tree_position: Some(tree_position),
164        })
165    }
166}
167
168impl Deserialize for FrankenLeafNodeTbs {
169    fn tls_deserialize<R: std::io::prelude::Read>(bytes: &mut R) -> Result<Self, Error>
170    where
171        Self: Sized,
172    {
173        let payload = FrankenLeafNodePayload::tls_deserialize(bytes)?;
174        let tree_position = match payload.leaf_node_source {
175            FrankenLeafNodeSource::KeyPackage(_) => None,
176            FrankenLeafNodeSource::Update | FrankenLeafNodeSource::Commit(_) => {
177                let tree_position = FrankenTreePosition::tls_deserialize(bytes)?;
178                Some(tree_position)
179            }
180        };
181
182        Ok(Self {
183            payload,
184            tree_position,
185        })
186    }
187}
188
189impl DeserializeBytes for FrankenLeafNodeTbs {
190    fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
191    where
192        Self: Sized,
193    {
194        let (payload, rest) = FrankenLeafNodePayload::tls_deserialize_bytes(bytes)?;
195        let (tree_position, rest) = match payload.leaf_node_source {
196            FrankenLeafNodeSource::KeyPackage(_) => (None, rest),
197            FrankenLeafNodeSource::Update | FrankenLeafNodeSource::Commit(_) => {
198                let (tree_position, rest) = FrankenTreePosition::tls_deserialize_bytes(bytes)?;
199                (Some(tree_position), rest)
200            }
201        };
202
203        Ok((
204            Self {
205                payload,
206                tree_position,
207            },
208            rest,
209        ))
210    }
211}
212
213impl Serialize for FrankenLeafNodeTbs {
214    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
215        let mut written = self.payload.tls_serialize(writer)?;
216
217        if let Some(tree_info) = &self.tree_position {
218            written += tree_info.tls_serialize(writer)?
219        };
220
221        Ok(written)
222    }
223}
224
225#[derive(
226    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
227)]
228pub struct FrankenCapabilities {
229    pub versions: Vec<u16>,
230    pub ciphersuites: Vec<u16>,
231    pub extensions: Vec<u16>,
232    pub proposals: Vec<u16>,
233    pub credentials: Vec<u16>,
234}
235
236#[derive(
237    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
238)]
239#[repr(u8)]
240pub enum FrankenLeafNodeSource {
241    #[tls_codec(discriminant = 1)]
242    KeyPackage(FrankenLifetime),
243    Update,
244    Commit(VLBytes),
245}