1use super::*;
31use crate::{
32 extensions::errors::InvalidExtensionError,
33 key_packages::Lifetime,
34 tree::sender_ratchet::SenderRatchetConfiguration,
35 treesync::{errors::LeafNodeValidationError, node::leaf_node::Capabilities},
36};
37use serde::{Deserialize, Serialize};
38
39#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
44pub struct MlsGroupJoinConfig {
45 pub(crate) wire_format_policy: WireFormatPolicy,
48 pub(crate) padding_size: usize,
50 pub(crate) max_past_epochs: usize,
53 pub(crate) number_of_resumption_psks: usize,
55 pub(crate) use_ratchet_tree_extension: bool,
57 pub(crate) sender_ratchet_configuration: SenderRatchetConfiguration,
59}
60
61impl MlsGroupJoinConfig {
62 pub fn builder() -> MlsGroupJoinConfigBuilder {
64 MlsGroupJoinConfigBuilder::new()
65 }
66
67 pub fn wire_format_policy(&self) -> WireFormatPolicy {
69 self.wire_format_policy
70 }
71
72 pub fn padding_size(&self) -> usize {
74 self.padding_size
75 }
76
77 pub fn sender_ratchet_configuration(&self) -> &SenderRatchetConfiguration {
79 &self.sender_ratchet_configuration
80 }
81}
82
83#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
87pub struct MlsGroupCreateConfig {
88 pub(crate) capabilities: Capabilities,
90 pub(crate) lifetime: Lifetime,
92 pub(crate) ciphersuite: Ciphersuite,
94 pub(crate) join_config: MlsGroupJoinConfig,
96 pub(crate) group_context_extensions: Extensions,
98 pub(crate) leaf_node_extensions: Extensions,
100}
101
102impl Default for MlsGroupCreateConfig {
103 fn default() -> Self {
104 Self {
105 capabilities: Capabilities::default(),
106 lifetime: Lifetime::default(),
107 ciphersuite: Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519,
108 join_config: MlsGroupJoinConfig::default(),
109 group_context_extensions: Extensions::default(),
110 leaf_node_extensions: Extensions::default(),
111 }
112 }
113}
114
115#[derive(Default)]
117pub struct MlsGroupJoinConfigBuilder {
118 join_config: MlsGroupJoinConfig,
119}
120
121impl MlsGroupJoinConfigBuilder {
122 fn new() -> Self {
124 Self {
125 join_config: MlsGroupJoinConfig::default(),
126 }
127 }
128
129 pub fn wire_format_policy(mut self, wire_format_policy: WireFormatPolicy) -> Self {
131 self.join_config.wire_format_policy = wire_format_policy;
132 self
133 }
134
135 pub fn padding_size(mut self, padding_size: usize) -> Self {
137 self.join_config.padding_size = padding_size;
138 self
139 }
140
141 pub fn max_past_epochs(mut self, max_past_epochs: usize) -> Self {
143 self.join_config.max_past_epochs = max_past_epochs;
144 self
145 }
146
147 pub fn number_of_resumption_psks(mut self, number_of_resumption_psks: usize) -> Self {
149 self.join_config.number_of_resumption_psks = number_of_resumption_psks;
150 self
151 }
152
153 pub fn use_ratchet_tree_extension(mut self, use_ratchet_tree_extension: bool) -> Self {
155 self.join_config.use_ratchet_tree_extension = use_ratchet_tree_extension;
156 self
157 }
158
159 pub fn sender_ratchet_configuration(
161 mut self,
162 sender_ratchet_configuration: SenderRatchetConfiguration,
163 ) -> Self {
164 self.join_config.sender_ratchet_configuration = sender_ratchet_configuration;
165 self
166 }
167
168 pub fn build(self) -> MlsGroupJoinConfig {
170 self.join_config
171 }
172}
173
174impl MlsGroupCreateConfig {
175 pub fn builder() -> MlsGroupCreateConfigBuilder {
177 MlsGroupCreateConfigBuilder::new()
178 }
179
180 pub fn wire_format_policy(&self) -> WireFormatPolicy {
182 self.join_config.wire_format_policy
183 }
184
185 pub fn padding_size(&self) -> usize {
187 self.join_config.padding_size
188 }
189
190 pub fn max_past_epochs(&self) -> usize {
192 self.join_config.max_past_epochs
193 }
194
195 pub fn number_of_resumption_psks(&self) -> usize {
197 self.join_config.number_of_resumption_psks
198 }
199
200 pub fn use_ratchet_tree_extension(&self) -> bool {
202 self.join_config.use_ratchet_tree_extension
203 }
204
205 pub fn sender_ratchet_configuration(&self) -> &SenderRatchetConfiguration {
207 &self.join_config.sender_ratchet_configuration
208 }
209
210 pub fn group_context_extensions(&self) -> &Extensions {
214 &self.group_context_extensions
215 }
216
217 pub fn lifetime(&self) -> &Lifetime {
219 &self.lifetime
220 }
221
222 pub fn ciphersuite(&self) -> Ciphersuite {
224 self.ciphersuite
225 }
226
227 #[cfg(any(feature = "test-utils", test))]
228 pub fn test_default(ciphersuite: Ciphersuite) -> Self {
229 Self::builder()
230 .wire_format_policy(WireFormatPolicy::new(
231 OutgoingWireFormatPolicy::AlwaysPlaintext,
232 IncomingWireFormatPolicy::Mixed,
233 ))
234 .ciphersuite(ciphersuite)
235 .build()
236 }
237
238 pub fn join_config(&self) -> &MlsGroupJoinConfig {
240 &self.join_config
241 }
242}
243
244#[derive(Default, Debug)]
246pub struct MlsGroupCreateConfigBuilder {
247 config: MlsGroupCreateConfig,
248}
249
250impl MlsGroupCreateConfigBuilder {
251 fn new() -> Self {
253 MlsGroupCreateConfigBuilder {
254 config: MlsGroupCreateConfig::default(),
255 }
256 }
257
258 pub fn wire_format_policy(mut self, wire_format_policy: WireFormatPolicy) -> Self {
260 self.config.join_config.wire_format_policy = wire_format_policy;
261 self
262 }
263
264 pub fn padding_size(mut self, padding_size: usize) -> Self {
266 self.config.join_config.padding_size = padding_size;
267 self
268 }
269
270 pub fn max_past_epochs(mut self, max_past_epochs: usize) -> Self {
281 self.config.join_config.max_past_epochs = max_past_epochs;
282 self
283 }
284
285 pub fn number_of_resumption_psks(mut self, number_of_resumption_psks: usize) -> Self {
287 self.config.join_config.number_of_resumption_psks = number_of_resumption_psks;
288 self
289 }
290
291 pub fn use_ratchet_tree_extension(mut self, use_ratchet_tree_extension: bool) -> Self {
293 self.config.join_config.use_ratchet_tree_extension = use_ratchet_tree_extension;
294 self
295 }
296
297 pub fn capabilities(mut self, capabilities: Capabilities) -> Self {
299 self.config.capabilities = capabilities;
300 self
301 }
302
303 pub fn sender_ratchet_configuration(
306 mut self,
307 sender_ratchet_configuration: SenderRatchetConfiguration,
308 ) -> Self {
309 self.config.join_config.sender_ratchet_configuration = sender_ratchet_configuration;
310 self
311 }
312
313 pub fn lifetime(mut self, lifetime: Lifetime) -> Self {
315 self.config.lifetime = lifetime;
316 self
317 }
318
319 pub fn ciphersuite(mut self, ciphersuite: Ciphersuite) -> Self {
321 self.config.ciphersuite = ciphersuite;
322 self
323 }
324
325 pub fn with_group_context_extensions(
327 mut self,
328 extensions: Extensions,
329 ) -> Result<Self, InvalidExtensionError> {
330 let is_valid_in_group_context = extensions.application_id().is_none()
331 && extensions.ratchet_tree().is_none()
332 && extensions.external_pub().is_none();
333 if !is_valid_in_group_context {
334 return Err(InvalidExtensionError::IllegalInGroupContext);
335 }
336 self.config.group_context_extensions = extensions;
337 Ok(self)
338 }
339
340 pub fn with_leaf_node_extensions(
344 mut self,
345 extensions: Extensions,
346 ) -> Result<Self, LeafNodeValidationError> {
347 if extensions.validate_extension_types_for_leaf_node().is_err() {
350 log::error!("Invalid leaf node extension.");
351 return Err(LeafNodeValidationError::UnsupportedExtensions);
352 }
353
354 if !self.config.capabilities.contains_extensions(&extensions) {
358 return Err(LeafNodeValidationError::ExtensionsNotInCapabilities);
359 }
360
361 self.config.leaf_node_extensions = extensions;
362 Ok(self)
363 }
364
365 pub fn build(self) -> MlsGroupCreateConfig {
367 self.config
368 }
369}
370
371#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
374pub enum IncomingWireFormatPolicy {
375 AlwaysCiphertext,
377 AlwaysPlaintext,
379 Mixed,
381}
382
383impl IncomingWireFormatPolicy {
384 pub(crate) fn is_compatible_with(&self, wire_format: WireFormat) -> bool {
385 match self {
386 IncomingWireFormatPolicy::AlwaysCiphertext => wire_format == WireFormat::PrivateMessage,
387 IncomingWireFormatPolicy::AlwaysPlaintext => wire_format == WireFormat::PublicMessage,
388 IncomingWireFormatPolicy::Mixed => {
389 wire_format == WireFormat::PrivateMessage
390 || wire_format == WireFormat::PublicMessage
391 }
392 }
393 }
394}
395
396#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
399pub enum OutgoingWireFormatPolicy {
400 AlwaysCiphertext,
402 AlwaysPlaintext,
404}
405
406#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
409pub struct WireFormatPolicy {
410 outgoing: OutgoingWireFormatPolicy,
411 incoming: IncomingWireFormatPolicy,
412}
413
414impl WireFormatPolicy {
415 #[cfg(any(feature = "test-utils", test))]
418 pub(crate) fn new(
419 outgoing: OutgoingWireFormatPolicy,
420 incoming: IncomingWireFormatPolicy,
421 ) -> Self {
422 Self { outgoing, incoming }
423 }
424
425 pub fn outgoing(&self) -> OutgoingWireFormatPolicy {
427 self.outgoing
428 }
429
430 pub fn incoming(&self) -> IncomingWireFormatPolicy {
432 self.incoming
433 }
434}
435
436impl Default for WireFormatPolicy {
437 fn default() -> Self {
438 PURE_CIPHERTEXT_WIRE_FORMAT_POLICY
439 }
440}
441
442impl From<OutgoingWireFormatPolicy> for WireFormat {
443 fn from(outgoing: OutgoingWireFormatPolicy) -> Self {
444 match outgoing {
445 OutgoingWireFormatPolicy::AlwaysCiphertext => WireFormat::PrivateMessage,
446 OutgoingWireFormatPolicy::AlwaysPlaintext => WireFormat::PublicMessage,
447 }
448 }
449}
450
451pub const WIRE_FORMAT_POLICIES: [WireFormatPolicy; 4] = [
457 PURE_PLAINTEXT_WIRE_FORMAT_POLICY,
458 PURE_CIPHERTEXT_WIRE_FORMAT_POLICY,
459 MIXED_PLAINTEXT_WIRE_FORMAT_POLICY,
460 MIXED_CIPHERTEXT_WIRE_FORMAT_POLICY,
461];
462
463pub const PURE_PLAINTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
465 outgoing: OutgoingWireFormatPolicy::AlwaysPlaintext,
466 incoming: IncomingWireFormatPolicy::AlwaysPlaintext,
467};
468
469pub const PURE_CIPHERTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
471 outgoing: OutgoingWireFormatPolicy::AlwaysCiphertext,
472 incoming: IncomingWireFormatPolicy::AlwaysCiphertext,
473};
474
475pub const MIXED_PLAINTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
478 outgoing: OutgoingWireFormatPolicy::AlwaysPlaintext,
479 incoming: IncomingWireFormatPolicy::Mixed,
480};
481
482pub const MIXED_CIPHERTEXT_WIRE_FORMAT_POLICY: WireFormatPolicy = WireFormatPolicy {
485 outgoing: OutgoingWireFormatPolicy::AlwaysCiphertext,
486 incoming: IncomingWireFormatPolicy::Mixed,
487};