sui_default_config/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use syn::{
7    parse_macro_input, Attribute, Data, DataStruct, DeriveInput, Fields, FieldsNamed, Meta,
8    MetaList, MetaNameValue, NestedMeta,
9};
10
11/// Attribute macro to be applied to config-based structs. It ensures that the struct derives serde
12/// traits, and `Debug`, that all fields are renamed with "kebab case", and adds a `#[serde(default
13/// = ...)]` implementation for each field that ensures that if the field is not present during
14/// deserialization, it is replaced with its default value, from the `Default` implementation for
15/// the config struct.
16#[allow(non_snake_case)]
17#[proc_macro_attribute]
18pub fn DefaultConfig(_attr: TokenStream, input: TokenStream) -> TokenStream {
19    let DeriveInput {
20        attrs,
21        vis,
22        ident,
23        generics,
24        data,
25    } = parse_macro_input!(input as DeriveInput);
26
27    let Data::Struct(DataStruct {
28        struct_token,
29        fields,
30        semi_token,
31    }) = data
32    else {
33        panic!("Default configs must be structs.");
34    };
35
36    let Fields::Named(FieldsNamed {
37        brace_token: _,
38        named,
39    }) = fields
40    else {
41        panic!("Default configs must have named fields.");
42    };
43
44    // Extract field names once to avoid having to check for their existence multiple times.
45    let fields_with_names: Vec<_> = named
46        .iter()
47        .map(|field| {
48            let Some(ident) = &field.ident else {
49                panic!("All fields must have an identifier.");
50            };
51
52            (ident, field)
53        })
54        .collect();
55
56    // Generate the fields with the `#[serde(default = ...)]` attribute.
57    let fields = fields_with_names.iter().map(|(name, field)| {
58        let default = format!("{ident}::__default_{name}");
59        quote! { #[serde(default = #default)] #field }
60    });
61
62    // Generate the default implementations for each field.
63    let defaults = fields_with_names.iter().map(|(name, field)| {
64        let ty = &field.ty;
65        let fn_name = format_ident!("__default_{}", name);
66        let cfg = extract_cfg(&field.attrs);
67
68        quote! {
69            #[doc(hidden)] #cfg
70            fn #fn_name() -> #ty {
71                <Self as std::default::Default>::default().#name
72            }
73        }
74    });
75
76    // Check if there's already a serde rename_all attribute
77    let has_rename_all = attrs.iter().any(|attr| {
78        if !attr.path.is_ident("serde") {
79            return false;
80        };
81
82        let Ok(Meta::List(MetaList { nested, .. })) = attr.parse_meta() else {
83            return false;
84        };
85
86        nested.iter().any(|nested| {
87            if let NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, .. })) = nested {
88                path.is_ident("rename_all")
89            } else {
90                false
91            }
92        })
93    });
94
95    // Only include the default rename_all if none exists
96    let rename_all = if !has_rename_all {
97        quote! { #[serde(rename_all = "kebab-case")] }
98    } else {
99        quote! {}
100    };
101
102    TokenStream::from(quote! {
103        #[derive(serde::Serialize, serde::Deserialize)]
104        #rename_all
105        #(#attrs)* #vis #struct_token #ident #generics {
106            #(#fields),*
107        } #semi_token
108
109        impl #ident {
110            #(#defaults)*
111        }
112    })
113}
114
115/// Find the attribute that corresponds to a `#[cfg(...)]` annotation, if it exists.
116fn extract_cfg(attrs: &[Attribute]) -> Option<&Attribute> {
117    attrs.iter().find(|attr| {
118        let meta = attr.parse_meta().ok();
119        meta.is_some_and(|m| m.path().is_ident("cfg"))
120    })
121}