sui_default_config/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use proc_macro::TokenStream;
5
6use quote::format_ident;
7use quote::quote;
8use syn::Attribute;
9use syn::Data;
10use syn::DataStruct;
11use syn::DeriveInput;
12use syn::Fields;
13use syn::FieldsNamed;
14use syn::Meta;
15use syn::MetaList;
16use syn::MetaNameValue;
17use syn::NestedMeta;
18use syn::parse_macro_input;
19
20/// Attribute macro to be applied to config-based structs. It ensures that the struct derives serde
21/// traits, and `Debug`, that all fields are renamed with "kebab case", and adds a `#[serde(default
22/// = ...)]` implementation for each field that ensures that if the field is not present during
23/// deserialization, it is replaced with its default value, from the `Default` implementation for
24/// the config struct.
25#[allow(non_snake_case)]
26#[proc_macro_attribute]
27pub fn DefaultConfig(_attr: TokenStream, input: TokenStream) -> TokenStream {
28    let DeriveInput {
29        attrs,
30        vis,
31        ident,
32        generics,
33        data,
34    } = parse_macro_input!(input as DeriveInput);
35
36    let Data::Struct(DataStruct {
37        struct_token,
38        fields,
39        semi_token,
40    }) = data
41    else {
42        panic!("Default configs must be structs.");
43    };
44
45    let Fields::Named(FieldsNamed {
46        brace_token: _,
47        named,
48    }) = fields
49    else {
50        panic!("Default configs must have named fields.");
51    };
52
53    // Extract field names once to avoid having to check for their existence multiple times.
54    let fields_with_names: Vec<_> = named
55        .iter()
56        .map(|field| {
57            let Some(ident) = &field.ident else {
58                panic!("All fields must have an identifier.");
59            };
60
61            (ident, field)
62        })
63        .collect();
64
65    // Generate the fields with the `#[serde(default = ...)]` attribute.
66    let fields = fields_with_names.iter().map(|(name, field)| {
67        let default = format!("{ident}::__default_{name}");
68        quote! { #[serde(default = #default)] #field }
69    });
70
71    // Generate the default implementations for each field.
72    let defaults = fields_with_names.iter().map(|(name, field)| {
73        let ty = &field.ty;
74        let fn_name = format_ident!("__default_{}", name);
75        let cfg = extract_cfg(&field.attrs);
76
77        quote! {
78            #[doc(hidden)] #cfg
79            fn #fn_name() -> #ty {
80                <Self as std::default::Default>::default().#name
81            }
82        }
83    });
84
85    // Check if there's already a serde rename_all attribute
86    let has_rename_all = attrs.iter().any(|attr| {
87        if !attr.path.is_ident("serde") {
88            return false;
89        };
90
91        let Ok(Meta::List(MetaList { nested, .. })) = attr.parse_meta() else {
92            return false;
93        };
94
95        nested.iter().any(|nested| {
96            if let NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, .. })) = nested {
97                path.is_ident("rename_all")
98            } else {
99                false
100            }
101        })
102    });
103
104    // Only include the default rename_all if none exists
105    let rename_all = if !has_rename_all {
106        quote! { #[serde(rename_all = "kebab-case")] }
107    } else {
108        quote! {}
109    };
110
111    TokenStream::from(quote! {
112        #[derive(serde::Serialize, serde::Deserialize)]
113        #rename_all
114        #(#attrs)* #vis #struct_token #ident #generics {
115            #(#fields),*
116        } #semi_token
117
118        impl #ident {
119            #(#defaults)*
120        }
121    })
122}
123
124/// Find the attribute that corresponds to a `#[cfg(...)]` annotation, if it exists.
125fn extract_cfg(attrs: &[Attribute]) -> Option<&Attribute> {
126    attrs.iter().find(|attr| {
127        let meta = attr.parse_meta().ok();
128        meta.is_some_and(|m| m.path().is_ident("cfg"))
129    })
130}