sui_protocol_config_macros/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{Data, DeriveInput, Fields, Type, parse_macro_input};
9
10/// This proc macro generates getters, attribute lookup, etc for protocol config fields of type `Option<T>`
11/// and for the feature flags
12/// Example for a field: `new_constant: Option<u64>`, and for feature flags `feature: bool`, we derive
13/// ```rust,ignore
14///     /// Returns the value of the field if exists at the given version, otherise panic
15///     pub fn new_constant(&self) -> u64 {
16///         self.new_constant.expect(Self::CONSTANT_ERR_MSG)
17///     }
18///     /// Returns the value of the field if exists at the given version, otherise None.
19///     pub fn new_constant_as_option(&self) -> Option<u64> {
20///         self.new_constant
21///     }
22///     // We auto derive an enum such that the variants are all the types of the fields
23///     pub enum ProtocolConfigValue {
24///        u32(u32),
25///        u64(u64),
26///        ..............
27///     }
28///     // This enum is used to return field values so that the type is also encoded in the response
29///
30///     /// Returns the value of the field if exists at the given version, otherise None
31///     pub fn lookup_attr(&self, value: String) -> Option<ProtocolConfigValue>;
32///
33///     /// Returns a map of all configs to values
34///     pub fn attr_map(&self) -> std::collections::BTreeMap<String, Option<ProtocolConfigValue>>;
35///
36///     /// Returns a feature by the string name or None if it doesn't exist
37///     pub fn lookup_feature(&self, value: String) -> Option<bool>;
38///
39///     /// Returns a map of all features to values
40///     pub fn feature_map(&self) -> std::collections::BTreeMap<String, bool>;
41/// ```
42#[proc_macro_derive(ProtocolConfigAccessors)]
43pub fn accessors_macro(input: TokenStream) -> TokenStream {
44    let ast = parse_macro_input!(input as DeriveInput);
45
46    let struct_name = &ast.ident;
47    let data = &ast.data;
48    let mut inner_types = vec![];
49
50    let tokens = match data {
51        Data::Struct(data_struct) => match &data_struct.fields {
52            // Operate on each field of the ProtocolConfig struct
53            Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
54                // Extract field name and type
55                let field_name = field.ident.as_ref().expect("Field must be named");
56                let field_type = &field.ty;
57                // Check if field is of type Option<T>
58                match field_type {
59                    Type::Path(type_path)
60                        if type_path
61                            .path
62                            .segments
63                            .last().is_some_and(|segment| segment.ident == "Option") =>
64                    {
65                        // Extract inner type T from Option<T>
66                        let inner_type = if let syn::PathArguments::AngleBracketed(
67                            angle_bracketed_generic_arguments,
68                        ) = &type_path.path.segments.last().unwrap().arguments
69                        {
70                            if let Some(syn::GenericArgument::Type(ty)) =
71                                angle_bracketed_generic_arguments.args.first()
72                            {
73                                ty.clone()
74                            } else {
75                                panic!("Expected a type argument.");
76                            }
77                        } else {
78                            panic!("Expected angle bracketed arguments.");
79                        };
80
81                        // Skip vec types - write your own accessor
82                        let is_vec = matches!(&inner_type, Type::Path(tp)
83                            if tp.path.segments.last()
84                                .is_some_and(|s| s.ident == "Vec"));
85                        if is_vec {
86                            return None;
87                        }
88
89                        let as_option_name = format!("{field_name}_as_option");
90                        let as_option_name: proc_macro2::TokenStream =
91                        as_option_name.parse().unwrap();
92                        let test_setter_name: proc_macro2::TokenStream =
93                            format!("set_{field_name}_for_testing").parse().unwrap();
94                        let test_un_setter_name: proc_macro2::TokenStream =
95                            format!("disable_{field_name}_for_testing").parse().unwrap();
96                        let test_setter_from_str_name: proc_macro2::TokenStream =
97                            format!("set_{field_name}_from_str_for_testing").parse().unwrap();
98
99                        let getter = quote! {
100                            // Derive the getter
101                            pub fn #field_name(&self) -> #inner_type {
102                                self.#field_name.expect(Self::CONSTANT_ERR_MSG)
103                            }
104
105                            pub fn #as_option_name(&self) -> #field_type {
106                                self.#field_name
107                            }
108                        };
109
110                        let test_setter = quote! {
111                            // Derive the setter
112                            pub fn #test_setter_name(&mut self, val: #inner_type) {
113                                self.#field_name = Some(val);
114                            }
115
116                            // Derive the setter from String
117                            pub fn #test_setter_from_str_name(&mut self, val: String) {
118                                use std::str::FromStr;
119                                self.#test_setter_name(#inner_type::from_str(&val).unwrap());
120                            }
121
122                            // Derive the un-setter
123                            pub fn #test_un_setter_name(&mut self) {
124                                self.#field_name = None;
125                            }
126                        };
127
128                        let value_setter = quote! {
129                            stringify!(#field_name) => self.#test_setter_from_str_name(val),
130                        };
131
132
133                        let value_lookup = quote! {
134                            stringify!(#field_name) => self.#field_name.map(|v| ProtocolConfigValue::#inner_type(v)),
135                        };
136
137                        let field_name_str = quote! {
138                            stringify!(#field_name)
139                        };
140
141                        // Track all the types seen
142                        if inner_types.contains(&inner_type) {
143                            None
144                        } else {
145                            inner_types.push(inner_type.clone());
146                            Some(quote! {
147                               #inner_type
148                            })
149                        };
150
151                        Some(((getter, (test_setter, value_setter)), (value_lookup, field_name_str)))
152                    }
153                    _ => None,
154                }
155            }),
156            _ => panic!("Only named fields are supported."),
157        },
158        _ => panic!("Only structs supported."),
159    };
160
161    #[allow(clippy::type_complexity)]
162    let ((getters, (test_setters, value_setters)), (value_lookup, field_names_str)): (
163        (Vec<_>, (Vec<_>, Vec<_>)),
164        (Vec<_>, Vec<_>),
165    ) = tokens.unzip();
166    let output = quote! {
167        // For each getter, expand it out into a function in the impl block
168        impl #struct_name {
169            const CONSTANT_ERR_MSG: &'static str = "protocol constant not present in current protocol version";
170            #(#getters)*
171
172            /// Lookup a config attribute by its string representation
173            pub fn lookup_attr(&self, value: String) -> Option<ProtocolConfigValue> {
174                match value.as_str() {
175                    #(#value_lookup)*
176                    _ => None,
177                }
178            }
179
180            /// Get a map of all config attribute from string representations
181            pub fn attr_map(&self) -> std::collections::BTreeMap<String, Option<ProtocolConfigValue>> {
182                vec![
183                    #(((#field_names_str).to_owned(), self.lookup_attr((#field_names_str).to_owned())),)*
184                    ].into_iter().collect()
185            }
186
187            /// Get the feature flags
188            pub fn lookup_feature(&self, value: String) -> Option<bool> {
189                self.feature_flags.lookup_attr(value)
190            }
191
192            pub fn feature_map(&self) -> std::collections::BTreeMap<String, bool> {
193                self.feature_flags.attr_map()
194            }
195        }
196
197        // For each attr, derive a setter from the raw value and from string repr
198        impl #struct_name {
199            #(#test_setters)*
200
201            pub fn set_attr_for_testing(&mut self, attr: String, val: String) {
202                match attr.as_str() {
203                    #(#value_setters)*
204                    _ => panic!("Attempting to set unknown attribute: {}", attr),
205                }
206            }
207        }
208
209        #[allow(non_camel_case_types)]
210        #[derive(Clone, Serialize, Debug, PartialEq, Deserialize, schemars::JsonSchema)]
211        pub enum ProtocolConfigValue {
212            #(#inner_types(#inner_types),)*
213        }
214
215        impl std::fmt::Display for ProtocolConfigValue {
216            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217                use std::fmt::Write;
218                let mut writer = String::new();
219                match self {
220                    #(
221                        ProtocolConfigValue::#inner_types(x) => {
222                            write!(writer, "{}", x)?;
223                        }
224                    )*
225                }
226                write!(f, "{}", writer)
227            }
228        }
229    };
230
231    TokenStream::from(output)
232}
233
234#[proc_macro_derive(ProtocolConfigOverride)]
235pub fn protocol_config_override_macro(input: TokenStream) -> TokenStream {
236    let ast = parse_macro_input!(input as DeriveInput);
237
238    // Create a new struct name by appending "Optional".
239    let struct_name = &ast.ident;
240    let optional_struct_name =
241        syn::Ident::new(&format!("{}Optional", struct_name), struct_name.span());
242
243    // Extract the fields from the struct
244    let fields = match &ast.data {
245        Data::Struct(data_struct) => match &data_struct.fields {
246            Fields::Named(fields_named) => &fields_named.named,
247            _ => panic!("ProtocolConfig must have named fields"),
248        },
249        _ => panic!("ProtocolConfig must be a struct"),
250    };
251
252    // Create new fields with types wrapped in Option.
253    let optional_fields = fields.iter().map(|field| {
254        let field_name = &field.ident;
255        let field_type = &field.ty;
256        quote! {
257            #field_name: Option<#field_type>
258        }
259    });
260
261    // Generate the function to update the original struct.
262    let update_fields = fields.iter().map(|field| {
263        let field_name = &field.ident;
264        quote! {
265            if let Some(value) = self.#field_name {
266                tracing::warn!(
267                    "ProtocolConfig field \"{}\" has been overridden with the value: {value:?}",
268                    stringify!(#field_name),
269                );
270                config.#field_name = value;
271            }
272        }
273    });
274
275    // Generate the new struct definition.
276    let output = quote! {
277        #[derive(serde::Deserialize, Debug)]
278        pub struct #optional_struct_name {
279            #(#optional_fields,)*
280        }
281
282        impl #optional_struct_name {
283            pub fn apply_to(self, config: &mut #struct_name) {
284                #(#update_fields)*
285            }
286        }
287    };
288
289    TokenStream::from(output)
290}
291
292#[proc_macro_derive(ProtocolConfigFeatureFlagsGetters)]
293pub fn feature_flag_getters_macro(input: TokenStream) -> TokenStream {
294    let ast = parse_macro_input!(input as DeriveInput);
295
296    let struct_name = &ast.ident;
297    let data = &ast.data;
298
299    let getters = match data {
300        Data::Struct(data_struct) => match &data_struct.fields {
301            // Operate on each field of the ProtocolConfig struct
302            Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
303                // Extract field name and type
304                let field_name = field.ident.as_ref().expect("Field must be named");
305                let field_type = &field.ty;
306                // Check if field is of type bool
307                match field_type {
308                    Type::Path(type_path)
309                        if type_path
310                            .path
311                            .segments
312                            .last()
313                            .is_some_and(|segment| segment.ident == "bool") =>
314                    {
315                        Some((
316                            quote! {
317                                // Derive the getter
318                                pub fn #field_name(&self) -> #field_type {
319                                    self.#field_name
320                                }
321                            },
322                            (
323                                quote! {
324                                    stringify!(#field_name) => Some(self.#field_name),
325                                },
326                                quote! {
327                                    stringify!(#field_name)
328                                },
329                            ),
330                        ))
331                    }
332                    _ => None,
333                }
334            }),
335            _ => panic!("Only named fields are supported."),
336        },
337        _ => panic!("Only structs supported."),
338    };
339
340    let (by_fn_getters, (string_name_getters, field_names)): (Vec<_>, (Vec<_>, Vec<_>)) =
341        getters.unzip();
342
343    let output = quote! {
344        // For each getter, expand it out into a function in the impl block
345        impl #struct_name {
346            #(#by_fn_getters)*
347
348            /// Lookup a feature flag by its string representation
349            pub fn lookup_attr(&self, value: String) -> Option<bool> {
350                match value.as_str() {
351                    #(#string_name_getters)*
352                    _ => None,
353                }
354            }
355
356            /// Get a map of all feature flags from string representations
357            pub fn attr_map(&self) -> std::collections::BTreeMap<String, bool> {
358                vec![
359                    // Okay to unwrap since we added all above
360                    #(((#field_names).to_owned(), self.lookup_attr((#field_names).to_owned()).unwrap()),)*
361                    ].into_iter().collect()
362            }
363        }
364    };
365
366    TokenStream::from(output)
367}