beginner – Rust Builder pattern derive-macro

Inspired by dtolnay’s procedural macro workshop (on Github), I have implemented a derive-macro to automatically implement the Builder pattern on a Rust struct, which allows a struct instance to be created using chained method calls on an intermediate object instead of having to provide all the required values at once. A more exact specification (along with a usage example) is given in the top-level doc comment in lib.rs.

This is my first “real” Rust project, so I’d be especially interested to hear about more idiomatic ways to approach this problem.

lib.rs

extern crate proc_macro;
mod fields;

use proc_macro::TokenStream;
use proc_macro2::{self, Ident, Span};
use quote::{format_ident, quote};
use syn::{parse_macro_input, Data, DeriveInput, Field, Fields};

use fields::{AnalyzedField, MetaData};

/// Constructs an implementation of the builder pattern for the struct.
///
/// Generates a struct `{struct_name}Builder` that can be created using
/// the `builder()` method on the main struct.
///
/// Initialize fields of the builder using methods of the same name as
/// the fields. These can also be chained.
///
/// All fields must be set (see exceptions below) before the `build()` method
/// may be called, which will return an instance of the struct with the
/// values from the builder. If `build()` is called before all values have been set,
/// an `Incomplete{struct_name}Error` will be returned.
///
/// Fields of the form `Option<T>` in the struct do not have to be set before
/// calling `build()`.
///
/// Fields of the form `Vec<T>` can be annotated with `#(builder(each = "{arg}"))`
/// to generate a method `arg(T)` that will append its argument to the Vec.
/// A Vec that is annotated with `each` will be initialized to an empty Vec on
/// construction of the builder.
///
/// # Example
///
/// ```
/// use derive_builder::Builder;
///
/// #(derive(Builder))
/// struct Thing {
///     num: usize,
///     name: String,
///     optional: Option<String>,
///     #(builder(each = "arg"))
///     args: Vec<i32>,
/// }
///
/// let instance = Thing::builder()
///                 .num(25)
///                 .name("John".to_owned())
///                 .arg(5)
///                 .arg(34)
///                 .build()
///                 .unwrap();
///
/// assert_eq!(instance.num, 25);
/// assert_eq!(instance.name, "John".to_owned());
/// assert_eq!(instance.optional, None);
/// assert_eq!(instance.args, vec!(5, 34));
/// ```
#(proc_macro_derive(Builder, attributes(builder)))
pub fn derive(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = &input.ident;

    let fields = match input.data {
        Data::Struct(syn::DataStruct {
            fields: Fields::Named(fields),
            ..
        }) => fields.named.into_iter().collect::<Vec<Field>>(),
        _ => {
            let e = syn::Error::new_spanned(
                &input,
                "the builder macro may only be used with named fields",
            )
            .to_compile_error();
            return TokenStream::from(quote! {
                fn err() { #e }
            });
        }
    };

    let mut analyzed_fields: Vec<AnalyzedField> = Vec::with_capacity(fields.len());
    for field in fields {
        match AnalyzedField::new(field) {
            Ok(f) => analyzed_fields.push(f),
            Err(e) => {
                let e = e.to_compile_error();
                return TokenStream::from(quote! {
                    fn err() { #e }
                });
            }
        }
    }

    let builder_constructor = make_builder_constructor(name, &analyzed_fields);

    let builder_struct = make_builder_struct(name, &analyzed_fields);

    let builder_impl = make_builder_impl(name, &analyzed_fields);

    let error_type = make_builder_error_type(name);

    let tokens = TokenStream::from(quote! {
        #builder_constructor
        #builder_struct
        #builder_impl
        #error_type
    });

    tokens
}

fn make_builder_constructor(
    type_name: &proc_macro2::Ident,
    fields: &Vec<AnalyzedField>,
) -> proc_macro2::TokenStream {
    let builder_initializers = fields.iter().map(|field| {
        let ident = field.get_field().ident.as_ref().unwrap();
        match field.get_meta_data() {
            MetaData::Each(_, _) => quote! { #ident: std::option::Option::Some(Vec::new()) },
            _ => quote! { #ident: std::option::Option::None },
        }
    });

    let name_builder = get_builder_name(type_name);

    quote! {
        impl #type_name {
            pub fn builder() -> #name_builder {
                #name_builder {
                    #(#builder_initializers),*
                }
            }
        }
    }
}

fn make_builder_struct(
    type_name: &proc_macro2::Ident,
    fields: &Vec<AnalyzedField>,
) -> proc_macro2::TokenStream {
    let option_fields = fields.iter().map(|field| {
        let ident = field.get_field().ident.as_ref().unwrap();
        let ty = field.get_setter_arg_type();
        quote! { #ident: std::option::Option<#ty>}
    });

    let name_builder = get_builder_name(type_name);

    quote! {
        pub struct #name_builder {
            #(#option_fields),*
        }
    }
}

fn make_builder_error_type(type_name: &proc_macro2::Ident) -> proc_macro2::TokenStream {
    let error_type_name = get_error_type_name(type_name);

    quote! {
        #(derive(Debug))
        pub struct #error_type_name {
            missing_field: &'static str,
        }

        impl std::fmt::Display for #error_type_name {
            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                write!(f, "Field {} is missing from builder.", self.missing_field)
            }
        }

        impl std::error::Error for #error_type_name {}
    }
}

