1use super::*;
31use crate::{
32 extensions::errors::InvalidExtensionError,
33 key_packages::Lifetime,
34 tree::sender_ratchet::SenderRatchetConfiguration,
35 treesync::{errors::LeafNodeValidationError, node::leaf_node::Capabilities},
36};
37use serde::{Deserialize, Serialize};
38
39#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
44pub struct MlsGroupJoinConfig {
45 pub(crate) wire_format_policy: WireFormatPolicy,
48 pub(crate) padding_size: usize,
50 pub(crate) max_past_epochs: usize,
53 pub(crate) number_of_resumption_psks: usize,
55 pub(crate) use_ratchet_tree_extension: bool,
57 pub(crate) sender_ratchet_configuration: SenderRatchetConfiguration,
59}
60
61impl MlsGroupJoinConfig {
62 pub fn builder() -> MlsGroupJoinConfigBuilder {
64 MlsGroupJoinConfigBuilder::new()
65 }
66
67 pub fn wire_format_policy(&self) -> WireFormatPolicy {
69 self.wire_format_policy
70 }
71
72 pub fn padding_size(&self) -> usize {
74 self.padding_size
75 }
76
77 pub fn sender_ratchet_configuration(&self) -> &SenderRatchetConfiguration {
79 &self.sender_ratchet_configuration
80 }
81}
82
83#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
87pub struct MlsGroupCreateConfig {
88 pub(crate) capabilities: Capabilities,
90 pub(crate) lifetime: Lifetime,
92 pub(crate) ciphersuite: Ciphersuite,
94 pub(crate) join_config: MlsGroupJoinConfig,
96 pub(crate) group_context_extensions: Extensions,
98 pub(crate) leaf_node_extensions: Extensions,
100}
101
102impl Default for MlsGroupCreateConfig {
103 fn default() -> Self {
104 Self {
105 capabilities: Capabilities::default(),
106 lifetime: Lifetime::default(),
107 ciphersuite: Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519,
108 join_config: MlsGroupJoinConfig::default(),
109 group_context_extensions: Extensions::default(),
110 leaf_node_extensions: Extensions::default(),
111 }
112 }
113}
114
115#[derive(Default)]
117pub struct MlsGroupJoinConfigBuilder {
118 join_config: MlsGroupJoinConfig,
119}
120
121impl MlsGroupJoinConfigBuilder {
122 fn new() -> Self {
124 Self {
125 join_config: MlsGroupJoinConfig::default(),
126 }
127 }
128
129 pub fn wire_format_policy(mut self, wire_format_policy: WireFormatPolicy) -> Self {
131 self.join_config.wire_format_policy = wire_format_policy;
132 self
133 }
134
135 pub fn padding_size(mut self, padding_size: usize) -> Self {
137 self.join_config.padding_size = padding_size;
138 self
139 }
140
141 pub fn max_past_epochs(mut self, max_past_epochs: usize) -> Self {
143 self.join_config.max_past_epochs = max_past_epochs;
144 self
145 }
146
147 pub fn number_of_resumption_psks(mut self, number_of_resumption_psks: usize) -> Self {
149 self.join_config.number_of_resumption_psks = number_of_resumption_psks;
150 self
151 }
152
153 pub fn use_ratchet_tree_extension(mut self, use_ratchet_tree_extension: bool) -> Self {
155 self.join_config.use_ratchet_tree_extension = use_ratchet_tree_extension;
156 self
157 }
158
159 pub fn sender_ratchet_configuration(
161 mut self,
162 sender_ratchet_configuration: SenderRatchetConfiguration,
163 ) -> Self {
164 self.join_config.sender_ratchet_configuration = sender_ratchet_configuration;
165 self
166 }
167
168 pub fn build(self) -> MlsGroupJoinConfig {
170 self.join_config
171 }
172}
173
174impl MlsGroupCreateConfig {
175 pub fn builder() -> MlsGroupCreateConfigBuilder {
177 MlsGroupCreateConfigBuilder::new()
178 }
179
180 pub fn wire_format_policy(&self) -> WireFormatPolicy {
182 self.join_config.wire_format_policy
183 }
184
185 pub fn padding_size(&self) -> usize {
187 self.join_config.padding_size
188 }
189
190 pub fn max_past_epochs(&self) -> usize {
192 self.join_config.max_past_epochs
193 }
194
195 pub fn number_of_resumption_psks(&self) -> usize {
197 self.join_config.number_of_resumption_psks
198 }
199
200 pub fn use_ratchet_tree_extension(&self) -> bool {
202 self.join_config.use_ratchet_tree_extension
203 }
204
205 pub fn sender_ratchet_configuration(&self) -> &SenderRatchetConfiguration {
207 &self.join_config.sender_ratchet_configuration
208 }
209
210 pub fn group_context_extensions(&self) -> &Extensions {
214 &self.group_context_extensions
215 }
216
217 pub fn lifetime(&self) -> &Lifetime {
219 &self.lifetime
220 }
221
222 pub fn ciphersuite(&self) -> Ciphersuite {
224 self.ciphersuite
225 }
226
227 #[cfg(any(feature = "test-utils", test))]
228 pub fn test_default(ciphersuite: Ciphersuite) -> Self {
229 Self::builder()
230 .wire_format_policy(WireFormatPolicy::new(
231 OutgoingWireFormatPolicy::AlwaysPlaintext,
232 IncomingWireFormatPolicy::Mixed,
233 ))
234 .ciphersuite(ciphersuite)
235 .build()
236 }
237
238 pub fn join_config(&self) -> &MlsGroupJoinConfig {
240 &self.join_config
241 }
242}
243
244#[derive(Default, Debug)]
246pub struct MlsGroupCreateConfigBuilder {
247 config: MlsGroupCreateConfig,
248}
249
250impl MlsGroupCreateConfigBuilder {
251 fn new() -> Self {
253 MlsGroupCreateConfigBuilder {
254 config: MlsGroupCreateConfig::default(),
255 }
256 }
257
258 pub fn wire_format_policy(mut self, wire_format_policy: WireFormatPolicy) -> Self {
260 self.config.join_config.wire_format_policy = wire_format_policy;
261 self
262 }
263
264 pub fn padding_size(mut self, padding_size: usize) -> Self {
266 self.config.join_config.padding_size = padding_size;
267 self
268 }
269
270 pub fn max_past_epochs(mut self, max_past_epochs: usize) -> Self {
281 self.config.join_config.max_past_epochs = max_past_epochs;
282 self
283 }
284
285 pub fn number_of_resumption_psks(mut self, number_of_resumption_psks: usize) -> Self {
287 self.config.join_config.number_of_resumption_psks = number_of_resumption_psks;
288 self
289 }
290
291 pub fn use_ratchet_tree_extension(mut self, use_ratchet_tree_extension: bool) -> Self {
293 self.config.join_config.use_ratchet_tree_extension = use_ratchet_tree_extension;
294 self
295 }
296
297 pub fn capabilities(mut self, capabilities: Capabilities) -> Self {
299 self.config.capabilities = capabilities;
300 self
301 }
302
303 pub fn sender_ratchet_configuration(
306 mut self,
307 sender_ratchet_configuration: SenderRatchetConfiguration,
308 ) -> Self {
309 self.config.join_config.sender_ratchet_configuration = sender_ratchet_configuration;
310 self
311 }
312
313 pub fn lifetime(mut self, lifetime: Lifetime) -> Self {
315 self.config.lifetime = lifetime;
316 self
317 }
318
319 pub fn ciphersuite(mut self, ciphersuite: Ciphersuite) -> Self {
321 self.config.ciphersuite = ciphersuite;
322 self
323 }
324
325 pub fn with_group_context_extensions(
327 mut self,
328 extensions: Extensions,
329 ) -> Result<Self, InvalidExtensionError> {
330 let is_valid_in_group_context = extensions.application_id().is_none()
331 && extensions.ratchet_tree().is_none()
332 && extensions.external_pub().is_none();
333 if !is_valid_in_group_context {
334 return Err(InvalidExtensionError::IllegalInGroupContext);
335 }
336 self.config.group_context_extensions = extensions;
337 Ok(self)
338 }
339
340 pub fn with_leaf_node_extensions(
342 mut self,
343 extensions: Extensions,
344 ) -> Result<Self, LeafNodeValidationError> {
345 let is_valid_in_leaf_node = extensions
348 .iter()
349 .all(|e| matches!(e.extension_type(), ExtensionType::Unknown(_)));
350 if !is_valid_in_leaf_node {
351 log::error!("Leaf node extensions must be unknown extensions.");
352 return Err(LeafNodeValidationError::UnsupportedExtensions);
353 }
354
355 if !self.config.capabilities.contains_extensions(&extensions) {
359 return Err(LeafNodeValidationError::ExtensionsNotInCapabilities);
360 }
361
362 self.config.leaf_node_extensions = extensions;
363 Ok(self)
364 }
365
366 pub fn build(self) -> MlsGroupCreateConfig {
368 self.config
369 }
370}
371
372#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
375pub enum IncomingWireFormatPolicy {
376 AlwaysCiphertext,
378 AlwaysPlaintext,
380 Mixed,
382}
383
384impl IncomingWireFormatPolicy {
385 pub(crate) fn is_compatible_with(&self, wire_format: WireFormat) -> bool {
386 match self {
387 IncomingWireFormatPolicy::AlwaysCiphertext => wire_format == WireFormat::PrivateMessage,
388 IncomingWireFormatPolicy::AlwaysPlaintext => wire_format == WireFormat::PublicMessage,
389 IncomingWireFormatPolicy::Mixed => {
390 wire_format == WireFormat::PrivateMessage
391 || wire_format == WireFormat::PublicMessage
392 }
393 }
394 }
395}
396
397#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
400pub enum OutgoingWireFormatPolicy {
401 AlwaysCiphertext,
403 AlwaysPlaintext,
405}
406
407#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
410pub struct WireFormatPolicy {
411 outgoing: OutgoingWireFormatPolicy,
412 incoming: IncomingWireFormatPolicy,
413}
414
415impl WireFormatPolicy {
416 #[cfg(any(feature = "test-utils", test))]
419 pub(crate) fn new(
420 outgoing: OutgoingWireFormatPolicy,
421 incoming: IncomingWireFormatPolicy,
422 ) -> Self {
423 Self { outgoing, incoming }
424 }
425
426 pub fn outgoing(&self) -> OutgoingWireFormatPolicy {
428 self.outgoing
429 }
430
431 pub fn incoming(&self) -> IncomingWireFormatPolicy {
433 self.incoming
434 }
435}
436
437impl Default for WireFormatPolicy {
438 fn default() -> Self {
439 PURE_CIPHERTEXT_WIRE_FORMAT_POLICY
440 }
441}
442
443impl From<OutgoingWireFormatPolicy> for WireFormat {
444 fn from(outgoing: OutgoingWireFormatPolicy) -> Self {
445 match outgoing {
446 OutgoingWireFormatPolicy::AlwaysCiphertext => WireFormat::PrivateMessage,
447 OutgoingWireFormatPolicy::AlwaysPlaintext => WireFormat::PublicMessage,
448 }
449 }
450}
451
452pub const WIRE_FORMAT_POLICIES: [WireFormatPolicy; 4] = [
458 PURE_PLAINTEXT_WIRE_FORMAT_POLICY,
459 PURE_CIPHERTEXT_WIRE_FORMAT_POLICY,
460 MIXED_PLAINTEXT_WIRE_FORMAT_POLICY,
461 MIXED_CIPHERTEXT_WIRE_FORMAT_POLICY,
462];
463
464pub const PURE_PLAINTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
466 outgoing: OutgoingWireFormatPolicy::AlwaysPlaintext,
467 incoming: IncomingWireFormatPolicy::AlwaysPlaintext,
468};
469
470pub const PURE_CIPHERTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
472 outgoing: OutgoingWireFormatPolicy::AlwaysCiphertext,
473 incoming: IncomingWireFormatPolicy::AlwaysCiphertext,
474};
475
476pub const MIXED_PLAINTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
479 outgoing: OutgoingWireFormatPolicy::AlwaysPlaintext,
480 incoming: IncomingWireFormatPolicy::Mixed,
481};
482
483pub const MIXED_CIPHERTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
486 outgoing: OutgoingWireFormatPolicy::AlwaysCiphertext,
487 incoming: IncomingWireFormatPolicy::Mixed,
488};