Source code
Revision control
Copy as Markdown
Other Tools
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
use core::{
fmt::{self, Debug},
ops::Deref,
};
use crate::error::{AnyError, IntoAnyError};
use alloc::vec::Vec;
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
mod list;
pub use list::*;
/// Wrapper type representing an extension identifier along with default values
/// defined by the MLS RFC.
#[derive(
Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode,
)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
// #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[repr(transparent)]
pub struct ExtensionType(u16);
// #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
impl ExtensionType {
pub const APPLICATION_ID: ExtensionType = ExtensionType(1);
pub const RATCHET_TREE: ExtensionType = ExtensionType(2);
pub const REQUIRED_CAPABILITIES: ExtensionType = ExtensionType(3);
pub const EXTERNAL_PUB: ExtensionType = ExtensionType(4);
pub const EXTERNAL_SENDERS: ExtensionType = ExtensionType(5);
/// Default extension types defined
pub const DEFAULT: &'static [ExtensionType] = &[
ExtensionType::APPLICATION_ID,
ExtensionType::RATCHET_TREE,
ExtensionType::REQUIRED_CAPABILITIES,
ExtensionType::EXTERNAL_PUB,
ExtensionType::EXTERNAL_SENDERS,
];
/// Extension type from a raw value
pub const fn new(raw_value: u16) -> Self {
ExtensionType(raw_value)
}
/// Raw numerical wrapped value.
pub const fn raw_value(&self) -> u16 {
self.0
}
/// Determines if this extension type is required to be implemented
/// by the MLS RFC.
pub const fn is_default(&self) -> bool {
self.0 <= 5
}
}
impl From<u16> for ExtensionType {
fn from(value: u16) -> Self {
ExtensionType(value)
}
}
impl Deref for ExtensionType {
type Target = u16;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug)]
#[cfg_attr(feature = "std", derive(thiserror::Error))]
pub enum ExtensionError {
#[cfg_attr(feature = "std", error(transparent))]
SerializationError(AnyError),
#[cfg_attr(feature = "std", error(transparent))]
DeserializationError(AnyError),
#[cfg_attr(feature = "std", error("incorrect extension type: {0:?}"))]
IncorrectType(ExtensionType),
}
impl IntoAnyError for ExtensionError {
#[cfg(feature = "std")]
fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
Ok(self.into())
}
}
#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
// #[cfg_attr(
// all(feature = "ffi", not(test)),
// safer_ffi_gen::ffi_type(clone, opaque)
// )]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
/// An MLS protocol [extension](https://messaginglayersecurity.rocks/mls-protocol/draft-ietf-mls-protocol.html#name-extensions).
///
/// Extensions are used as customization points in various parts of the
/// MLS protocol and are inserted into an [ExtensionList](self::ExtensionList).
pub struct Extension {
/// Extension type of this extension
pub extension_type: ExtensionType,
/// Data held within this extension
#[mls_codec(with = "mls_rs_codec::byte_vec")]
#[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
pub extension_data: Vec<u8>,
}
impl Debug for Extension {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Extension")
.field("extension_type", &self.extension_type)
.field(
"extension_data",
&crate::debug::pretty_bytes(&self.extension_data),
)
.finish()
}
}
// #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
impl Extension {
/// Create an extension with specified type and data properties.
pub fn new(extension_type: ExtensionType, extension_data: Vec<u8>) -> Extension {
Extension {
extension_type,
extension_data,
}
}
/// Extension type of this extension
#[cfg(feature = "ffi")]
pub fn extension_type(&self) -> ExtensionType {
self.extension_type
}
/// Data held within this extension
#[cfg(feature = "ffi")]
pub fn extension_data(&self) -> &[u8] {
&self.extension_data
}
}
/// Trait used to convert a type to and from an [Extension]
pub trait MlsExtension: Sized {
/// Error type of the underlying serializer that can convert this type into a `Vec<u8>`.
type SerializationError: IntoAnyError;
/// Error type of the underlying deserializer that can convert a `Vec<u8>` into this type.
type DeserializationError: IntoAnyError;
/// Extension type value that this type represents.
fn extension_type() -> ExtensionType;
/// Convert this type to opaque bytes.
fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>;
/// Create this type from opaque bytes.
fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError>;
/// Convert this type into an [Extension].
fn into_extension(self) -> Result<Extension, ExtensionError> {
Ok(Extension::new(
Self::extension_type(),
self.to_bytes()
.map_err(|e| ExtensionError::SerializationError(e.into_any_error()))?,
))
}
/// Create this type from an [Extension].
fn from_extension(ext: &Extension) -> Result<Self, ExtensionError> {
if ext.extension_type != Self::extension_type() {
return Err(ExtensionError::IncorrectType(ext.extension_type));
}
Self::from_bytes(&ext.extension_data)
.map_err(|e| ExtensionError::DeserializationError(e.into_any_error()))
}
}
/// Convenience trait for custom extension types that use
/// [mls_rs_codec] as an underlying serialization mechanism
pub trait MlsCodecExtension: MlsSize + MlsEncode + MlsDecode {
fn extension_type() -> ExtensionType;
}
impl<T> MlsExtension for T
where
T: MlsCodecExtension,
{
type SerializationError = mls_rs_codec::Error;
type DeserializationError = mls_rs_codec::Error;
fn extension_type() -> ExtensionType {
<Self as MlsCodecExtension>::extension_type()
}
fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
self.mls_encode_to_vec()
}
fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError> {
Self::mls_decode(&mut &*data)
}
}
#[cfg(test)]
mod tests {
use core::convert::Infallible;
use alloc::vec;
use alloc::vec::Vec;
use assert_matches::assert_matches;
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use super::{Extension, ExtensionError, ExtensionType, MlsCodecExtension, MlsExtension};
struct TestExtension;
#[derive(Debug, MlsSize, MlsEncode, MlsDecode)]
struct AnotherTestExtension;
impl MlsExtension for TestExtension {
type SerializationError = Infallible;
type DeserializationError = Infallible;
fn extension_type() -> super::ExtensionType {
ExtensionType(42)
}
fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
Ok(vec![0])
}
fn from_bytes(_data: &[u8]) -> Result<Self, Self::DeserializationError> {
Ok(TestExtension)
}
}
impl MlsCodecExtension for AnotherTestExtension {
fn extension_type() -> ExtensionType {
ExtensionType(43)
}
}
#[test]
fn into_extension() {
assert_eq!(
TestExtension.into_extension().unwrap(),
Extension::new(42.into(), vec![0])
)
}
#[test]
fn incorrect_type_is_discovered() {
let ext = Extension::new(42.into(), vec![0]);
assert_matches!(AnotherTestExtension::from_extension(&ext), Err(ExtensionError::IncorrectType(found)) if found == 42.into());
}
}