fn make_builder_impl(
    type_name: &proc_macro2::Ident,
    fields: &Vec<AnalyzedField>,
) -> proc_macro2::TokenStream {
    let methods = make_builder_methods(fields);
    let error_type_name = get_error_type_name(type_name);
    let name_builder = get_builder_name(type_name);

    let checks = fields.iter().map(|field| {
        let ident = field.get_field().ident.as_ref().unwrap();
        match field.get_meta_data() {
            MetaData::Optional(_) => {
                quote! {
                    let #ident = self.#ident.take();
                }
            }
            _ => {
                let field_name = ident.to_string();
                quote! {
                    let #ident = self.#ident.take()
                        .ok_or(#error_type_name { missing_field: #field_name })?;
                }
            }
        }
    });

    let setters = fields.iter().map(|field| {
        let ident = field.get_field().ident.as_ref().unwrap();
        quote! {
            #ident
        }
    });

    quote! {
        impl #name_builder {
            #(#methods)*

            pub fn build(&mut self) -> std::result::Result<#type_name, #error_type_name> {
                #(#checks)*
                Ok(#type_name {
                    #(#setters),*
                })
            }
        }
    }
}

fn make_builder_methods(fields: &Vec<AnalyzedField>) -> Vec<proc_macro2::TokenStream> {
    fields
        .iter()
        .map(|field| {
            let ident = field.get_field().ident.as_ref().unwrap();
            let ty = field.get_setter_arg_type();
            let mut setter = quote! {
                pub fn #ident(&mut self, #ident: #ty) -> &mut Self {
                    self.#ident = std::option::Option::Some(#ident);
                    self
                }
            };
            let mut appender = proc_macro2::TokenStream::new();
            match field.get_meta_data() {
                MetaData::Each(ref s, ref t) => {
                    let appender_ident = Ident::new(s, Span::call_site());
                    if &ident.to_string() == s {
                        setter = proc_macro2::TokenStream::new();
                    }
                    appender = quote! {
                        pub fn #appender_ident(&mut self, #appender_ident: #t) -> &mut Self {
                            self.#ident.as_mut().unwrap().push(#appender_ident);
                            self
                        }
                    };
                }
                _ => (),
            };
            quote! {
                #setter
                #appender
            }
        })
        .collect()
}

fn get_builder_name(type_name: &proc_macro2::Ident) -> proc_macro2::Ident {
    format_ident!("{}Builder", type_name)
}

fn get_error_type_name(type_name: &proc_macro2::Ident) -> proc_macro2::Ident {
    format_ident!("Incomplete{}Error", type_name)
}

fields.rs

use proc_macro2::{self, Ident};
use syn::Field;

/// A struct field that has been analyzed for
/// properties of interest to the Builder macro.
/// The results of the analysis are given as (`MetaData`).
pub struct AnalyzedField {
    field: Field,
    meta: MetaData,
}

/// Holds additional information on fields to be initialized
/// in the builder.
pub enum MetaData {
    /// A field with no special treatment
    Normal,
    /// A field that is an (Option<T>) in the final struct
    /// and is thus not obligatory to set in the builder.
    /// The contained type is the `T` from the Option.
    Optional(syn::Type),
    /// A field that is a (Vec<T>) in the struct
    /// and can be appended to in the builder using the method
    /// with the name in the string.
    /// The contained type is the `T` from the Vec.
    Each(String, syn::Type),
}

