Revision control
Copy as Markdown
Other Tools
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
use crate::cipher_suite::CipherSuite;
use crate::client::MlsError;
use crate::crypto::HpkePublicKey;
use crate::hash_reference::HashReference;
use crate::identity::SigningIdentity;
use crate::protocol_version::ProtocolVersion;
use crate::signer::Signable;
use crate::tree_kem::leaf_node::{LeafNode, LeafNodeSource};
use crate::CipherSuiteProvider;
use alloc::vec::Vec;
use core::{
fmt::{self, Debug},
ops::Deref,
};
use mls_rs_codec::MlsDecode;
use mls_rs_codec::MlsEncode;
use mls_rs_codec::MlsSize;
use mls_rs_core::extension::ExtensionList;
mod validator;
pub(crate) use validator::*;
pub(crate) mod generator;
pub(crate) use generator::*;
#[non_exhaustive]
#[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
// #[cfg_attr(
// all(feature = "ffi", not(test)),
// safer_ffi_gen::ffi_type(clone, opaque)
// )]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct KeyPackage {
pub version: ProtocolVersion,
pub cipher_suite: CipherSuite,
pub hpke_init_key: HpkePublicKey,
pub(crate) leaf_node: LeafNode,
pub extensions: ExtensionList,
#[mls_codec(with = "mls_rs_codec::byte_vec")]
#[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
pub signature: Vec<u8>,
}
impl Debug for KeyPackage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("KeyPackage")
.field("version", &self.version)
.field("cipher_suite", &self.cipher_suite)
.field("hpke_init_key", &self.hpke_init_key)
.field("leaf_node", &self.leaf_node)
.field("extensions", &self.extensions)
.field(
"signature",
&mls_rs_core::debug::pretty_bytes(&self.signature),
)
.finish()
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
// #[cfg_attr(
// all(feature = "ffi", not(test)),
// safer_ffi_gen::ffi_type(clone, opaque)
// )]
pub struct KeyPackageRef(HashReference);
impl Deref for KeyPackageRef {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Vec<u8>> for KeyPackageRef {
fn from(v: Vec<u8>) -> Self {
Self(HashReference::from(v))
}
}
#[derive(MlsSize, MlsEncode)]
struct KeyPackageData<'a> {
pub version: ProtocolVersion,
pub cipher_suite: CipherSuite,
#[mls_codec(with = "mls_rs_codec::byte_vec")]
pub hpke_init_key: &'a HpkePublicKey,
pub leaf_node: &'a LeafNode,
pub extensions: &'a ExtensionList,
}
// #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
impl KeyPackage {
#[cfg(feature = "ffi")]
pub fn version(&self) -> ProtocolVersion {
self.version
}
#[cfg(feature = "ffi")]
pub fn cipher_suite(&self) -> CipherSuite {
self.cipher_suite
}
pub fn signing_identity(&self) -> &SigningIdentity {
&self.leaf_node.signing_identity
}
// #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn to_reference<CP: CipherSuiteProvider>(
&self,
cipher_suite_provider: &CP,
) -> Result<KeyPackageRef, MlsError> {
if cipher_suite_provider.cipher_suite() != self.cipher_suite {
return Err(MlsError::CipherSuiteMismatch);
}
Ok(KeyPackageRef(
HashReference::compute(
&self.mls_encode_to_vec()?,
b"MLS 1.0 KeyPackage Reference",
cipher_suite_provider,
)
.await?,
))
}
pub fn expiration(&self) -> Result<u64, MlsError> {
if let LeafNodeSource::KeyPackage(lifetime) = &self.leaf_node.leaf_node_source {
Ok(lifetime.not_after)
} else {
Err(MlsError::InvalidLeafNodeSource)
}
}
}
impl<'a> Signable<'a> for KeyPackage {
const SIGN_LABEL: &'static str = "KeyPackageTBS";
type SigningContext = ();
fn signature(&self) -> &[u8] {
&self.signature
}
fn signable_content(
&self,
_context: &Self::SigningContext,
) -> Result<Vec<u8>, mls_rs_codec::Error> {
KeyPackageData {
version: self.version,
cipher_suite: self.cipher_suite,
hpke_init_key: &self.hpke_init_key,
leaf_node: &self.leaf_node,
extensions: &self.extensions,
}
.mls_encode_to_vec()
}
fn write_signature(&mut self, signature: Vec<u8>) {
self.signature = signature
}
}
#[cfg(test)]
pub(crate) mod test_utils {
use super::*;
use crate::{
crypto::test_utils::test_cipher_suite_provider,
group::framing::MlsMessagePayload,
identity::basic::BasicIdentityProvider,
identity::test_utils::get_test_signing_identity,
tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
MlsMessage,
};
use mls_rs_core::crypto::SignatureSecretKey;
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn test_key_package(
protocol_version: ProtocolVersion,
cipher_suite: CipherSuite,
id: &str,
) -> KeyPackage {
test_key_package_with_signer(protocol_version, cipher_suite, id)
.await
.0
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn test_key_package_with_signer(
protocol_version: ProtocolVersion,
cipher_suite: CipherSuite,
id: &str,
) -> (KeyPackage, SignatureSecretKey) {
let (signing_identity, secret_key) =
get_test_signing_identity(cipher_suite, id.as_bytes()).await;
let generator = KeyPackageGenerator {
protocol_version,
cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
signing_identity: &signing_identity,
signing_key: &secret_key,
identity_provider: &BasicIdentityProvider,
};
let key_package = generator
.generate(
Lifetime::years(1).unwrap(),
get_test_capabilities(),
ExtensionList::default(),
ExtensionList::default(),
)
.await
.unwrap()
.key_package;
(key_package, secret_key)
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn test_key_package_message(
protocol_version: ProtocolVersion,
cipher_suite: CipherSuite,
id: &str,
) -> MlsMessage {
MlsMessage::new(
protocol_version,
MlsMessagePayload::KeyPackage(
test_key_package(protocol_version, cipher_suite, id).await,
),
)
}
}
#[cfg(test)]
mod tests {
use crate::{
client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
};
use super::{test_utils::test_key_package, *};
use alloc::format;
use assert_matches::assert_matches;
#[derive(serde::Deserialize, serde::Serialize)]
struct TestCase {
cipher_suite: u16,
#[serde(with = "hex::serde")]
input: Vec<u8>,
#[serde(with = "hex::serde")]
output: Vec<u8>,
}
impl TestCase {
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[cfg_attr(coverage_nightly, coverage(off))]
async fn generate() -> Vec<TestCase> {
let mut test_cases = Vec::new();
for (i, (protocol_version, cipher_suite)) in ProtocolVersion::all()
.flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
.enumerate()
{
let pkg =
test_key_package(protocol_version, cipher_suite, &format!("alice{i}")).await;
let pkg_ref = pkg
.to_reference(&test_cipher_suite_provider(cipher_suite))
.await
.unwrap();
let case = TestCase {
cipher_suite: cipher_suite.into(),
input: pkg.mls_encode_to_vec().unwrap(),
output: pkg_ref.to_vec(),
};
test_cases.push(case);
}
test_cases
}
}
#[cfg(mls_build_async)]
async fn load_test_cases() -> Vec<TestCase> {
load_test_case_json!(key_package_ref, TestCase::generate().await)
}
#[cfg(not(mls_build_async))]
fn load_test_cases() -> Vec<TestCase> {
load_test_case_json!(key_package_ref, TestCase::generate())
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_key_package_ref() {
let cases = load_test_cases().await;
for one_case in cases {
let Some(provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
continue;
};
let key_package = KeyPackage::mls_decode(&mut one_case.input.as_slice()).unwrap();
let key_package_ref = key_package.to_reference(&provider).await.unwrap();
let expected_out = KeyPackageRef::from(one_case.output);
assert_eq!(expected_out, key_package_ref);
}
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn key_package_ref_fails_invalid_cipher_suite() {
let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await;
for another_cipher_suite in CipherSuite::all().filter(|cs| cs != &TEST_CIPHER_SUITE) {
if let Some(cs) = try_test_cipher_suite_provider(*another_cipher_suite) {
let res = key_package.to_reference(&cs).await;
assert_matches!(res, Err(MlsError::CipherSuiteMismatch));
}
}
}
}