Revision control

Copy as Markdown

Other Tools

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
use std::str::FromStr;
use darling::{
ast::{self, Fields},
FromDeriveInput, FromField, FromVariant,
};
use proc_macro2::{Literal, TokenStream};
use quote::quote;
use syn::{
parse_macro_input, parse_quote, Attribute, DeriveInput, Expr, Generics, Ident, Index, Lit, Path,
};
enum Operation {
Size,
Encode,
Decode,
}
impl Operation {
fn path(&self) -> Path {
match self {
Operation::Size => parse_quote! { mls_rs_codec::MlsSize },
Operation::Encode => parse_quote! { mls_rs_codec::MlsEncode },
Operation::Decode => parse_quote! { mls_rs_codec::MlsDecode },
}
}
fn call(&self) -> TokenStream {
match self {
Operation::Size => quote! { mls_encoded_len },
Operation::Encode => quote! { mls_encode },
Operation::Decode => quote! { mls_decode },
}
}
fn extras(&self) -> TokenStream {
match self {
Operation::Size => quote! {},
Operation::Encode => quote! { , writer },
Operation::Decode => quote! { reader },
}
}
fn is_result(&self) -> bool {
match self {
Operation::Size => false,
Operation::Encode => true,
Operation::Decode => true,
}
}
}
#[derive(Debug, FromField)]
#[darling(attributes(mls_codec))]
struct MlsFieldReceiver {
ident: Option<Ident>,
with: Option<Path>,
}
impl MlsFieldReceiver {
pub fn call_tokens(&self, index: Index) -> TokenStream {
if let Some(ref ident) = self.ident {
quote! { &self.#ident }
} else {
quote! { &self.#index }
}
}
pub fn name(&self, index: Index) -> TokenStream {
if let Some(ref ident) = self.ident {
quote! {#ident: }
} else {
quote! { #index: }
}
}
}
#[derive(Debug, FromVariant)]
#[darling(attributes(mls_codec))]
struct MlsVariantReceiver {
ident: Ident,
discriminant: Option<Expr>,
fields: ast::Fields<MlsFieldReceiver>,
}
#[derive(FromDeriveInput)]
#[darling(attributes(mls_codec), forward_attrs(repr))]
struct MlsInputReceiver {
attrs: Vec<Attribute>,
ident: Ident,
generics: Generics,
data: ast::Data<MlsVariantReceiver, MlsFieldReceiver>,
}
impl MlsInputReceiver {
fn handle_input(&self, operation: Operation) -> TokenStream {
match self.data {
ast::Data::Struct(ref s) => struct_impl(s, operation),
ast::Data::Enum(ref e) => enum_impl(&self.ident, &self.attrs, e, operation),
}
}
}
fn repr_ident(attrs: &[Attribute]) -> Option<Ident> {
let repr_path = attrs
.iter()
.filter(|attr| matches!(attr.style, syn::AttrStyle::Outer))
.find(|attr| attr.path().is_ident("repr"))
.map(|repr| repr.parse_args())
.transpose()
.ok()
.flatten();
let Some(Expr::Path(path)) = repr_path else {
return None;
};
path.path
.segments
.iter()
.find(|s| s.ident != "C")
.map(|path| path.ident.clone())
}
/// Provides the discriminant for a given variant. If the variant does not specify a suffix
/// and a `repr_ident` is provided, it will be appended to number.
fn discriminant_for_variant(
variant: &MlsVariantReceiver,
repr_ident: &Option<Ident>,
) -> TokenStream {
let discriminant = variant
.discriminant
.clone()
.expect("Enum discriminants must be explicitly defined");
let Expr::Lit(lit_expr) = &discriminant else {
return quote! {#discriminant};
};
let Lit::Int(lit_int) = &lit_expr.lit else {
return quote! {#discriminant};
};
if lit_int.suffix().is_empty() {
// This is dirty and there is probably a better way of doing this but I'm way too much of a noob at
// proc macros to pull it off...
// TODO: Add proper support for correctly ignoring transparent, packed and modifiers
let str = format!(
"{}{}",
lit_int.base10_digits(),
&repr_ident.clone().expect("Expected a repr(u*) to be provided or for the variant's discriminant to be defined with suffixed literals.")
);
Literal::from_str(&str)
.map(|l| quote! {#l})
.ok()
.unwrap_or_else(|| quote! {#discriminant})
} else {
quote! {#discriminant}
}
}
fn enum_impl(
ident: &Ident,
attrs: &[Attribute],
variants: &[MlsVariantReceiver],
operation: Operation,
) -> TokenStream {
let handle_error = operation.is_result().then_some(quote! { ? });
let path = operation.path();
let call = operation.call();
let extras = operation.extras();
let enum_name = &ident;
let repr_ident = repr_ident(attrs);
if matches!(operation, Operation::Decode) {
let cases = variants.iter().map(|variant| {
let variant_name = &variant.ident;
let discriminant = discriminant_for_variant(variant, &repr_ident);
// TODO: Support more than 1 field
match variant.fields.len() {
0 => quote! { #discriminant => Ok(#enum_name::#variant_name), },
1 =>{
let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path);
quote! { #discriminant => Ok(#enum_name::#variant_name(#path::#call(#extras) #handle_error)), }
},
_ => panic!("Enum discriminants with more than 1 field are not currently supported")
}
});
return quote! {
let discriminant = #path::#call(#extras)#handle_error;
match discriminant {
#(#cases)*
_ => Err(mls_rs_codec::Error::UnsupportedEnumDiscriminant),
}
};
}
let cases = variants.iter().map(|variant| {
let variant_name = &variant.ident;
let discriminant = discriminant_for_variant(variant, &repr_ident);
let (parameter, field) = if variant.fields.is_empty() {
(None, None)
} else {
let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path);
let start = match operation {
Operation::Size => Some(quote! { + }),
Operation::Encode => Some(quote! {;}),
Operation::Decode => None,
};
(
Some(quote! {(ref val)}),
Some(quote! { #start #path::#call (val #extras) #handle_error }),
)
};
let discrim = quote! { #path::#call (&#discriminant #extras) #handle_error };
quote! { #enum_name::#variant_name #parameter => { #discrim #field }}
});
let enum_impl = quote! {
match self {
#(#cases)*
}
};
if operation.is_result() {
quote! {
Ok(#enum_impl)
}
} else {
enum_impl
}
}
fn struct_impl(s: &Fields<MlsFieldReceiver>, operation: Operation) -> TokenStream {
let recurse = s.fields.iter().enumerate().map(|(index, field)| {
let (call_tokens, field_name) = match operation {
Operation::Size | Operation::Encode => {
(field.call_tokens(Index::from(index)), quote! {})
}
Operation::Decode => (quote! {}, field.name(Index::from(index))),
};
let handle_error = operation.is_result().then_some(quote! { ? });
let path = field.with.clone().unwrap_or(operation.path());
let call = operation.call();
let extras = operation.extras();
quote! {
#field_name #path::#call (#call_tokens #extras) #handle_error
}
});
match operation {
Operation::Size => quote! { 0 #(+ #recurse)* },
Operation::Encode => quote! { #(#recurse;)* Ok(()) },
Operation::Decode => quote! { Ok(Self { #(#recurse,)* }) },
}
}
fn derive_impl<F>(
input: proc_macro::TokenStream,
trait_name: TokenStream,
function_def: TokenStream,
internals: F,
) -> proc_macro::TokenStream
where
F: FnOnce(&MlsInputReceiver) -> TokenStream,
{
let input = parse_macro_input!(input as DeriveInput);
let input = MlsInputReceiver::from_derive_input(&input).unwrap();
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
// Generate an expression to sum up the heap size of each field.
let function_impl = internals(&input);
let expanded = quote! {
// The generated impl.
impl #impl_generics #trait_name for #name #ty_generics #where_clause {
#function_def {
#function_impl
}
}
};
// Hand the output tokens back to the compiler.
proc_macro::TokenStream::from(expanded)
}
#[proc_macro_derive(MlsSize, attributes(mls_codec))]
pub fn derive_size(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let trait_name = quote! { mls_rs_codec::MlsSize };
let function_def = quote! {fn mls_encoded_len(&self) -> usize };
derive_impl(input, trait_name, function_def, |input| {
input.handle_input(Operation::Size)
})
}
#[proc_macro_derive(MlsEncode, attributes(mls_codec))]
pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let trait_name = quote! { mls_rs_codec::MlsEncode };
let function_def = quote! { fn mls_encode(&self, writer: &mut mls_rs_codec::Vec<u8>) -> Result<(), mls_rs_codec::Error> };
derive_impl(input, trait_name, function_def, |input| {
input.handle_input(Operation::Encode)
})
}
#[proc_macro_derive(MlsDecode, attributes(mls_codec))]
pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let trait_name = quote! { mls_rs_codec::MlsDecode };
let function_def =
quote! { fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> };
derive_impl(input, trait_name, function_def, |input| {
input.handle_input(Operation::Decode)
})
}