Revision control
Copy as Markdown
Other Tools
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
use core::ops::{Deref, DerefMut};
use alloc::format;
use rand::RngCore;
use super::*;
use crate::{
client::{
test_utils::{
test_client_with_key_pkg, test_client_with_key_pkg_custom, TEST_CIPHER_SUITE,
TEST_PROTOCOL_VERSION,
},
MlsError,
},
client_builder::test_utils::{TestClientBuilder, TestClientConfig},
crypto::test_utils::test_cipher_suite_provider,
extension::ExtensionType,
identity::basic::BasicIdentityProvider,
identity::test_utils::get_test_signing_identity,
key_package::{KeyPackageGeneration, KeyPackageGenerator},
mls_rules::{CommitOptions, DefaultMlsRules},
tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
};
use crate::extension::RequiredCapabilitiesExt;
#[cfg(not(feature = "by_ref_proposal"))]
use crate::crypto::HpkePublicKey;
pub const TEST_GROUP: &[u8] = b"group";
#[derive(Clone)]
pub(crate) struct TestGroup {
pub group: Group<TestClientConfig>,
}
impl TestGroup {
#[cfg(feature = "external_client")]
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn propose(&mut self, proposal: Proposal) -> MlsMessage {
self.group.proposal_message(proposal, vec![]).await.unwrap()
}
#[cfg(feature = "by_ref_proposal")]
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn update_proposal(&mut self) -> Proposal {
self.group.update_proposal(None, None).await.unwrap()
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn join_with_custom_config<F>(
&mut self,
name: &str,
custom_kp: bool,
mut config: F,
) -> Result<(TestGroup, MlsMessage), MlsError>
where
F: FnMut(&mut TestClientConfig),
{
let (mut new_client, new_key_package) = if custom_kp {
test_client_with_key_pkg_custom(
self.group.protocol_version(),
self.group.cipher_suite(),
name,
&mut config,
)
.await
} else {
test_client_with_key_pkg(
self.group.protocol_version(),
self.group.cipher_suite(),
name,
)
.await
};
// Add new member to the group
let CommitOutput {
welcome_messages,
ratchet_tree,
commit_message,
..
} = self
.group
.commit_builder()
.add_member(new_key_package)
.unwrap()
.build()
.await
.unwrap();
// Apply the commit to the original group
self.group.apply_pending_commit().await.unwrap();
config(&mut new_client.config);
// Group from new member's perspective
let (new_group, _) = Group::join(
&welcome_messages[0],
ratchet_tree,
new_client.config.clone(),
new_client.signer.clone().unwrap(),
)
.await?;
let new_test_group = TestGroup { group: new_group };
Ok((new_test_group, commit_message))
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn join(&mut self, name: &str) -> (TestGroup, MlsMessage) {
self.join_with_custom_config(name, false, |_| ())
.await
.unwrap()
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn process_pending_commit(
&mut self,
) -> Result<CommitMessageDescription, MlsError> {
self.group.apply_pending_commit().await
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn process_message(
&mut self,
message: MlsMessage,
) -> Result<ReceivedMessage, MlsError> {
self.group.process_incoming_message(message).await
}
#[cfg(feature = "private_message")]
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn make_plaintext(&mut self, content: Content) -> MlsMessage {
let auth_content = AuthenticatedContent::new_signed(
&self.group.cipher_suite_provider,
&self.group.state.context,
Sender::Member(*self.group.private_tree.self_index),
content,
&self.group.signer,
WireFormat::PublicMessage,
Vec::new(),
)
.await
.unwrap();
self.group.format_for_wire(auth_content).await.unwrap()
}
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn get_test_group_context(epoch: u64, cipher_suite: CipherSuite) -> GroupContext {
let cs = test_cipher_suite_provider(cipher_suite);
GroupContext {
protocol_version: TEST_PROTOCOL_VERSION,
cipher_suite,
group_id: TEST_GROUP.to_vec(),
epoch,
tree_hash: cs.hash(&[1, 2, 3]).await.unwrap(),
confirmed_transcript_hash: cs.hash(&[3, 2, 1]).await.unwrap().into(),
extensions: ExtensionList::from(vec![]),
}
}
#[cfg(feature = "prior_epoch")]
pub(crate) fn get_test_group_context_with_id(
group_id: Vec<u8>,
epoch: u64,
cipher_suite: CipherSuite,
) -> GroupContext {
GroupContext {
protocol_version: TEST_PROTOCOL_VERSION,
cipher_suite,
group_id,
epoch,
tree_hash: vec![],
confirmed_transcript_hash: ConfirmedTranscriptHash::from(vec![]),
extensions: ExtensionList::from(vec![]),
}
}
pub(crate) fn group_extensions() -> ExtensionList {
let required_capabilities = RequiredCapabilitiesExt::default();
let mut extensions = ExtensionList::new();
extensions.set_from(required_capabilities).unwrap();
extensions
}
pub(crate) fn lifetime() -> Lifetime {
Lifetime::years(1).unwrap()
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn test_member(
protocol_version: ProtocolVersion,
cipher_suite: CipherSuite,
identifier: &[u8],
) -> (KeyPackageGeneration, SignatureSecretKey) {
let (signing_identity, signing_key) = get_test_signing_identity(cipher_suite, identifier).await;
let key_package_generator = KeyPackageGenerator {
protocol_version,
cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
signing_identity: &signing_identity,
signing_key: &signing_key,
identity_provider: &BasicIdentityProvider,
};
let key_package = key_package_generator
.generate(
lifetime(),
get_test_capabilities(),
ExtensionList::default(),
ExtensionList::default(),
)
.await
.unwrap();
(key_package, signing_key)
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn test_group_custom(
protocol_version: ProtocolVersion,
cipher_suite: CipherSuite,
extension_types: Vec<ExtensionType>,
leaf_extensions: Option<ExtensionList>,
commit_options: Option<CommitOptions>,
) -> TestGroup {
let leaf_extensions = leaf_extensions.unwrap_or_default();
let commit_options = commit_options.unwrap_or_default();
let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
let group = TestClientBuilder::new_for_test()
.leaf_node_extensions(leaf_extensions)
.mls_rules(DefaultMlsRules::default().with_commit_options(commit_options))
.extension_types(extension_types)
.protocol_versions(ProtocolVersion::all())
.used_protocol_version(protocol_version)
.signing_identity(signing_identity.clone(), secret_key, cipher_suite)
.build()
.create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
.await
.unwrap();
TestGroup { group }
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn test_group(
protocol_version: ProtocolVersion,
cipher_suite: CipherSuite,
) -> TestGroup {
test_group_custom(
protocol_version,
cipher_suite,
Default::default(),
None,
None,
)
.await
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn test_group_custom_config<F>(
protocol_version: ProtocolVersion,
cipher_suite: CipherSuite,
custom: F,
) -> TestGroup
where
F: FnOnce(TestClientBuilder) -> TestClientBuilder,
{
let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
let client_builder = TestClientBuilder::new_for_test().used_protocol_version(protocol_version);
let group = custom(client_builder)
.signing_identity(signing_identity.clone(), secret_key, cipher_suite)
.build()
.create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
.await
.unwrap();
TestGroup { group }
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn test_n_member_group(
protocol_version: ProtocolVersion,
cipher_suite: CipherSuite,
num_members: usize,
) -> Vec<TestGroup> {
let group = test_group(protocol_version, cipher_suite).await;
let mut groups = vec![group];
for i in 1..num_members {
let (new_group, commit) = groups.get_mut(0).unwrap().join(&format!("name {i}")).await;
process_commit(&mut groups, commit, 0).await;
groups.push(new_group);
}
groups
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn process_commit(groups: &mut [TestGroup], commit: MlsMessage, excluded: u32) {
for g in groups
.iter_mut()
.filter(|g| g.group.current_member_index() != excluded)
{
g.process_message(commit.clone()).await.unwrap();
}
}
pub(crate) fn get_test_25519_key(key_byte: u8) -> HpkePublicKey {
vec![key_byte; 32].into()
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn get_test_groups_with_features(
n: usize,
extensions: ExtensionList,
leaf_extensions: ExtensionList,
) -> Vec<Group<TestClientConfig>> {
let mut clients = Vec::new();
for i in 0..n {
let (identity, secret_key) =
get_test_signing_identity(TEST_CIPHER_SUITE, format!("member{i}").as_bytes()).await;
clients.push(
TestClientBuilder::new_for_test()
.extension_type(999.into())
.leaf_node_extensions(leaf_extensions.clone())
.signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
.build(),
);
}
let group = clients[0]
.create_group_with_id(b"TEST GROUP".to_vec(), extensions)
.await
.unwrap();
let mut groups = vec![group];
for client in clients.iter().skip(1) {
let key_package = client.generate_key_package_message().await.unwrap();
let commit_output = groups[0]
.commit_builder()
.add_member(key_package)
.unwrap()
.build()
.await
.unwrap();
groups[0].apply_pending_commit().await.unwrap();
for group in groups.iter_mut().skip(1) {
group
.process_incoming_message(commit_output.commit_message.clone())
.await
.unwrap();
}
groups.push(
client
.join_group(None, &commit_output.welcome_messages[0])
.await
.unwrap()
.0,
);
}
groups
}
pub fn random_bytes(count: usize) -> Vec<u8> {
let mut buf = vec![0; count];
rand::thread_rng().fill_bytes(&mut buf);
buf
}
pub(crate) struct GroupWithoutKeySchedule {
inner: Group<TestClientConfig>,
pub secrets: Option<(TreeKemPrivate, PathSecret)>,
pub provisional_public_state: Option<ProvisionalState>,
}
impl Deref for GroupWithoutKeySchedule {
type Target = Group<TestClientConfig>;
#[cfg_attr(coverage_nightly, coverage(off))]
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for GroupWithoutKeySchedule {
#[cfg_attr(coverage_nightly, coverage(off))]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
#[cfg(feature = "rfc_compliant")]
impl GroupWithoutKeySchedule {
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn new(cs: CipherSuite) -> Self {
Self {
inner: test_group(TEST_PROTOCOL_VERSION, cs).await.group,
secrets: None,
provisional_public_state: None,
}
}
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
#[cfg_attr(
all(not(target_arch = "wasm32"), mls_build_async),
maybe_async::must_be_async
)]
impl MessageProcessor for GroupWithoutKeySchedule {
type CipherSuiteProvider = <Group<TestClientConfig> as MessageProcessor>::CipherSuiteProvider;
type OutputType = <Group<TestClientConfig> as MessageProcessor>::OutputType;
type PreSharedKeyStorage = <Group<TestClientConfig> as MessageProcessor>::PreSharedKeyStorage;
type IdentityProvider = <Group<TestClientConfig> as MessageProcessor>::IdentityProvider;
type MlsRules = <Group<TestClientConfig> as MessageProcessor>::MlsRules;
fn group_state(&self) -> &GroupState {
self.inner.group_state()
}
#[cfg_attr(coverage_nightly, coverage(off))]
fn group_state_mut(&mut self) -> &mut GroupState {
self.inner.group_state_mut()
}
fn mls_rules(&self) -> Self::MlsRules {
self.inner.mls_rules()
}
fn identity_provider(&self) -> Self::IdentityProvider {
self.inner.identity_provider()
}
fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider {
self.inner.cipher_suite_provider()
}
fn psk_storage(&self) -> Self::PreSharedKeyStorage {
self.inner.psk_storage()
}
fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool {
self.inner.can_continue_processing(provisional_state)
}
#[cfg(feature = "private_message")]
#[cfg_attr(coverage_nightly, coverage(off))]
fn min_epoch_available(&self) -> Option<u64> {
self.inner.min_epoch_available()
}
async fn apply_update_path(
&mut self,
sender: LeafIndex,
update_path: &ValidatedUpdatePath,
provisional_state: &mut ProvisionalState,
) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
self.inner
.apply_update_path(sender, update_path, provisional_state)
.await
}
#[cfg(feature = "private_message")]
#[cfg_attr(coverage_nightly, coverage(off))]
async fn process_ciphertext(
&mut self,
cipher_text: &PrivateMessage,
) -> Result<EventOrContent<Self::OutputType>, MlsError> {
self.inner.process_ciphertext(cipher_text).await
}
#[cfg_attr(coverage_nightly, coverage(off))]
async fn verify_plaintext_authentication(
&self,
message: PublicMessage,
) -> Result<EventOrContent<Self::OutputType>, MlsError> {
self.inner.verify_plaintext_authentication(message).await
}
async fn update_key_schedule(
&mut self,
secrets: Option<(TreeKemPrivate, PathSecret)>,
_interim_transcript_hash: InterimTranscriptHash,
_confirmation_tag: &ConfirmationTag,
provisional_public_state: ProvisionalState,
) -> Result<(), MlsError> {
self.provisional_public_state = Some(provisional_public_state);
self.secrets = secrets;
Ok(())
}
}