Revision control
Copy as Markdown
Other Tools
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
use crate::client::MlsError;
use crate::extension::ExternalPubExt;
use crate::group::{GroupContext, MembershipTag};
use crate::psk::secret::PskSecret;
#[cfg(feature = "psk")]
use crate::psk::PreSharedKey;
use crate::tree_kem::path_secret::PathSecret;
use crate::CipherSuiteProvider;
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
use crate::group::SecretTree;
use alloc::vec;
use alloc::vec::Vec;
use core::fmt::{self, Debug};
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use mls_rs_core::error::IntoAnyError;
use zeroize::Zeroizing;
use crate::crypto::{HpkeContextR, HpkeContextS, HpkePublicKey, HpkeSecretKey};
use super::epoch::{EpochSecrets, SenderDataSecret};
use super::message_signature::AuthenticatedContent;
#[derive(Clone, PartialEq, Eq, Default, MlsEncode, MlsDecode, MlsSize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct KeySchedule {
#[mls_codec(with = "mls_rs_codec::byte_vec")]
#[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
exporter_secret: Zeroizing<Vec<u8>>,
#[mls_codec(with = "mls_rs_codec::byte_vec")]
#[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
pub authentication_secret: Zeroizing<Vec<u8>>,
#[mls_codec(with = "mls_rs_codec::byte_vec")]
#[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
external_secret: Zeroizing<Vec<u8>>,
#[mls_codec(with = "mls_rs_codec::byte_vec")]
#[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
membership_key: Zeroizing<Vec<u8>>,
init_secret: InitSecret,
}
impl Debug for KeySchedule {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("KeySchedule")
.field(
"exporter_secret",
&mls_rs_core::debug::pretty_bytes(&self.exporter_secret),
)
.field(
"authentication_secret",
&mls_rs_core::debug::pretty_bytes(&self.authentication_secret),
)
.field(
"external_secret",
&mls_rs_core::debug::pretty_bytes(&self.external_secret),
)
.field(
"membership_key",
&mls_rs_core::debug::pretty_bytes(&self.membership_key),
)
.field("init_secret", &self.init_secret)
.finish()
}
}
pub(crate) struct KeyScheduleDerivationResult {
pub(crate) key_schedule: KeySchedule,
pub(crate) confirmation_key: Zeroizing<Vec<u8>>,
pub(crate) joiner_secret: JoinerSecret,
pub(crate) epoch_secrets: EpochSecrets,
}
impl KeySchedule {
pub fn new(init_secret: InitSecret) -> Self {
KeySchedule {
init_secret,
..Default::default()
}
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn derive_for_external<P: CipherSuiteProvider>(
&self,
kem_output: &[u8],
cipher_suite: &P,
) -> Result<KeySchedule, MlsError> {
let (secret, public) = self.get_external_key_pair(cipher_suite).await?;
let init_secret =
InitSecret::decode_for_external(cipher_suite, kem_output, &secret, &public).await?;
Ok(KeySchedule::new(init_secret))
}
/// Returns the derived epoch as well as the joiner secret required for building welcome
/// messages
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn from_key_schedule<P: CipherSuiteProvider>(
last_key_schedule: &KeySchedule,
commit_secret: &PathSecret,
context: &GroupContext,
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
secret_tree_size: u32,
psk_secret: &PskSecret,
cipher_suite_provider: &P,
) -> Result<KeyScheduleDerivationResult, MlsError> {
let joiner_seed = cipher_suite_provider
.kdf_extract(&last_key_schedule.init_secret.0, commit_secret)
.await
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
let joiner_secret = kdf_expand_with_label(
cipher_suite_provider,
&joiner_seed,
b"joiner",
&context.mls_encode_to_vec()?,
None,
)
.await?
.into();
let key_schedule_result = Self::from_joiner(
cipher_suite_provider,
&joiner_secret,
context,
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
secret_tree_size,
psk_secret,
)
.await?;
Ok(KeyScheduleDerivationResult {
key_schedule: key_schedule_result.key_schedule,
confirmation_key: key_schedule_result.confirmation_key,
joiner_secret,
epoch_secrets: key_schedule_result.epoch_secrets,
})
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn from_joiner<P: CipherSuiteProvider>(
cipher_suite_provider: &P,
joiner_secret: &JoinerSecret,
context: &GroupContext,
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
secret_tree_size: u32,
psk_secret: &PskSecret,
) -> Result<KeyScheduleDerivationResult, MlsError> {
let epoch_seed =
get_pre_epoch_secret(cipher_suite_provider, psk_secret, joiner_secret).await?;
let context = context.mls_encode_to_vec()?;
let epoch_secret =
kdf_expand_with_label(cipher_suite_provider, &epoch_seed, b"epoch", &context, None)
.await?;
Self::from_epoch_secret(
cipher_suite_provider,
&epoch_secret,
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
secret_tree_size,
)
.await
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn from_random_epoch_secret<P: CipherSuiteProvider>(
cipher_suite_provider: &P,
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
secret_tree_size: u32,
) -> Result<KeyScheduleDerivationResult, MlsError> {
let epoch_secret = cipher_suite_provider
.random_bytes_vec(cipher_suite_provider.kdf_extract_size())
.map(Zeroizing::new)
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
Self::from_epoch_secret(
cipher_suite_provider,
&epoch_secret,
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
secret_tree_size,
)
.await
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn from_epoch_secret<P: CipherSuiteProvider>(
cipher_suite_provider: &P,
epoch_secret: &[u8],
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
secret_tree_size: u32,
) -> Result<KeyScheduleDerivationResult, MlsError> {
let secrets_producer = SecretsProducer::new(cipher_suite_provider, epoch_secret);
let epoch_secrets = EpochSecrets {
#[cfg(feature = "psk")]
resumption_secret: PreSharedKey::from(secrets_producer.derive(b"resumption").await?),
sender_data_secret: SenderDataSecret::from(
secrets_producer.derive(b"sender data").await?,
),
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
secret_tree: SecretTree::new(
secret_tree_size,
secrets_producer.derive(b"encryption").await?,
),
};
let key_schedule = Self {
exporter_secret: secrets_producer.derive(b"exporter").await?,
authentication_secret: secrets_producer.derive(b"authentication").await?,
external_secret: secrets_producer.derive(b"external").await?,
membership_key: secrets_producer.derive(b"membership").await?,
init_secret: InitSecret(secrets_producer.derive(b"init").await?),
};
Ok(KeyScheduleDerivationResult {
key_schedule,
confirmation_key: secrets_producer.derive(b"confirm").await?,
joiner_secret: Zeroizing::new(vec![]).into(),
epoch_secrets,
})
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn export_secret<P: CipherSuiteProvider>(
&self,
label: &[u8],
context: &[u8],
len: usize,
cipher_suite: &P,
) -> Result<Zeroizing<Vec<u8>>, MlsError> {
let secret = kdf_derive_secret(cipher_suite, &self.exporter_secret, label).await?;
let context_hash = cipher_suite
.hash(context)
.await
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
kdf_expand_with_label(cipher_suite, &secret, b"exported", &context_hash, Some(len)).await
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn get_membership_tag<P: CipherSuiteProvider>(
&self,
content: &AuthenticatedContent,
context: &GroupContext,
cipher_suite_provider: &P,
) -> Result<MembershipTag, MlsError> {
MembershipTag::create(
content,
context,
&self.membership_key,
cipher_suite_provider,
)
.await
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn get_external_key_pair<P: CipherSuiteProvider>(
&self,
cipher_suite: &P,
) -> Result<(HpkeSecretKey, HpkePublicKey), MlsError> {
cipher_suite
.kem_derive(&self.external_secret)
.await
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn get_external_key_pair_ext<P: CipherSuiteProvider>(
&self,
cipher_suite: &P,
) -> Result<ExternalPubExt, MlsError> {
let (_external_secret, external_pub) = self.get_external_key_pair(cipher_suite).await?;
Ok(ExternalPubExt { external_pub })
}
}
#[derive(MlsEncode, MlsSize)]
struct Label<'a> {
length: u16,
#[mls_codec(with = "mls_rs_codec::byte_vec")]
label: Vec<u8>,
#[mls_codec(with = "mls_rs_codec::byte_vec")]
context: &'a [u8],
}
impl<'a> Label<'a> {
fn new(length: u16, label: &'a [u8], context: &'a [u8]) -> Self {
Self {
length,
label: [b"MLS 1.0 ", label].concat(),
context,
}
}
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn kdf_expand_with_label<P: CipherSuiteProvider>(
cipher_suite_provider: &P,
secret: &[u8],
label: &[u8],
context: &[u8],
len: Option<usize>,
) -> Result<Zeroizing<Vec<u8>>, MlsError> {
let extract_size = cipher_suite_provider.kdf_extract_size();
let len = len.unwrap_or(extract_size);
let label = Label::new(len as u16, label, context);
cipher_suite_provider
.kdf_expand(secret, &label.mls_encode_to_vec()?, len)
.await
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn kdf_derive_secret<P: CipherSuiteProvider>(
cipher_suite_provider: &P,
secret: &[u8],
label: &[u8],
) -> Result<Zeroizing<Vec<u8>>, MlsError> {
kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None).await
}
#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
pub(crate) struct JoinerSecret(#[mls_codec(with = "mls_rs_codec::byte_vec")] Zeroizing<Vec<u8>>);
impl Debug for JoinerSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
mls_rs_core::debug::pretty_bytes(&self.0)
.named("JoinerSecret")
.fmt(f)
}
}
impl From<Zeroizing<Vec<u8>>> for JoinerSecret {
fn from(bytes: Zeroizing<Vec<u8>>) -> Self {
Self(bytes)
}
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn get_pre_epoch_secret<P: CipherSuiteProvider>(
cipher_suite_provider: &P,
psk_secret: &PskSecret,
joiner_secret: &JoinerSecret,
) -> Result<Zeroizing<Vec<u8>>, MlsError> {
cipher_suite_provider
.kdf_extract(&joiner_secret.0, psk_secret)
.await
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
}
struct SecretsProducer<'a, P: CipherSuiteProvider> {
cipher_suite_provider: &'a P,
epoch_secret: &'a [u8],
}
impl<'a, P: CipherSuiteProvider> SecretsProducer<'a, P> {
fn new(cipher_suite_provider: &'a P, epoch_secret: &'a [u8]) -> Self {
Self {
cipher_suite_provider,
epoch_secret,
}
}
// TODO document somewhere in the crypto provider that the RFC defines the length of all secrets as
// KDF extract size but then inputs secrets as MAC keys etc, therefore, we require that these
// lengths match in the crypto provider
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn derive(&self, label: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError> {
kdf_derive_secret(self.cipher_suite_provider, self.epoch_secret, label).await
}
}
const EXPORTER_CONTEXT: &[u8] = b"MLS 1.0 external init secret";
#[derive(Clone, Eq, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct InitSecret(
#[mls_codec(with = "mls_rs_codec::byte_vec")]
#[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
Zeroizing<Vec<u8>>,
);
impl Debug for InitSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
mls_rs_core::debug::pretty_bytes(&self.0)
.named("InitSecret")
.fmt(f)
}
}
impl InitSecret {
/// Returns init secret and KEM output to be used when creating an external commit.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn encode_for_external<P: CipherSuiteProvider>(
cipher_suite: &P,
external_pub: &HpkePublicKey,
) -> Result<(Self, Vec<u8>), MlsError> {
let (kem_output, context) = cipher_suite
.hpke_setup_s(external_pub, &[])
.await
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
let init_secret = context
.export(EXPORTER_CONTEXT, cipher_suite.kdf_extract_size())
.await
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
Ok((InitSecret(Zeroizing::new(init_secret)), kem_output))
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn decode_for_external<P: CipherSuiteProvider>(
cipher_suite: &P,
kem_output: &[u8],
external_secret: &HpkeSecretKey,
external_pub: &HpkePublicKey,
) -> Result<Self, MlsError> {
let context = cipher_suite
.hpke_setup_r(kem_output, external_secret, external_pub, &[])
.await
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
context
.export(EXPORTER_CONTEXT, cipher_suite.kdf_extract_size())
.await
.map(Zeroizing::new)
.map(InitSecret)
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
}
}
pub(crate) struct WelcomeSecret<'a, P: CipherSuiteProvider> {
cipher_suite: &'a P,
key: Zeroizing<Vec<u8>>,
nonce: Zeroizing<Vec<u8>>,
}
impl<'a, P: CipherSuiteProvider> WelcomeSecret<'a, P> {
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn from_joiner_secret(
cipher_suite: &'a P,
joiner_secret: &JoinerSecret,
psk_secret: &PskSecret,
) -> Result<WelcomeSecret<'a, P>, MlsError> {
let welcome_secret = get_welcome_secret(cipher_suite, joiner_secret, psk_secret).await?;
let key_len = cipher_suite.aead_key_size();
let key = kdf_expand_with_label(cipher_suite, &welcome_secret, b"key", &[], Some(key_len))
.await?;
let nonce_len = cipher_suite.aead_nonce_size();
let nonce = kdf_expand_with_label(
cipher_suite,
&welcome_secret,
b"nonce",
&[],
Some(nonce_len),
)
.await?;
Ok(Self {
cipher_suite,
key,
nonce,
})
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, MlsError> {
self.cipher_suite
.aead_seal(&self.key, plaintext, None, &self.nonce)
.await
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn decrypt(&self, ciphertext: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError> {
self.cipher_suite
.aead_open(&self.key, ciphertext, None, &self.nonce)
.await
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
}
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn get_welcome_secret<P: CipherSuiteProvider>(
cipher_suite: &P,
joiner_secret: &JoinerSecret,
psk_secret: &PskSecret,
) -> Result<Zeroizing<Vec<u8>>, MlsError> {
let epoch_seed = get_pre_epoch_secret(cipher_suite, psk_secret, joiner_secret).await?;
kdf_derive_secret(cipher_suite, &epoch_seed, b"welcome").await
}
#[cfg(test)]
pub(crate) mod test_utils {
use alloc::vec;
use alloc::vec::Vec;
use mls_rs_core::crypto::CipherSuiteProvider;
use zeroize::Zeroizing;
use crate::{cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider};
use super::{InitSecret, JoinerSecret, KeySchedule};
#[cfg(all(feature = "rfc_compliant", not(mls_build_async)))]
use mls_rs_core::error::IntoAnyError;
#[cfg(all(feature = "rfc_compliant", not(mls_build_async)))]
use super::MlsError;
impl From<JoinerSecret> for Vec<u8> {
fn from(mut value: JoinerSecret) -> Self {
core::mem::take(&mut value.0)
}
}
pub(crate) fn get_test_key_schedule(cipher_suite: CipherSuite) -> KeySchedule {
let key_size = test_cipher_suite_provider(cipher_suite).kdf_extract_size();
let fake_secret = Zeroizing::new(vec![1u8; key_size]);
KeySchedule {
exporter_secret: fake_secret.clone(),
authentication_secret: fake_secret.clone(),
external_secret: fake_secret.clone(),
membership_key: fake_secret,
init_secret: InitSecret::new(vec![0u8; key_size]),
}
}
impl InitSecret {
pub fn new(init_secret: Vec<u8>) -> Self {
InitSecret(Zeroizing::new(init_secret))
}
#[cfg(all(feature = "rfc_compliant", test, not(mls_build_async)))]
#[cfg_attr(coverage_nightly, coverage(off))]
pub fn random<P: CipherSuiteProvider>(cipher_suite: &P) -> Result<Self, MlsError> {
cipher_suite
.random_bytes_vec(cipher_suite.kdf_extract_size())
.map(Zeroizing::new)
.map(InitSecret)
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
}
}
#[cfg(feature = "rfc_compliant")]
impl KeySchedule {
pub fn set_membership_key(&mut self, key: Vec<u8>) {
self.membership_key = Zeroizing::new(key)
}
}
}
#[cfg(test)]
mod tests {
use crate::client::test_utils::TEST_PROTOCOL_VERSION;
use crate::crypto::test_utils::try_test_cipher_suite_provider;
use crate::group::key_schedule::{
get_welcome_secret, kdf_derive_secret, kdf_expand_with_label,
};
use crate::group::GroupContext;
use alloc::string::String;
use alloc::vec::Vec;
use mls_rs_codec::MlsEncode;
use mls_rs_core::crypto::CipherSuiteProvider;
use mls_rs_core::extension::ExtensionList;
#[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
use crate::{
crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
group::{
key_schedule::KeyScheduleDerivationResult, test_utils::random_bytes, InitSecret,
PskSecret,
},
};
#[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
use alloc::{string::ToString, vec};
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test as test;
use zeroize::Zeroizing;
use super::test_utils::get_test_key_schedule;
use super::KeySchedule;
#[derive(serde::Deserialize, serde::Serialize)]
struct TestCase {
cipher_suite: u16,
#[serde(with = "hex::serde")]
group_id: Vec<u8>,
#[serde(with = "hex::serde")]
initial_init_secret: Vec<u8>,
epochs: Vec<KeyScheduleEpoch>,
}
#[derive(serde::Deserialize, serde::Serialize)]
struct KeyScheduleEpoch {
#[serde(with = "hex::serde")]
commit_secret: Vec<u8>,
#[serde(with = "hex::serde")]
psk_secret: Vec<u8>,
#[serde(with = "hex::serde")]
confirmed_transcript_hash: Vec<u8>,
#[serde(with = "hex::serde")]
tree_hash: Vec<u8>,
#[serde(with = "hex::serde")]
group_context: Vec<u8>,
#[serde(with = "hex::serde")]
joiner_secret: Vec<u8>,
#[serde(with = "hex::serde")]
welcome_secret: Vec<u8>,
#[serde(with = "hex::serde")]
init_secret: Vec<u8>,
#[serde(with = "hex::serde")]
sender_data_secret: Vec<u8>,
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
#[serde(with = "hex::serde")]
encryption_secret: Vec<u8>,
#[serde(with = "hex::serde")]
exporter_secret: Vec<u8>,
#[serde(with = "hex::serde")]
epoch_authenticator: Vec<u8>,
#[serde(with = "hex::serde")]
external_secret: Vec<u8>,
#[serde(with = "hex::serde")]
confirmation_key: Vec<u8>,
#[serde(with = "hex::serde")]
membership_key: Vec<u8>,
#[cfg(feature = "psk")]
#[serde(with = "hex::serde")]
resumption_psk: Vec<u8>,
#[serde(with = "hex::serde")]
external_pub: Vec<u8>,
exporter: KeyScheduleExporter,
}
#[derive(serde::Deserialize, serde::Serialize)]
struct KeyScheduleExporter {
label: String,
#[serde(with = "hex::serde")]
context: Vec<u8>,
length: usize,
#[serde(with = "hex::serde")]
secret: Vec<u8>,
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_key_schedule() {
let test_cases: Vec<TestCase> =
load_test_case_json!(key_schedule_test_vector, generate_test_vector());
for test_case in test_cases {
let Some(cs_provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
continue;
};
let mut key_schedule = get_test_key_schedule(cs_provider.cipher_suite());
key_schedule.init_secret.0 = Zeroizing::new(test_case.initial_init_secret);
for (i, epoch) in test_case.epochs.into_iter().enumerate() {
let context = GroupContext {
protocol_version: TEST_PROTOCOL_VERSION,
cipher_suite: cs_provider.cipher_suite(),
group_id: test_case.group_id.clone(),
epoch: i as u64,
tree_hash: epoch.tree_hash,
confirmed_transcript_hash: epoch.confirmed_transcript_hash.into(),
extensions: ExtensionList::new(),
};
assert_eq!(context.mls_encode_to_vec().unwrap(), epoch.group_context);
let psk = epoch.psk_secret.into();
let commit = epoch.commit_secret.into();
let key_schedule_res = KeySchedule::from_key_schedule(
&key_schedule,
&commit,
&context,
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
32,
&psk,
&cs_provider,
)
.await
.unwrap();
key_schedule = key_schedule_res.key_schedule;
let welcome =
get_welcome_secret(&cs_provider, &key_schedule_res.joiner_secret, &psk)
.await
.unwrap();
assert_eq!(*welcome, epoch.welcome_secret);
let expected: Vec<u8> = key_schedule_res.joiner_secret.into();
assert_eq!(epoch.joiner_secret, expected);
assert_eq!(&key_schedule.init_secret.0.to_vec(), &epoch.init_secret);
assert_eq!(
epoch.sender_data_secret,
*key_schedule_res.epoch_secrets.sender_data_secret.to_vec()
);
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
assert_eq!(
epoch.encryption_secret,
*key_schedule_res.epoch_secrets.secret_tree.get_root_secret()
);
assert_eq!(epoch.exporter_secret, key_schedule.exporter_secret.to_vec());
assert_eq!(
epoch.epoch_authenticator,
key_schedule.authentication_secret.to_vec()
);
assert_eq!(epoch.external_secret, key_schedule.external_secret.to_vec());
assert_eq!(
epoch.confirmation_key,
key_schedule_res.confirmation_key.to_vec()
);
assert_eq!(epoch.membership_key, key_schedule.membership_key.to_vec());
#[cfg(feature = "psk")]
{
let expected: Vec<u8> =
key_schedule_res.epoch_secrets.resumption_secret.to_vec();
assert_eq!(epoch.resumption_psk, expected);
}
let (_external_sec, external_pub) = key_schedule
.get_external_key_pair(&cs_provider)
.await
.unwrap();
assert_eq!(epoch.external_pub, *external_pub);
let exp = epoch.exporter;
let exported = key_schedule
.export_secret(exp.label.as_bytes(), &exp.context, exp.length, &cs_provider)
.await
.unwrap();
assert_eq!(exported.to_vec(), exp.secret);
}
}
}
#[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
#[cfg_attr(coverage_nightly, coverage(off))]
fn generate_test_vector() -> Vec<TestCase> {
let mut test_cases = vec![];
for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
let cs_provider = test_cipher_suite_provider(cipher_suite);
let key_size = cs_provider.kdf_extract_size();
let mut group_context = GroupContext {
protocol_version: TEST_PROTOCOL_VERSION,
cipher_suite: cs_provider.cipher_suite(),
group_id: b"my group 5".to_vec(),
epoch: 0,
tree_hash: random_bytes(key_size),
confirmed_transcript_hash: random_bytes(key_size).into(),
extensions: Default::default(),
};
let initial_init_secret = InitSecret::random(&cs_provider).unwrap();
let mut key_schedule = get_test_key_schedule(cs_provider.cipher_suite());
key_schedule.init_secret = initial_init_secret.clone();
let commit_secret = random_bytes(key_size).into();
let psk_secret = PskSecret::new(&cs_provider);
let key_schedule_res = KeySchedule::from_key_schedule(
&key_schedule,
&commit_secret,
&group_context,
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
32,
&psk_secret,
&cs_provider,
)
.unwrap();
key_schedule = key_schedule_res.key_schedule.clone();
let epoch1 = KeyScheduleEpoch::new(
key_schedule_res,
psk_secret,
commit_secret.to_vec(),
&group_context,
&cs_provider,
);
group_context.epoch += 1;
group_context.confirmed_transcript_hash = random_bytes(key_size).into();
group_context.tree_hash = random_bytes(key_size);
let commit_secret = random_bytes(key_size).into();
let psk_secret = PskSecret::new(&cs_provider);
let key_schedule_res = KeySchedule::from_key_schedule(
&key_schedule,
&commit_secret,
&group_context,
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
32,
&psk_secret,
&cs_provider,
)
.unwrap();
let epoch2 = KeyScheduleEpoch::new(
key_schedule_res,
psk_secret,
commit_secret.to_vec(),
&group_context,
&cs_provider,
);
let test_case = TestCase {
cipher_suite: cs_provider.cipher_suite().into(),
group_id: group_context.group_id.clone(),
initial_init_secret: initial_init_secret.0.to_vec(),
epochs: vec![epoch1, epoch2],
};
test_cases.push(test_case);
}
test_cases
}
#[cfg(not(all(not(mls_build_async), feature = "rfc_compliant")))]
fn generate_test_vector() -> Vec<TestCase> {
panic!("Tests cannot be generated in async mode");
}
#[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
impl KeyScheduleEpoch {
#[cfg_attr(coverage_nightly, coverage(off))]
fn new<P: CipherSuiteProvider>(
key_schedule_res: KeyScheduleDerivationResult,
psk_secret: PskSecret,
commit_secret: Vec<u8>,
group_context: &GroupContext,
cs: &P,
) -> Self {
let (_external_sec, external_pub) = key_schedule_res
.key_schedule
.get_external_key_pair(cs)
.unwrap();
let mut exporter = KeyScheduleExporter {
label: "exporter label 15".to_string(),
context: b"exporter context".to_vec(),
length: 64,
secret: vec![],
};
exporter.secret = key_schedule_res
.key_schedule
.export_secret(
exporter.label.as_bytes(),
&exporter.context,
exporter.length,
cs,
)
.unwrap()
.to_vec();
let welcome_secret =
get_welcome_secret(cs, &key_schedule_res.joiner_secret, &psk_secret)
.unwrap()
.to_vec();
KeyScheduleEpoch {
commit_secret,
welcome_secret,
psk_secret: psk_secret.to_vec(),
group_context: group_context.mls_encode_to_vec().unwrap(),
joiner_secret: key_schedule_res.joiner_secret.into(),
init_secret: key_schedule_res.key_schedule.init_secret.0.to_vec(),
sender_data_secret: key_schedule_res.epoch_secrets.sender_data_secret.to_vec(),
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
encryption_secret: key_schedule_res.epoch_secrets.secret_tree.get_root_secret(),
exporter_secret: key_schedule_res.key_schedule.exporter_secret.to_vec(),
epoch_authenticator: key_schedule_res.key_schedule.authentication_secret.to_vec(),
external_secret: key_schedule_res.key_schedule.external_secret.to_vec(),
confirmation_key: key_schedule_res.confirmation_key.to_vec(),
membership_key: key_schedule_res.key_schedule.membership_key.to_vec(),
#[cfg(feature = "psk")]
resumption_psk: key_schedule_res.epoch_secrets.resumption_secret.to_vec(),
external_pub: external_pub.to_vec(),
exporter,
confirmed_transcript_hash: group_context.confirmed_transcript_hash.to_vec(),
tree_hash: group_context.tree_hash.clone(),
}
}
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct ExpandWithLabelTestCase {
#[serde(with = "hex::serde")]
secret: Vec<u8>,
label: String,
#[serde(with = "hex::serde")]
context: Vec<u8>,
length: usize,
#[serde(with = "hex::serde")]
out: Vec<u8>,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct DeriveSecretTestCase {
#[serde(with = "hex::serde")]
secret: Vec<u8>,
label: String,
#[serde(with = "hex::serde")]
out: Vec<u8>,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct InteropTestCase {
cipher_suite: u16,
expand_with_label: ExpandWithLabelTestCase,
derive_secret: DeriveSecretTestCase,
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_basic_crypto_test_vectors() {
// The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json
let test_cases: Vec<InteropTestCase> =
load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
for test_case in test_cases {
if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
let test_exp = &test_case.expand_with_label;
let computed = kdf_expand_with_label(
&cs,
&test_exp.secret,
test_exp.label.as_bytes(),
&test_exp.context,
Some(test_exp.length),
)
.await
.unwrap();
assert_eq!(&computed.to_vec(), &test_exp.out);
let test_derive = &test_case.derive_secret;
let computed =
kdf_derive_secret(&cs, &test_derive.secret, test_derive.label.as_bytes())
.await
.unwrap();
assert_eq!(&computed.to_vec(), &test_derive.out);
}
}
}
}