1use super::*;
31use crate::{
32 extensions::Extensions,
33 key_packages::Lifetime,
34 tree::sender_ratchet::SenderRatchetConfiguration,
35 treesync::{errors::LeafNodeValidationError, node::leaf_node::Capabilities},
36};
37use serde::{Deserialize, Serialize};
38
39#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
44pub struct MlsGroupJoinConfig {
45 pub(crate) wire_format_policy: WireFormatPolicy,
48 pub(crate) padding_size: usize,
50 pub(crate) max_past_epochs: usize,
53 pub(crate) number_of_resumption_psks: usize,
55 pub(crate) use_ratchet_tree_extension: bool,
57 pub(crate) sender_ratchet_configuration: SenderRatchetConfiguration,
59}
60
61impl MlsGroupJoinConfig {
62 pub fn builder() -> MlsGroupJoinConfigBuilder {
64 MlsGroupJoinConfigBuilder::new()
65 }
66
67 pub fn wire_format_policy(&self) -> WireFormatPolicy {
69 self.wire_format_policy
70 }
71
72 pub fn padding_size(&self) -> usize {
74 self.padding_size
75 }
76
77 pub fn sender_ratchet_configuration(&self) -> &SenderRatchetConfiguration {
79 &self.sender_ratchet_configuration
80 }
81}
82
83#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
87pub struct MlsGroupCreateConfig {
88 pub(crate) capabilities: Capabilities,
90 pub(crate) lifetime: Lifetime,
92 pub(crate) ciphersuite: Ciphersuite,
94 pub(crate) join_config: MlsGroupJoinConfig,
96 pub(crate) group_context_extensions: Extensions<GroupContext>,
98 pub(crate) leaf_node_extensions: Extensions<LeafNode>,
100}
101
102impl Default for MlsGroupCreateConfig {
103 fn default() -> Self {
104 Self {
105 capabilities: Capabilities::default(),
106 lifetime: Lifetime::default(),
107 ciphersuite: Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519,
108 join_config: MlsGroupJoinConfig::default(),
109 group_context_extensions: Extensions::default(),
110 leaf_node_extensions: Extensions::default(),
111 }
112 }
113}
114
115#[derive(Default)]
117pub struct MlsGroupJoinConfigBuilder {
118 join_config: MlsGroupJoinConfig,
119}
120
121impl MlsGroupJoinConfigBuilder {
122 fn new() -> Self {
124 Self {
125 join_config: MlsGroupJoinConfig::default(),
126 }
127 }
128
129 pub fn wire_format_policy(mut self, wire_format_policy: WireFormatPolicy) -> Self {
131 self.join_config.wire_format_policy = wire_format_policy;
132 self
133 }
134
135 pub fn padding_size(mut self, padding_size: usize) -> Self {
137 self.join_config.padding_size = padding_size;
138 self
139 }
140
141 pub fn max_past_epochs(mut self, max_past_epochs: usize) -> Self {
143 self.join_config.max_past_epochs = max_past_epochs;
144 self
145 }
146
147 pub fn number_of_resumption_psks(mut self, number_of_resumption_psks: usize) -> Self {
149 self.join_config.number_of_resumption_psks = number_of_resumption_psks;
150 self
151 }
152
153 pub fn use_ratchet_tree_extension(mut self, use_ratchet_tree_extension: bool) -> Self {
155 self.join_config.use_ratchet_tree_extension = use_ratchet_tree_extension;
156 self
157 }
158
159 pub fn sender_ratchet_configuration(
161 mut self,
162 sender_ratchet_configuration: SenderRatchetConfiguration,
163 ) -> Self {
164 self.join_config.sender_ratchet_configuration = sender_ratchet_configuration;
165 self
166 }
167
168 pub fn build(self) -> MlsGroupJoinConfig {
170 self.join_config
171 }
172}
173
174impl MlsGroupCreateConfig {
175 pub fn builder() -> MlsGroupCreateConfigBuilder {
177 MlsGroupCreateConfigBuilder::new()
178 }
179
180 pub fn wire_format_policy(&self) -> WireFormatPolicy {
182 self.join_config.wire_format_policy
183 }
184
185 pub fn padding_size(&self) -> usize {
187 self.join_config.padding_size
188 }
189
190 pub fn max_past_epochs(&self) -> usize {
192 self.join_config.max_past_epochs
193 }
194
195 pub fn number_of_resumption_psks(&self) -> usize {
197 self.join_config.number_of_resumption_psks
198 }
199
200 pub fn use_ratchet_tree_extension(&self) -> bool {
202 self.join_config.use_ratchet_tree_extension
203 }
204
205 pub fn sender_ratchet_configuration(&self) -> &SenderRatchetConfiguration {
207 &self.join_config.sender_ratchet_configuration
208 }
209
210 pub fn group_context_extensions(&self) -> &Extensions<GroupContext> {
214 &self.group_context_extensions
215 }
216
217 pub fn lifetime(&self) -> &Lifetime {
219 &self.lifetime
220 }
221
222 pub fn ciphersuite(&self) -> Ciphersuite {
224 self.ciphersuite
225 }
226
227 #[cfg(any(feature = "test-utils", test))]
228 pub fn test_default(ciphersuite: Ciphersuite) -> Self {
229 Self::builder()
230 .wire_format_policy(WireFormatPolicy::new(
231 OutgoingWireFormatPolicy::AlwaysPlaintext,
232 IncomingWireFormatPolicy::Mixed,
233 ))
234 .ciphersuite(ciphersuite)
235 .build()
236 }
237
238 pub fn join_config(&self) -> &MlsGroupJoinConfig {
240 &self.join_config
241 }
242}
243
244#[derive(Default, Debug)]
246pub struct MlsGroupCreateConfigBuilder {
247 config: MlsGroupCreateConfig,
248}
249
250impl MlsGroupCreateConfigBuilder {
251 fn new() -> Self {
253 MlsGroupCreateConfigBuilder {
254 config: MlsGroupCreateConfig::default(),
255 }
256 }
257
258 pub fn wire_format_policy(mut self, wire_format_policy: WireFormatPolicy) -> Self {
260 self.config.join_config.wire_format_policy = wire_format_policy;
261 self
262 }
263
264 pub fn padding_size(mut self, padding_size: usize) -> Self {
266 self.config.join_config.padding_size = padding_size;
267 self
268 }
269
270 pub fn max_past_epochs(mut self, max_past_epochs: usize) -> Self {
281 self.config.join_config.max_past_epochs = max_past_epochs;
282 self
283 }
284
285 pub fn number_of_resumption_psks(mut self, number_of_resumption_psks: usize) -> Self {
287 self.config.join_config.number_of_resumption_psks = number_of_resumption_psks;
288 self
289 }
290
291 pub fn use_ratchet_tree_extension(mut self, use_ratchet_tree_extension: bool) -> Self {
293 self.config.join_config.use_ratchet_tree_extension = use_ratchet_tree_extension;
294 self
295 }
296
297 pub fn capabilities(mut self, capabilities: Capabilities) -> Self {
299 self.config.capabilities = capabilities;
300 self
301 }
302
303 pub fn sender_ratchet_configuration(
306 mut self,
307 sender_ratchet_configuration: SenderRatchetConfiguration,
308 ) -> Self {
309 self.config.join_config.sender_ratchet_configuration = sender_ratchet_configuration;
310 self
311 }
312
313 pub fn lifetime(mut self, lifetime: Lifetime) -> Self {
315 self.config.lifetime = lifetime;
316 self
317 }
318
319 pub fn ciphersuite(mut self, ciphersuite: Ciphersuite) -> Self {
321 self.config.ciphersuite = ciphersuite;
322 self
323 }
324
325 pub fn with_group_context_extensions(mut self, extensions: Extensions<GroupContext>) -> Self {
327 self.config.group_context_extensions = extensions;
328 self
329 }
330
331 pub fn with_leaf_node_extensions(
335 mut self,
336 extensions: Extensions<LeafNode>,
337 ) -> Result<Self, LeafNodeValidationError> {
338 if !self.config.capabilities.contains_extensions(&extensions) {
342 return Err(LeafNodeValidationError::ExtensionsNotInCapabilities);
343 }
344
345 self.config.leaf_node_extensions = extensions;
347 Ok(self)
348 }
349
350 pub fn build(self) -> MlsGroupCreateConfig {
352 self.config
353 }
354}
355
356#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
359pub enum IncomingWireFormatPolicy {
360 AlwaysCiphertext,
362 AlwaysPlaintext,
364 Mixed,
366}
367
368impl IncomingWireFormatPolicy {
369 pub(crate) fn is_compatible_with(&self, wire_format: WireFormat) -> bool {
370 match self {
371 IncomingWireFormatPolicy::AlwaysCiphertext => wire_format == WireFormat::PrivateMessage,
372 IncomingWireFormatPolicy::AlwaysPlaintext => wire_format == WireFormat::PublicMessage,
373 IncomingWireFormatPolicy::Mixed => {
374 wire_format == WireFormat::PrivateMessage
375 || wire_format == WireFormat::PublicMessage
376 }
377 }
378 }
379}
380
381#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
384pub enum OutgoingWireFormatPolicy {
385 AlwaysCiphertext,
387 AlwaysPlaintext,
389}
390
391#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
394pub struct WireFormatPolicy {
395 outgoing: OutgoingWireFormatPolicy,
396 incoming: IncomingWireFormatPolicy,
397}
398
399impl WireFormatPolicy {
400 #[cfg(any(feature = "test-utils", test))]
403 pub(crate) fn new(
404 outgoing: OutgoingWireFormatPolicy,
405 incoming: IncomingWireFormatPolicy,
406 ) -> Self {
407 Self { outgoing, incoming }
408 }
409
410 pub fn outgoing(&self) -> OutgoingWireFormatPolicy {
412 self.outgoing
413 }
414
415 pub fn incoming(&self) -> IncomingWireFormatPolicy {
417 self.incoming
418 }
419}
420
421impl Default for WireFormatPolicy {
422 fn default() -> Self {
423 PURE_CIPHERTEXT_WIRE_FORMAT_POLICY
424 }
425}
426
427impl From<OutgoingWireFormatPolicy> for WireFormat {
428 fn from(outgoing: OutgoingWireFormatPolicy) -> Self {
429 match outgoing {
430 OutgoingWireFormatPolicy::AlwaysCiphertext => WireFormat::PrivateMessage,
431 OutgoingWireFormatPolicy::AlwaysPlaintext => WireFormat::PublicMessage,
432 }
433 }
434}
435
436pub const WIRE_FORMAT_POLICIES: [WireFormatPolicy; 4] = [
442 PURE_PLAINTEXT_WIRE_FORMAT_POLICY,
443 PURE_CIPHERTEXT_WIRE_FORMAT_POLICY,
444 MIXED_PLAINTEXT_WIRE_FORMAT_POLICY,
445 MIXED_CIPHERTEXT_WIRE_FORMAT_POLICY,
446];
447
448pub const PURE_PLAINTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
450 outgoing: OutgoingWireFormatPolicy::AlwaysPlaintext,
451 incoming: IncomingWireFormatPolicy::AlwaysPlaintext,
452};
453
454pub const PURE_CIPHERTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
456 outgoing: OutgoingWireFormatPolicy::AlwaysCiphertext,
457 incoming: IncomingWireFormatPolicy::AlwaysCiphertext,
458};
459
460pub const MIXED_PLAINTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
463 outgoing: OutgoingWireFormatPolicy::AlwaysPlaintext,
464 incoming: IncomingWireFormatPolicy::Mixed,
465};
466
467pub const MIXED_CIPHERTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
470 outgoing: OutgoingWireFormatPolicy::AlwaysCiphertext,
471 incoming: IncomingWireFormatPolicy::Mixed,
472};