openmls/test_utils/frankenstein/
extensions.rs

1use tls_codec::*;
2
3use crate::{
4    extensions::{
5        ApplicationIdExtension, Extension, RatchetTreeExtension, RequiredCapabilitiesExtension,
6    },
7    treesync::{node::NodeIn, Node, ParentNode},
8};
9
10use super::{FrankenCredential, FrankenLeafNode};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum FrankenExtensionType {
14    ApplicationId,
15    RatchetTree,
16    RequiredCapabilities,
17    ExternalPub,
18    ExternalSenders,
19    LastResort,
20    Unknown(u16),
21}
22
23impl From<u16> for FrankenExtensionType {
24    fn from(a: u16) -> Self {
25        match a {
26            1 => FrankenExtensionType::ApplicationId,
27            2 => FrankenExtensionType::RatchetTree,
28            3 => FrankenExtensionType::RequiredCapabilities,
29            4 => FrankenExtensionType::ExternalPub,
30            5 => FrankenExtensionType::ExternalSenders,
31            10 => FrankenExtensionType::LastResort,
32            unknown => FrankenExtensionType::Unknown(unknown),
33        }
34    }
35}
36
37impl From<FrankenExtensionType> for u16 {
38    fn from(value: FrankenExtensionType) -> Self {
39        match value {
40            FrankenExtensionType::ApplicationId => 1,
41            FrankenExtensionType::RatchetTree => 2,
42            FrankenExtensionType::RequiredCapabilities => 3,
43            FrankenExtensionType::ExternalPub => 4,
44            FrankenExtensionType::ExternalSenders => 5,
45            FrankenExtensionType::LastResort => 10,
46            FrankenExtensionType::Unknown(unknown) => unknown,
47        }
48    }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52#[repr(u16)]
53pub enum FrankenExtension {
54    ApplicationId(FrankenApplicationIdExtension),
55    RatchetTree(FrankenRatchetTreeExtension),
56    RequiredCapabilities(FrankenRequiredCapabilitiesExtension),
57    ExternalPub(FrankenExternalPubExtension),
58    ExternalSenders(FrankenExternalSendersExtension),
59    LastResort,
60    Unknown(u16, VLBytes),
61}
62
63impl FrankenExtension {
64    pub const fn extension_type(&self) -> FrankenExtensionType {
65        match self {
66            FrankenExtension::ApplicationId(_) => FrankenExtensionType::ApplicationId,
67            FrankenExtension::RatchetTree(_) => FrankenExtensionType::RatchetTree,
68            FrankenExtension::RequiredCapabilities(_) => FrankenExtensionType::RequiredCapabilities,
69            FrankenExtension::ExternalPub(_) => FrankenExtensionType::ExternalPub,
70            FrankenExtension::ExternalSenders(_) => FrankenExtensionType::ExternalSenders,
71            FrankenExtension::LastResort => FrankenExtensionType::LastResort,
72            FrankenExtension::Unknown(kind, _) => FrankenExtensionType::Unknown(*kind),
73        }
74    }
75}
76
77impl From<Extension> for FrankenExtension {
78    fn from(value: Extension) -> Self {
79        let bytes = value.tls_serialize_detached().unwrap();
80        FrankenExtension::tls_deserialize(&mut bytes.as_slice()).unwrap()
81    }
82}
83
84#[derive(
85    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
86)]
87pub struct FrankenApplicationIdExtension {
88    pub key_id: VLBytes,
89}
90
91#[derive(
92    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
93)]
94pub struct FrankenRatchetTreeExtension {
95    pub ratchet_tree: Vec<Option<FrankenNode>>,
96}
97
98#[derive(
99    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
100)]
101#[repr(u8)]
102#[allow(clippy::large_enum_variant)]
103pub enum FrankenNode {
104    #[tls_codec(discriminant = 1)]
105    LeafNode(FrankenLeafNode),
106    #[tls_codec(discriminant = 2)]
107    ParentNode(FrankenParentNode),
108}
109
110impl From<Node> for FrankenNode {
111    fn from(value: Node) -> Self {
112        let bytes = value.tls_serialize_detached().unwrap();
113        FrankenNode::tls_deserialize(&mut bytes.as_slice()).unwrap()
114    }
115}
116
117impl From<NodeIn> for FrankenNode {
118    fn from(value: NodeIn) -> Self {
119        let bytes = value.tls_serialize_detached().unwrap();
120        FrankenNode::tls_deserialize(&mut bytes.as_slice()).unwrap()
121    }
122}
123
124#[derive(
125    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
126)]
127pub struct FrankenParentNode {
128    pub encryption_key: VLBytes,
129    pub parent_hash: VLBytes,
130    pub unmerged_leaves: Vec<u32>,
131}
132
133impl From<ParentNode> for FrankenParentNode {
134    fn from(value: ParentNode) -> Self {
135        let bytes = value.tls_serialize_detached().unwrap();
136        Self::tls_deserialize(&mut bytes.as_slice()).unwrap()
137    }
138}
139
140#[derive(
141    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
142)]
143pub struct FrankenRequiredCapabilitiesExtension {
144    pub extension_types: Vec<u16>,
145    pub proposal_types: Vec<u16>,
146    pub credential_types: Vec<u16>,
147}
148
149impl From<RequiredCapabilitiesExtension> for FrankenRequiredCapabilitiesExtension {
150    fn from(value: RequiredCapabilitiesExtension) -> Self {
151        let bytes = value.tls_serialize_detached().unwrap();
152        Self::tls_deserialize(&mut bytes.as_slice()).unwrap()
153    }
154}
155
156#[derive(
157    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
158)]
159pub struct FrankenExternalPubExtension {
160    external_pub: VLBytes,
161}
162
163#[derive(
164    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
165)]
166pub struct FrankenExternalSendersExtension {
167    external_senders: Vec<FrankenExternalSender>,
168}
169
170#[derive(
171    Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize,
172)]
173pub struct FrankenExternalSender {
174    pub signature_key: VLBytes,
175    pub credential: FrankenCredential,
176}