Revision control
Copy as Markdown
Other Tools
extern crate proc_macro;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::*;
mod container_attributes;
mod field_attributes;
use container_attributes::ContainerAttributes;
use field_attributes::{determine_field_constructor, FieldConstructor};
static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
expand_derive_arbitrary(input)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
let container_attrs = ContainerAttributes::from_derive_input(&input)?;
let (lifetime_without_bounds, lifetime_with_bounds) =
build_arbitrary_lifetime(input.generics.clone());
let recursive_count = syn::Ident::new(
&format!("RECURSIVE_COUNT_{}", input.ident),
Span::call_site(),
);
let arbitrary_method =
gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
let size_hint_method = gen_size_hint_method(&input)?;
let name = input.ident;
// Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
let generics = apply_trait_bounds(
input.generics,
lifetime_without_bounds.clone(),
&container_attrs,
)?;
let mut generics_with_lifetime = generics.clone();
generics_with_lifetime
.params
.push(GenericParam::Lifetime(lifetime_with_bounds));
let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();
// Build TypeGenerics and WhereClause without a lifetime
let (_, ty_generics, where_clause) = generics.split_for_impl();
Ok(quote! {
const _: () = {
std::thread_local! {
#[allow(non_upper_case_globals)]
static #recursive_count: std::cell::Cell<u32> = std::cell::Cell::new(0);
}
#[automatically_derived]
impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
#arbitrary_method
#size_hint_method
}
};
})
}
// Returns: (lifetime without bounds, lifetime with bounds)
// Example: ("'arbitrary", "'arbitrary: 'a + 'b")
fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam) {
let lifetime_without_bounds =
LifetimeParam::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
let mut lifetime_with_bounds = lifetime_without_bounds.clone();
for param in generics.params.iter() {
if let GenericParam::Lifetime(lifetime_def) = param {
lifetime_with_bounds
.bounds
.push(lifetime_def.lifetime.clone());
}
}
(lifetime_without_bounds, lifetime_with_bounds)
}
fn apply_trait_bounds(
mut generics: Generics,
lifetime: LifetimeParam,
container_attrs: &ContainerAttributes,
) -> Result<Generics> {
// If user-supplied bounds exist, apply them to their matching type parameters.
if let Some(config_bounds) = &container_attrs.bounds {
let mut config_bounds_applied = 0;
for param in generics.params.iter_mut() {
if let GenericParam::Type(type_param) = param {
if let Some(replacement) = config_bounds
.iter()
.flatten()
.find(|p| p.ident == type_param.ident)
{
*type_param = replacement.clone();
config_bounds_applied += 1;
} else {
// If no user-supplied bounds exist for this type, delete the original bounds.
// This mimics serde.
type_param.bounds = Default::default();
type_param.default = None;
}
}
}
let config_bounds_supplied = config_bounds
.iter()
.map(|bounds| bounds.len())
.sum::<usize>();
if config_bounds_applied != config_bounds_supplied {
return Err(Error::new(
Span::call_site(),
format!(
"invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
),
));
}
Ok(generics)
} else {
// Otherwise, inject a `T: Arbitrary` bound for every parameter.
Ok(add_trait_bounds(generics, lifetime))
}
}
// Add a bound `T: Arbitrary` to every type parameter T.
fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics {
for param in generics.params.iter_mut() {
if let GenericParam::Type(type_param) = param {
type_param
.bounds
.push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
}
}
generics
}
fn with_recursive_count_guard(
recursive_count: &syn::Ident,
expr: impl quote::ToTokens,
) -> impl quote::ToTokens {
quote! {
let guard_against_recursion = u.is_empty();
if guard_against_recursion {
#recursive_count.with(|count| {
if count.get() > 0 {
return Err(arbitrary::Error::NotEnoughData);
}
count.set(count.get() + 1);
Ok(())
})?;
}
let result = (|| { #expr })();
if guard_against_recursion {
#recursive_count.with(|count| {
count.set(count.get() - 1);
});
}
result
}
}
fn gen_arbitrary_method(
input: &DeriveInput,
lifetime: LifetimeParam,
recursive_count: &syn::Ident,
) -> Result<TokenStream> {
fn arbitrary_structlike(
fields: &Fields,
ident: &syn::Ident,
lifetime: LifetimeParam,
recursive_count: &syn::Ident,
) -> Result<TokenStream> {
let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });
let arbitrary_take_rest = construct_take_rest(fields)?;
let take_rest_body =
with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });
Ok(quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#body
}
fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#take_rest_body
}
})
}
let ident = &input.ident;
let output = match &input.data {
Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)?,
Data::Union(data) => arbitrary_structlike(
&Fields::Named(data.fields.clone()),
ident,
lifetime,
recursive_count,
)?,
Data::Enum(data) => {
let variants: Vec<TokenStream> = data
.variants
.iter()
.enumerate()
.map(|(i, variant)| {
let idx = i as u64;
let variant_name = &variant.ident;
construct(&variant.fields, |_, field| gen_constructor_for_field(field))
.map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
})
.collect::<Result<_>>()?;
let variants_take_rest: Vec<TokenStream> = data
.variants
.iter()
.enumerate()
.map(|(i, variant)| {
let idx = i as u64;
let variant_name = &variant.ident;
construct_take_rest(&variant.fields)
.map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
})
.collect::<Result<_>>()?;
let count = data.variants.len() as u64;
let arbitrary = with_recursive_count_guard(
recursive_count,
quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(u)?) * #count) >> 32 {
#(#variants,)*
_ => unreachable!()
})
},
);
let arbitrary_take_rest = with_recursive_count_guard(
recursive_count,
quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 {
#(#variants_take_rest,)*
_ => unreachable!()
})
},
);
quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary
}
fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary_take_rest
}
}
}
};
Ok(output)
}
fn construct(
fields: &Fields,
ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
) -> Result<TokenStream> {
let output = match fields {
Fields::Named(names) => {
let names: Vec<TokenStream> = names
.named
.iter()
.enumerate()
.map(|(i, f)| {
let name = f.ident.as_ref().unwrap();
ctor(i, f).map(|ctor| quote! { #name: #ctor })
})
.collect::<Result<_>>()?;
quote! { { #(#names,)* } }
}
Fields::Unnamed(names) => {
let names: Vec<TokenStream> = names
.unnamed
.iter()
.enumerate()
.map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
.collect::<Result<_>>()?;
quote! { ( #(#names),* ) }
}
Fields::Unit => quote!(),
};
Ok(output)
}
fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
construct(fields, |idx, field| {
determine_field_constructor(field).map(|field_constructor| match field_constructor {
FieldConstructor::Default => quote!(Default::default()),
FieldConstructor::Arbitrary => {
if idx + 1 == fields.len() {
quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
} else {
quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
}
}
FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
FieldConstructor::Value(value) => quote!(#value),
})
})
}
fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
let size_hint_fields = |fields: &Fields| {
fields
.iter()
.map(|f| {
let ty = &f.ty;
determine_field_constructor(f).map(|field_constructor| {
match field_constructor {
FieldConstructor::Default | FieldConstructor::Value(_) => {
quote!((0, Some(0)))
}
FieldConstructor::Arbitrary => {
quote! { <#ty as arbitrary::Arbitrary>::size_hint(depth) }
}
// Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
// just an educated guess, although it's gonna be inaccurate for dynamically
// allocated types (Vec, HashMap, etc.).
FieldConstructor::With(_) => {
quote! { (::core::mem::size_of::<#ty>(), None) }
}
}
})
})
.collect::<Result<Vec<TokenStream>>>()
.map(|hints| {
quote! {
arbitrary::size_hint::and_all(&[
#( #hints ),*
])
}
})
};
let size_hint_structlike = |fields: &Fields| {
size_hint_fields(fields).map(|hint| {
quote! {
#[inline]
fn size_hint(depth: usize) -> (usize, Option<usize>) {
arbitrary::size_hint::recursion_guard(depth, |depth| #hint)
}
}
})
};
match &input.data {
Data::Struct(data) => size_hint_structlike(&data.fields),
Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
Data::Enum(data) => data
.variants
.iter()
.map(|v| size_hint_fields(&v.fields))
.collect::<Result<Vec<TokenStream>>>()
.map(|variants| {
quote! {
#[inline]
fn size_hint(depth: usize) -> (usize, Option<usize>) {
arbitrary::size_hint::and(
<u32 as arbitrary::Arbitrary>::size_hint(depth),
arbitrary::size_hint::recursion_guard(depth, |depth| {
arbitrary::size_hint::or_all(&[ #( #variants ),* ])
}),
)
}
}
}),
}
}
fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
let ctor = match determine_field_constructor(field)? {
FieldConstructor::Default => quote!(Default::default()),
FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
FieldConstructor::Value(value) => quote!(#value),
};
Ok(ctor)
}