const EACH_KEYWORD: &str = "each";
const ATTR_PATH: &str = "builder";
const EACH_REQ_TYPE: &str = "Vec";
const OPTIONAL_TYPE: &str = "Option";

impl AnalyzedField {

    /// Analyzes the field and initializes the attached metadata.
    /// Returns an error if one occured while parsing a builder attribute
    /// or if there was a type mismatch between attribute requirements
    /// and the given types in the struct (e.g. "each" requires a (Vec<T>)).
    pub fn new(field: Field) -> Result<Self, syn::Error> {
        let builder_attr = Self::get_builder_attr(&field);
        if let Some(attr) = builder_attr {
            let appender_name = Self::parse_each_attribute(attr)?;
            match Self::get_generic_type(&field, EACH_REQ_TYPE) {
                Some(t) => Ok(Self {
                    field,
                    meta: MetaData::Each(appender_name, t),
                }),
                None => Err(syn::Error::new_spanned(
                    field,
                    format!("`{}` requires type {}<T>", EACH_KEYWORD, EACH_REQ_TYPE),
                )),
            }
        } else if let Some(t) = Self::get_generic_type(&field, OPTIONAL_TYPE) {
            Ok(Self {
                field,
                meta: MetaData::Optional(t),
            })
        } else {
            Ok(Self {
                field,
                meta: MetaData::Normal,
            })
        }
    }

    /// Returns the field this instance was initialized with.
    pub fn get_field(&self) -> &Field {
        &self.field
    }

    /// Returns the meta data associated with the field.
    pub fn get_meta_data(&self) -> &MetaData {
        &self.meta
    }

    /// If the field's type is of the form ty<T>, returns Some(T).
    fn get_generic_type(field: &Field, ty: &str) -> Option<syn::Type> {
        if let &syn::Type::Path(syn::TypePath {
            qself: None,
            ref path,
        }) = &field.ty
        {
            if path.segments.len() != 1 {
                return None;
            }

            let segment = path.segments.first().unwrap();
            if segment.ident.to_string() != ty.to_owned() {
                return None;
            }

            match segment.arguments {
                syn::PathArguments::AngleBracketed(ref args) => {
                    Self::extract_generic_argument(args)
                }
                _ => None,
            }
        } else {
            None
        }
    }

    /// If there is exactly one generic arguments, returns it.
    fn extract_generic_argument(args: &syn::AngleBracketedGenericArguments) -> Option<syn::Type> {
        if args.args.len() == 1 {
            match args.args.first().unwrap() {
                syn::GenericArgument::Type(ref t) => Some(t.clone()),
                _ => None,
            }
        } else {
            None
        }
    }

    /// Returns an attribute of the form #(builder(...)) if one exists on the field.
    fn get_builder_attr(field: &Field) -> Option<&syn::Attribute> {
        field
            .attrs
            .iter()
            .find(|attr| attr.path.get_ident().map(Ident::to_string) == Some(ATTR_PATH.to_owned()))
    }

    /// Parses an attribute of the form #(builder(each = "...")) and returns
    /// the literal string argument. Returns an error if the attribute cannot
    /// be parsed in this manner.
    fn parse_each_attribute(attr: &syn::Attribute) -> Result<String, syn::Error> {
        let meta = attr.parse_meta()?;

        let error = Err(syn::Error::new_spanned(
            &meta,
            format!("expected `{}({} = "...")`", ATTR_PATH, EACH_KEYWORD),
        ));

        let args = if let syn::Meta::List(ref list) = meta {
            &list.nested
        } else {
            return error;
        };

        if args.len() != 1 {
            return error;
        }

        if let Some(&syn::NestedMeta::Meta(syn::Meta::NameValue(ref name))) = args.first() {
            if name.path.get_ident().map(Ident::to_string) == Some(EACH_KEYWORD.to_owned()) {
                match &name.lit {
                    &syn::Lit::Str(ref lit_str) => return Ok(lit_str.value()),
                    _ => return error,
                }
            } else {
                return error;
            }
        } else {
            return error;
        }
    }

    /// Gets the type that should be used by a setter in the builder.
    pub fn get_setter_arg_type(&self) -> &syn::Type {
        match &self.meta {
            &MetaData::Optional(ref t) => t,
            _ => &self.field.ty,
        }
    }
}