Revision control
Copy as Markdown
Other Tools
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
use alloc::vec;
use alloc::vec::Vec;
use core::{
fmt::{self, Debug},
ops::Deref,
};
use mls_rs_core::crypto::CipherSuiteProvider;
use zeroize::Zeroizing;
#[cfg(feature = "psk")]
use mls_rs_codec::MlsEncode;
#[cfg(feature = "psk")]
use mls_rs_core::{error::IntoAnyError, psk::PreSharedKey};
#[cfg(feature = "psk")]
use crate::{
client::MlsError,
group::key_schedule::kdf_expand_with_label,
psk::{PSKLabel, PreSharedKeyID},
};
#[cfg(feature = "psk")]
#[derive(Clone)]
pub(crate) struct PskSecretInput {
pub id: PreSharedKeyID,
pub psk: PreSharedKey,
}
#[derive(PartialEq, Eq, Clone)]
pub(crate) struct PskSecret(Zeroizing<Vec<u8>>);
impl Debug for PskSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
mls_rs_core::debug::pretty_bytes(&self.0)
.named("PskSecret")
.fmt(f)
}
}
#[cfg(test)]
impl From<Vec<u8>> for PskSecret {
fn from(value: Vec<u8>) -> Self {
PskSecret(Zeroizing::new(value))
}
}
impl Deref for PskSecret {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl PskSecret {
pub(crate) fn new<P: CipherSuiteProvider>(provider: &P) -> PskSecret {
PskSecret(Zeroizing::new(vec![0u8; provider.kdf_extract_size()]))
}
#[cfg(feature = "psk")]
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn calculate<P: CipherSuiteProvider>(
input: &[PskSecretInput],
cipher_suite_provider: &P,
) -> Result<PskSecret, MlsError> {
let len = u16::try_from(input.len()).map_err(|_| MlsError::TooManyPskIds)?;
let mut psk_secret = PskSecret::new(cipher_suite_provider);
for (index, psk_secret_input) in input.iter().enumerate() {
let index = index as u16;
let label = PSKLabel {
id: &psk_secret_input.id,
index,
count: len,
};
let psk_extracted = cipher_suite_provider
.kdf_extract(
&vec![0; cipher_suite_provider.kdf_extract_size()],
&psk_secret_input.psk,
)
.await
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
let psk_input = kdf_expand_with_label(
cipher_suite_provider,
&psk_extracted,
b"derived psk",
&label.mls_encode_to_vec()?,
None,
)
.await?;
psk_secret = cipher_suite_provider
.kdf_extract(&psk_input, &psk_secret)
.await
.map(PskSecret)
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
}
Ok(psk_secret)
}
}
#[cfg(feature = "psk")]
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
#[cfg(not(mls_build_async))]
use core::iter;
use serde::{Deserialize, Serialize};
use crate::{
crypto::test_utils::try_test_cipher_suite_provider,
psk::ExternalPskId,
psk::{JustPreSharedKeyID, PreSharedKeyID, PskNonce},
CipherSuiteProvider,
};
#[cfg(not(mls_build_async))]
use crate::{
crypto::test_utils::test_cipher_suite_provider, psk::test_utils::make_external_psk_id,
CipherSuite,
};
use super::{PskSecret, PskSecretInput};
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
struct PskInfo {
#[serde(with = "hex::serde")]
id: Vec<u8>,
#[serde(with = "hex::serde")]
psk: Vec<u8>,
#[serde(with = "hex::serde")]
nonce: Vec<u8>,
}
impl From<PskInfo> for PskSecretInput {
fn from(info: PskInfo) -> Self {
let id = PreSharedKeyID {
key_id: JustPreSharedKeyID::External(ExternalPskId::new(info.id)),
psk_nonce: PskNonce(info.nonce),
};
PskSecretInput {
id,
psk: info.psk.into(),
}
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
struct TestScenario {
cipher_suite: u16,
psks: Vec<PskInfo>,
#[serde(with = "hex::serde")]
psk_secret: Vec<u8>,
}
impl TestScenario {
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(not(mls_build_async))]
fn make_psk_list<CS: CipherSuiteProvider>(cs: &CS, n: usize) -> Vec<PskInfo> {
iter::repeat_with(
#[cfg_attr(coverage_nightly, coverage(off))]
|| PskInfo {
id: make_external_psk_id(cs).to_vec(),
psk: cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(),
nonce: crate::psk::test_utils::make_nonce(cs.cipher_suite()).0,
},
)
.take(n)
.collect::<Vec<_>>()
}
#[cfg(not(mls_build_async))]
#[cfg_attr(coverage_nightly, coverage(off))]
fn generate() -> Vec<TestScenario> {
CipherSuite::all()
.flat_map(
#[cfg_attr(coverage_nightly, coverage(off))]
|cs| (1..=10).map(move |n| (cs, n)),
)
.map(
#[cfg_attr(coverage_nightly, coverage(off))]
|(cs, n)| {
let provider = test_cipher_suite_provider(cs);
let psks = Self::make_psk_list(&provider, n);
let psk_secret = Self::compute_psk_secret(&provider, psks.clone());
TestScenario {
cipher_suite: cs.into(),
psks: psks.to_vec(),
psk_secret: psk_secret.to_vec(),
}
},
)
.collect()
}
#[cfg(mls_build_async)]
fn generate() -> Vec<TestScenario> {
panic!("Tests cannot be generated in async mode");
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn compute_psk_secret<P: CipherSuiteProvider>(
provider: &P,
psks: Vec<PskInfo>,
) -> PskSecret {
let input = psks
.into_iter()
.map(PskSecretInput::from)
.collect::<Vec<_>>();
PskSecret::calculate(&input, provider).await.unwrap()
}
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn expected_psk_secret_is_produced() {
let scenarios: Vec<TestScenario> =
load_test_case_json!(psk_secret, TestScenario::generate());
for scenario in scenarios {
if let Some(provider) = try_test_cipher_suite_provider(scenario.cipher_suite) {
let computed =
TestScenario::compute_psk_secret(&provider, scenario.psks.clone()).await;
assert_eq!(scenario.psk_secret, computed.to_vec());
}
}
}
}