Revision control
Copy as Markdown
Other Tools
//! Create extensions for types you don't own with [extension traits] but without the boilerplate.
//!
//! Example:
//!
//! ```rust
//! use extend::ext;
//!
//! #[ext]
//! impl<T: Ord> Vec<T> {
//! fn sorted(mut self) -> Self {
//! self.sort();
//! self
//! }
//! }
//!
//! assert_eq!(
//! vec![1, 2, 3],
//! vec![2, 3, 1].sorted(),
//! );
//! ```
//!
//! # How does it work?
//!
//! Under the hood it generates a trait with methods in your `impl` and implements those for the
//! type you specify. The code shown above expands roughly to:
//!
//! ```rust
//! trait VecExt<T: Ord> {
//! fn sorted(self) -> Self;
//! }
//!
//! impl<T: Ord> VecExt<T> for Vec<T> {
//! fn sorted(mut self) -> Self {
//! self.sort();
//! self
//! }
//! }
//! ```
//!
//! # Supported items
//!
//! Extensions can contain methods or associated constants:
//!
//! ```rust
//! use extend::ext;
//!
//! #[ext]
//! impl String {
//! const CONSTANT: &'static str = "FOO";
//!
//! fn method() {
//! // ...
//! # todo!()
//! }
//! }
//! ```
//!
//! # Configuration
//!
//! You can configure:
//!
//! - The visibility of the trait. Use `pub impl ...` to generate `pub trait ...`. The default
//! visibility is private.
//! - The name of the generated extension trait. Example: `#[ext(name = MyExt)]`. By default we
//! generate a name based on what you extend.
//! - Which supertraits the generated extension trait should have. Default is no supertraits.
//! Example: `#[ext(supertraits = Default + Clone)]`.
//!
//! More examples:
//!
//! ```rust
//! use extend::ext;
//!
//! #[ext(name = SortedVecExt)]
//! impl<T: Ord> Vec<T> {
//! fn sorted(mut self) -> Self {
//! self.sort();
//! self
//! }
//! }
//!
//! #[ext]
//! pub(crate) impl i32 {
//! fn double(self) -> i32 {
//! self * 2
//! }
//! }
//!
//! #[ext(name = ResultSafeUnwrapExt)]
//! pub impl<T> Result<T, std::convert::Infallible> {
//! fn safe_unwrap(self) -> T {
//! match self {
//! Ok(t) => t,
//! Err(_) => unreachable!(),
//! }
//! }
//! }
//!
//! #[ext(supertraits = Default + Clone)]
//! impl String {
//! fn my_length(self) -> usize {
//! self.len()
//! }
//! }
//! ```
//!
//! For backwards compatibility you can also declare the visibility as the first argument to `#[ext]`:
//!
//! ```
//! use extend::ext;
//!
//! #[ext(pub)]
//! impl i32 {
//! fn double(self) -> i32 {
//! self * 2
//! }
//! }
//! ```
//!
//! # async-trait compatibility
//!
//!
//! Be aware that you need to add `#[async_trait]` _below_ `#[ext]`. Otherwise the `ext` macro
//! cannot see the `#[async_trait]` attribute and pass it along in the generated code.
//!
//! Example:
//!
//! ```
//! use extend::ext;
//! use async_trait::async_trait;
//!
//! #[ext]
//! #[async_trait]
//! impl String {
//! async fn read_file() -> String {
//! // ...
//! # todo!()
//! }
//! }
//! ```
//!
//! # Other attributes
//!
//! Other attributes provided _below_ `#[ext]` will be passed along to both the generated trait and
//! the implementation. See [async-trait compatibility](#async-trait-compatibility) above for an
//! example.
//!
//! [extension traits]: https://dev.to/matsimitsu/extending-existing-functionality-in-rust-with-traits-in-rust-3622
#![allow(clippy::let_and_return)]
#![deny(unused_variables, dead_code, unused_must_use, unused_imports)]
use proc_macro2::TokenStream;
use quote::{format_ident, quote, ToTokens};
use std::convert::{TryFrom, TryInto};
use syn::{
parse::{self, Parse, ParseStream},
parse_macro_input, parse_quote,
punctuated::Punctuated,
spanned::Spanned,
token::{Plus, Semi},
Ident, ImplItem, ItemImpl, Result, Token, TraitItemConst, TraitItemFn, Type, TypeArray,
TypeBareFn, TypeGroup, TypeNever, TypeParamBound, TypeParen, TypePath, TypePtr, TypeReference,
TypeSlice, TypeTraitObject, TypeTuple, Visibility,
};
#[derive(Debug)]
struct Input {
item_impl: ItemImpl,
vis: Option<Visibility>,
}
impl Parse for Input {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut attributes = Vec::new();
if input.peek(syn::Token![#]) {
attributes.extend(syn::Attribute::parse_outer(input)?);
}
let vis = input
.parse::<Visibility>()
.ok()
.filter(|vis| vis != &Visibility::Inherited);
let mut item_impl = input.parse::<ItemImpl>()?;
item_impl.attrs.extend(attributes);
Ok(Self { item_impl, vis })
}
}
/// See crate docs for more info.
#[proc_macro_attribute]
#[allow(clippy::unneeded_field_pattern)]
pub fn ext(
attr: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let item = parse_macro_input!(item as Input);
let config = parse_macro_input!(attr as Config);
match go(item, config) {
Ok(tokens) => tokens,
Err(err) => err.into_compile_error().into(),
}
}
/// Like [`ext`](macro@crate::ext) but always add `Sized` as a supertrait.
///
/// This is provided as a convenience for generating extension traits that require `Self: Sized`
/// such as:
///
/// ```
/// use extend::ext_sized;
///
/// #[ext_sized]
/// impl i32 {
/// fn requires_sized(self) -> Option<Self> {
/// Some(self)
/// }
/// }
/// ```
#[proc_macro_attribute]
#[allow(clippy::unneeded_field_pattern)]
pub fn ext_sized(
attr: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let item = parse_macro_input!(item as Input);
let mut config: Config = parse_macro_input!(attr as Config);
config.supertraits = if let Some(supertraits) = config.supertraits.take() {
Some(parse_quote!(#supertraits + Sized))
} else {
Some(parse_quote!(Sized))
};
match go(item, config) {
Ok(tokens) => tokens,
Err(err) => err.into_compile_error().into(),
}
}
fn go(item: Input, mut config: Config) -> Result<proc_macro::TokenStream> {
if let Some(vis) = item.vis {
if config.visibility != Visibility::Inherited {
return Err(syn::Error::new(
config.visibility.span(),
"Cannot set visibility on `#[ext]` and `impl` block",
));
}
config.visibility = vis;
}
let ItemImpl {
attrs,
unsafety,
generics,
trait_,
self_ty,
items,
// What is defaultness?
defaultness: _,
impl_token: _,
brace_token: _,
} = item.item_impl;
if let Some((_, path, _)) = trait_ {
return Err(syn::Error::new(
path.span(),
"Trait impls cannot be used for #[ext]",
));
}
let self_ty = parse_self_ty(&self_ty)?;
let ext_trait_name = if let Some(ext_trait_name) = config.ext_trait_name {
ext_trait_name
} else {
ext_trait_name(&self_ty)?
};
let MethodsAndConsts {
trait_methods,
trait_consts,
} = extract_allowed_items(&items)?;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let visibility = &config.visibility;
let mut all_supertraits = Vec::<TypeParamBound>::new();
if let Some(supertraits_from_config) = config.supertraits {
all_supertraits.extend(supertraits_from_config);
}
let supertraits_quoted = if all_supertraits.is_empty() {
quote! {}
} else {
let supertraits_quoted = punctuated_from_iter::<_, _, Plus>(all_supertraits);
quote! { : #supertraits_quoted }
};
let code = (quote! {
#[allow(non_camel_case_types)]
#(#attrs)*
#visibility
#unsafety
trait #ext_trait_name #impl_generics #supertraits_quoted #where_clause {
#(
#trait_consts
)*
#(
#[allow(
patterns_in_fns_without_body,
clippy::inline_fn_without_body,
unused_attributes
)]
#trait_methods
)*
}
#(#attrs)*
impl #impl_generics #ext_trait_name #ty_generics for #self_ty #where_clause {
#(#items)*
}
})
.into();
Ok(code)
}
#[derive(Debug, Clone)]
enum ExtType<'a> {
Array(&'a TypeArray),
Group(&'a TypeGroup),
Never(&'a TypeNever),
Paren(&'a TypeParen),
Path(&'a TypePath),
Ptr(&'a TypePtr),
Reference(&'a TypeReference),
Slice(&'a TypeSlice),
Tuple(&'a TypeTuple),
BareFn(&'a TypeBareFn),
TraitObject(&'a TypeTraitObject),
}
#[allow(clippy::wildcard_in_or_patterns)]
fn parse_self_ty(self_ty: &Type) -> Result<ExtType> {
let ty = match self_ty {
Type::Array(inner) => ExtType::Array(inner),
Type::Group(inner) => ExtType::Group(inner),
Type::Never(inner) => ExtType::Never(inner),
Type::Paren(inner) => ExtType::Paren(inner),
Type::Path(inner) => ExtType::Path(inner),
Type::Ptr(inner) => ExtType::Ptr(inner),
Type::Reference(inner) => ExtType::Reference(inner),
Type::Slice(inner) => ExtType::Slice(inner),
Type::Tuple(inner) => ExtType::Tuple(inner),
Type::BareFn(inner) => ExtType::BareFn(inner),
Type::TraitObject(inner) => ExtType::TraitObject(inner),
Type::ImplTrait(_) | Type::Infer(_) | Type::Macro(_) | Type::Verbatim(_) | _ => {
return Err(syn::Error::new(
self_ty.span(),
"#[ext] is not supported for this kind of type",
))
}
};
Ok(ty)
}
impl<'a> TryFrom<&'a Type> for ExtType<'a> {
type Error = syn::Error;
fn try_from(inner: &'a Type) -> Result<ExtType<'a>> {
parse_self_ty(inner)
}
}
impl<'a> ToTokens for ExtType<'a> {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
ExtType::Array(inner) => inner.to_tokens(tokens),
ExtType::Group(inner) => inner.to_tokens(tokens),
ExtType::Never(inner) => inner.to_tokens(tokens),
ExtType::Paren(inner) => inner.to_tokens(tokens),
ExtType::Path(inner) => inner.to_tokens(tokens),
ExtType::Ptr(inner) => inner.to_tokens(tokens),
ExtType::Reference(inner) => inner.to_tokens(tokens),
ExtType::Slice(inner) => inner.to_tokens(tokens),
ExtType::Tuple(inner) => inner.to_tokens(tokens),
ExtType::BareFn(inner) => inner.to_tokens(tokens),
ExtType::TraitObject(inner) => inner.to_tokens(tokens),
}
}
}
fn ext_trait_name(self_ty: &ExtType) -> Result<Ident> {
fn inner_self_ty(self_ty: &ExtType) -> Result<Ident> {
match self_ty {
ExtType::Path(inner) => find_and_combine_idents(inner),
ExtType::Reference(inner) => {
let name = inner_self_ty(&(&*inner.elem).try_into()?)?;
if inner.mutability.is_some() {
Ok(format_ident!("RefMut{}", name))
} else {
Ok(format_ident!("Ref{}", name))
}
}
ExtType::Array(inner) => {
let name = inner_self_ty(&(&*inner.elem).try_into()?)?;
Ok(format_ident!("ListOf{}", name))
}
ExtType::Group(inner) => {
let name = inner_self_ty(&(&*inner.elem).try_into()?)?;
Ok(format_ident!("Group{}", name))
}
ExtType::Paren(inner) => {
let name = inner_self_ty(&(&*inner.elem).try_into()?)?;
Ok(format_ident!("Paren{}", name))
}
ExtType::Ptr(inner) => {
let name = inner_self_ty(&(&*inner.elem).try_into()?)?;
Ok(format_ident!("PointerTo{}", name))
}
ExtType::Slice(inner) => {
let name = inner_self_ty(&(&*inner.elem).try_into()?)?;
Ok(format_ident!("SliceOf{}", name))
}
ExtType::Tuple(inner) => {
let mut name = format_ident!("TupleOf");
for elem in &inner.elems {
name = format_ident!("{}{}", name, inner_self_ty(&elem.try_into()?)?);
}
Ok(name)
}
ExtType::Never(_) => Ok(format_ident!("Never")),
ExtType::BareFn(inner) => {
let mut name = format_ident!("BareFn");
for input in inner.inputs.iter() {
name = format_ident!("{}{}", name, inner_self_ty(&(&input.ty).try_into()?)?);
}
match &inner.output {
syn::ReturnType::Default => {
name = format_ident!("{}Unit", name);
}
syn::ReturnType::Type(_, ty) => {
name = format_ident!("{}{}", name, inner_self_ty(&(&**ty).try_into()?)?);
}
}
Ok(name)
}
ExtType::TraitObject(inner) => {
let mut name = format_ident!("TraitObject");
for bound in inner.bounds.iter() {
match bound {
TypeParamBound::Trait(bound) => {
for segment in bound.path.segments.iter() {
name = format_ident!("{}{}", name, segment.ident);
}
}
TypeParamBound::Lifetime(lifetime) => {
name = format_ident!("{}{}", name, lifetime.ident);
}
other => {
return Err(syn::Error::new(other.span(), "unsupported bound"));
}
}
}
Ok(name)
}
}
}
Ok(format_ident!("{}Ext", inner_self_ty(self_ty)?))
}
fn find_and_combine_idents(type_path: &TypePath) -> Result<Ident> {
use syn::visit::{self, Visit};
struct IdentVisitor<'a>(Vec<&'a Ident>);
impl<'a> Visit<'a> for IdentVisitor<'a> {
fn visit_ident(&mut self, i: &'a Ident) {
self.0.push(i);
}
}
let mut visitor = IdentVisitor(Vec::new());
visit::visit_type_path(&mut visitor, type_path);
let idents = visitor.0;
if idents.is_empty() {
Err(syn::Error::new(type_path.span(), "Empty type path"))
} else {
let start = &idents[0].span();
let combined_span = idents
.iter()
.map(|i| i.span())
.fold(*start, |a, b| a.join(b).unwrap_or(a));
let combined_name = idents.iter().map(|i| i.to_string()).collect::<String>();
Ok(Ident::new(&combined_name, combined_span))
}
}
#[derive(Debug, Default)]
struct MethodsAndConsts {
trait_methods: Vec<TraitItemFn>,
trait_consts: Vec<TraitItemConst>,
}
#[allow(clippy::wildcard_in_or_patterns)]
fn extract_allowed_items(items: &[ImplItem]) -> Result<MethodsAndConsts> {
let mut acc = MethodsAndConsts::default();
for item in items {
match item {
ImplItem::Fn(method) => acc.trait_methods.push(TraitItemFn {
attrs: method.attrs.clone(),
sig: {
let mut sig = method.sig.clone();
sig.inputs = sig
.inputs
.into_iter()
.map(|fn_arg| match fn_arg {
syn::FnArg::Receiver(recv) => syn::FnArg::Receiver(recv),
syn::FnArg::Typed(mut pat_type) => {
pat_type.pat = Box::new(match *pat_type.pat {
syn::Pat::Ident(pat_ident) => syn::Pat::Ident(pat_ident),
_ => {
parse_quote!(_)
}
});
syn::FnArg::Typed(pat_type)
}
})
.collect();
sig
},
default: None,
semi_token: Some(Semi::default()),
}),
ImplItem::Const(const_) => acc.trait_consts.push(TraitItemConst {
attrs: const_.attrs.clone(),
generics: const_.generics.clone(),
const_token: Default::default(),
ident: const_.ident.clone(),
colon_token: Default::default(),
ty: const_.ty.clone(),
default: None,
semi_token: Default::default(),
}),
ImplItem::Type(_) => {
return Err(syn::Error::new(
item.span(),
"Associated types are not allowed in #[ext] impls",
))
}
ImplItem::Macro(_) => {
return Err(syn::Error::new(
item.span(),
"Macros are not allowed in #[ext] impls",
))
}
ImplItem::Verbatim(_) | _ => {
return Err(syn::Error::new(item.span(), "Not allowed in #[ext] impls"))
}
}
}
Ok(acc)
}
#[derive(Debug)]
struct Config {
ext_trait_name: Option<Ident>,
visibility: Visibility,
supertraits: Option<Punctuated<TypeParamBound, Plus>>,
}
impl Parse for Config {
fn parse(input: ParseStream) -> parse::Result<Self> {
let mut config = Config::default();
if let Ok(visibility) = input.parse::<Visibility>() {
config.visibility = visibility;
}
input.parse::<Token![,]>().ok();
while !input.is_empty() {
let ident = input.parse::<Ident>()?;
input.parse::<Token![=]>()?;
match &*ident.to_string() {
"name" => {
config.ext_trait_name = Some(input.parse()?);
}
"supertraits" => {
config.supertraits =
Some(Punctuated::<TypeParamBound, Plus>::parse_terminated(input)?);
}
_ => return Err(syn::Error::new(ident.span(), "Unknown configuration name")),
}
input.parse::<Token![,]>().ok();
}
Ok(config)
}
}
impl Default for Config {
fn default() -> Self {
Self {
ext_trait_name: None,
visibility: Visibility::Inherited,
supertraits: None,
}
}
}
fn punctuated_from_iter<I, T, P>(i: I) -> Punctuated<T, P>
where
P: Default,
I: IntoIterator<Item = T>,
{
let mut iter = i.into_iter().peekable();
let mut acc = Punctuated::default();
while let Some(item) = iter.next() {
acc.push_value(item);
if iter.peek().is_some() {
acc.push_punct(P::default());
}
}
acc
}
#[cfg(test)]
mod test {
#[allow(unused_imports)]
use super::*;
#[test]
fn test_ui() {
let t = trybuild::TestCases::new();
t.pass("tests/compile_pass/*.rs");
t.compile_fail("tests/compile_fail/*.rs");
}
}