Revision control

Copy as Markdown

Other Tools

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
use alloc::vec::Vec;
use core::{
fmt::{self, Debug},
ops::Deref,
};
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use mls_rs_core::{crypto::CipherSuiteProvider, error::IntoAnyError};
use crate::{
client::MlsError,
group::{framing::FramedContent, MessageSignature},
WireFormat,
};
use super::{AuthenticatedContent, ConfirmationTag};
#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ConfirmedTranscriptHash(
#[mls_codec(with = "mls_rs_codec::byte_vec")]
#[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
Vec<u8>,
);
impl Debug for ConfirmedTranscriptHash {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
mls_rs_core::debug::pretty_bytes(&self.0)
.named("ConfirmedTranscriptHash")
.fmt(f)
}
}
impl Deref for ConfirmedTranscriptHash {
type Target = Vec<u8>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Vec<u8>> for ConfirmedTranscriptHash {
fn from(value: Vec<u8>) -> Self {
Self(value)
}
}
impl ConfirmedTranscriptHash {
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn create<P: CipherSuiteProvider>(
cipher_suite_provider: &P,
interim_transcript_hash: &InterimTranscriptHash,
content: &AuthenticatedContent,
) -> Result<Self, MlsError> {
#[derive(Debug, MlsSize, MlsEncode)]
struct ConfirmedTranscriptHashInput<'a> {
wire_format: WireFormat,
content: &'a FramedContent,
signature: &'a MessageSignature,
}
let input = ConfirmedTranscriptHashInput {
wire_format: content.wire_format,
content: &content.content,
signature: &content.auth.signature,
};
let hash_input = [
interim_transcript_hash.deref(),
input.mls_encode_to_vec()?.deref(),
]
.concat();
cipher_suite_provider
.hash(&hash_input)
.await
.map(Into::into)
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
}
}
#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct InterimTranscriptHash(
#[mls_codec(with = "mls_rs_codec::byte_vec")]
#[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
Vec<u8>,
);
impl Debug for InterimTranscriptHash {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
mls_rs_core::debug::pretty_bytes(&self.0)
.named("InterimTranscriptHash")
.fmt(f)
}
}
impl Deref for InterimTranscriptHash {
type Target = Vec<u8>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Vec<u8>> for InterimTranscriptHash {
fn from(value: Vec<u8>) -> Self {
Self(value)
}
}
impl InterimTranscriptHash {
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn create<P: CipherSuiteProvider>(
cipher_suite_provider: &P,
confirmed: &ConfirmedTranscriptHash,
confirmation_tag: &ConfirmationTag,
) -> Result<Self, MlsError> {
#[derive(Debug, MlsSize, MlsEncode)]
struct InterimTranscriptHashInput<'a> {
confirmation_tag: &'a ConfirmationTag,
}
let input = InterimTranscriptHashInput { confirmation_tag }.mls_encode_to_vec()?;
cipher_suite_provider
.hash(&[confirmed.0.deref(), &input].concat())
.await
.map(Into::into)
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
}
}
// Test vectors come from the MLS interop repository and contain a proposal by reference.
#[cfg(feature = "by_ref_proposal")]
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use mls_rs_codec::MlsDecode;
use crate::{
crypto::test_utils::try_test_cipher_suite_provider,
group::{framing::ContentType, message_signature::AuthenticatedContent, transcript_hashes},
};
#[cfg(not(mls_build_async))]
use alloc::{boxed::Box, vec};
#[cfg(not(mls_build_async))]
use crate::{
crypto::test_utils::test_cipher_suite_provider,
group::{
confirmation_tag::ConfirmationTag,
framing::Content,
proposal::{Proposal, ProposalOrRef, RemoveProposal},
test_utils::get_test_group_context,
Commit, LeafIndex, Sender,
},
mls_rs_codec::MlsEncode,
CipherSuite, CipherSuiteProvider, WireFormat,
};
#[cfg(not(mls_build_async))]
use super::{ConfirmedTranscriptHash, InterimTranscriptHash};
#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
struct TestCase {
pub cipher_suite: u16,
#[serde(with = "hex::serde")]
pub confirmation_key: Vec<u8>,
#[serde(with = "hex::serde")]
pub authenticated_content: Vec<u8>,
#[serde(with = "hex::serde")]
pub interim_transcript_hash_before: Vec<u8>,
#[serde(with = "hex::serde")]
pub confirmed_transcript_hash_after: Vec<u8>,
#[serde(with = "hex::serde")]
pub interim_transcript_hash_after: Vec<u8>,
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn transcript_hash() {
let test_cases: Vec<TestCase> =
load_test_case_json!(interop_transcript_hashes, generate_test_vector());
for test_case in test_cases.into_iter() {
let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
continue;
};
let auth_content =
AuthenticatedContent::mls_decode(&mut &*test_case.authenticated_content).unwrap();
assert!(auth_content.content.content_type() == ContentType::Commit);
let conf_key = &test_case.confirmation_key;
let conf_hash_after = test_case.confirmed_transcript_hash_after.into();
let conf_tag = auth_content.auth.confirmation_tag.clone().unwrap();
let matches = conf_tag
.matches(conf_key, &conf_hash_after, &cs)
.await
.unwrap();
assert!(matches);
let (expected_interim, expected_conf) = transcript_hashes(
&cs,
&test_case.interim_transcript_hash_before.into(),
&auth_content,
)
.await
.unwrap();
assert_eq!(*expected_interim, test_case.interim_transcript_hash_after);
assert_eq!(expected_conf, conf_hash_after);
}
}
#[cfg(not(mls_build_async))]
#[cfg_attr(coverage_nightly, coverage(off))]
fn generate_test_vector() -> Vec<TestCase> {
CipherSuite::all().fold(vec![], |mut test_cases, cs| {
let cs = test_cipher_suite_provider(cs);
let context = get_test_group_context(0x3456, cs.cipher_suite());
let proposal = Proposal::Remove(RemoveProposal {
to_remove: LeafIndex(1),
});
let proposal = ProposalOrRef::Proposal(Box::new(proposal));
let commit = Commit {
proposals: vec![proposal],
path: None,
};
let signer = cs.signature_key_generate().unwrap().0;
let mut auth_content = AuthenticatedContent::new_signed(
&cs,
&context,
Sender::Member(0),
Content::Commit(alloc::boxed::Box::new(commit)),
&signer,
WireFormat::PublicMessage,
vec![],
)
.unwrap();
let interim_hash_before = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap().into();
let conf_hash_after =
ConfirmedTranscriptHash::create(&cs, &interim_hash_before, &auth_content).unwrap();
let conf_key = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
let conf_tag = ConfirmationTag::create(&conf_key, &conf_hash_after, &cs).unwrap();
let interim_hash_after =
InterimTranscriptHash::create(&cs, &conf_hash_after, &conf_tag).unwrap();
auth_content.auth.confirmation_tag = Some(conf_tag);
let test_case = TestCase {
cipher_suite: cs.cipher_suite().into(),
confirmation_key: conf_key,
authenticated_content: auth_content.mls_encode_to_vec().unwrap(),
interim_transcript_hash_before: interim_hash_before.0,
confirmed_transcript_hash_after: conf_hash_after.0,
interim_transcript_hash_after: interim_hash_after.0,
};
test_cases.push(test_case);
test_cases
})
}
#[cfg(mls_build_async)]
fn generate_test_vector() -> Vec<TestCase> {
panic!("Tests cannot be generated in async mode");
}
}