sui_protocol_config_macros/
lib.rs1extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{Data, DeriveInput, Fields, Type, parse_macro_input};
9
10#[proc_macro_derive(ProtocolConfigAccessors)]
43pub fn accessors_macro(input: TokenStream) -> TokenStream {
44    let ast = parse_macro_input!(input as DeriveInput);
45
46    let struct_name = &ast.ident;
47    let data = &ast.data;
48    let mut inner_types = vec![];
49
50    let tokens = match data {
51        Data::Struct(data_struct) => match &data_struct.fields {
52            Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
54                let field_name = field.ident.as_ref().expect("Field must be named");
56                let field_type = &field.ty;
57                match field_type {
59                    Type::Path(type_path)
60                        if type_path
61                            .path
62                            .segments
63                            .last().is_some_and(|segment| segment.ident == "Option") =>
64                    {
65                        let inner_type = if let syn::PathArguments::AngleBracketed(
67                            angle_bracketed_generic_arguments,
68                        ) = &type_path.path.segments.last().unwrap().arguments
69                        {
70                            if let Some(syn::GenericArgument::Type(ty)) =
71                                angle_bracketed_generic_arguments.args.first()
72                            {
73                                ty.clone()
74                            } else {
75                                panic!("Expected a type argument.");
76                            }
77                        } else {
78                            panic!("Expected angle bracketed arguments.");
79                        };
80
81                        let as_option_name = format!("{field_name}_as_option");
82                        let as_option_name: proc_macro2::TokenStream =
83                        as_option_name.parse().unwrap();
84                        let test_setter_name: proc_macro2::TokenStream =
85                            format!("set_{field_name}_for_testing").parse().unwrap();
86                        let test_un_setter_name: proc_macro2::TokenStream =
87                            format!("disable_{field_name}_for_testing").parse().unwrap();
88                        let test_setter_from_str_name: proc_macro2::TokenStream =
89                            format!("set_{field_name}_from_str_for_testing").parse().unwrap();
90
91                        let getter = quote! {
92                            pub fn #field_name(&self) -> #inner_type {
94                                self.#field_name.expect(Self::CONSTANT_ERR_MSG)
95                            }
96
97                            pub fn #as_option_name(&self) -> #field_type {
98                                self.#field_name
99                            }
100                        };
101
102                        let test_setter = quote! {
103                            pub fn #test_setter_name(&mut self, val: #inner_type) {
105                                self.#field_name = Some(val);
106                            }
107
108                            pub fn #test_setter_from_str_name(&mut self, val: String) {
110                                use std::str::FromStr;
111                                self.#test_setter_name(#inner_type::from_str(&val).unwrap());
112                            }
113
114                            pub fn #test_un_setter_name(&mut self) {
116                                self.#field_name = None;
117                            }
118                        };
119
120                        let value_setter = quote! {
121                            stringify!(#field_name) => self.#test_setter_from_str_name(val),
122                        };
123
124
125                        let value_lookup = quote! {
126                            stringify!(#field_name) => self.#field_name.map(|v| ProtocolConfigValue::#inner_type(v)),
127                        };
128
129                        let field_name_str = quote! {
130                            stringify!(#field_name)
131                        };
132
133                        if inner_types.contains(&inner_type) {
135                            None
136                        } else {
137                            inner_types.push(inner_type.clone());
138                            Some(quote! {
139                               #inner_type
140                            })
141                        };
142
143                        Some(((getter, (test_setter, value_setter)), (value_lookup, field_name_str)))
144                    }
145                    _ => None,
146                }
147            }),
148            _ => panic!("Only named fields are supported."),
149        },
150        _ => panic!("Only structs supported."),
151    };
152
153    #[allow(clippy::type_complexity)]
154    let ((getters, (test_setters, value_setters)), (value_lookup, field_names_str)): (
155        (Vec<_>, (Vec<_>, Vec<_>)),
156        (Vec<_>, Vec<_>),
157    ) = tokens.unzip();
158    let output = quote! {
159        impl #struct_name {
161            const CONSTANT_ERR_MSG: &'static str = "protocol constant not present in current protocol version";
162            #(#getters)*
163
164            pub fn lookup_attr(&self, value: String) -> Option<ProtocolConfigValue> {
166                match value.as_str() {
167                    #(#value_lookup)*
168                    _ => None,
169                }
170            }
171
172            pub fn attr_map(&self) -> std::collections::BTreeMap<String, Option<ProtocolConfigValue>> {
174                vec![
175                    #(((#field_names_str).to_owned(), self.lookup_attr((#field_names_str).to_owned())),)*
176                    ].into_iter().collect()
177            }
178
179            pub fn lookup_feature(&self, value: String) -> Option<bool> {
181                self.feature_flags.lookup_attr(value)
182            }
183
184            pub fn feature_map(&self) -> std::collections::BTreeMap<String, bool> {
185                self.feature_flags.attr_map()
186            }
187        }
188
189        impl #struct_name {
191            #(#test_setters)*
192
193            pub fn set_attr_for_testing(&mut self, attr: String, val: String) {
194                match attr.as_str() {
195                    #(#value_setters)*
196                    _ => panic!("Attempting to set unknown attribute: {}", attr),
197                }
198            }
199        }
200
201        #[allow(non_camel_case_types)]
202        #[derive(Clone, Serialize, Debug, PartialEq, Deserialize, schemars::JsonSchema)]
203        pub enum ProtocolConfigValue {
204            #(#inner_types(#inner_types),)*
205        }
206
207        impl std::fmt::Display for ProtocolConfigValue {
208            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209                use std::fmt::Write;
210                let mut writer = String::new();
211                match self {
212                    #(
213                        ProtocolConfigValue::#inner_types(x) => {
214                            write!(writer, "{}", x)?;
215                        }
216                    )*
217                }
218                write!(f, "{}", writer)
219            }
220        }
221    };
222
223    TokenStream::from(output)
224}
225
226#[proc_macro_derive(ProtocolConfigOverride)]
227pub fn protocol_config_override_macro(input: TokenStream) -> TokenStream {
228    let ast = parse_macro_input!(input as DeriveInput);
229
230    let struct_name = &ast.ident;
232    let optional_struct_name =
233        syn::Ident::new(&format!("{}Optional", struct_name), struct_name.span());
234
235    let fields = match &ast.data {
237        Data::Struct(data_struct) => match &data_struct.fields {
238            Fields::Named(fields_named) => &fields_named.named,
239            _ => panic!("ProtocolConfig must have named fields"),
240        },
241        _ => panic!("ProtocolConfig must be a struct"),
242    };
243
244    let optional_fields = fields.iter().map(|field| {
246        let field_name = &field.ident;
247        let field_type = &field.ty;
248        quote! {
249            #field_name: Option<#field_type>
250        }
251    });
252
253    let update_fields = fields.iter().map(|field| {
255        let field_name = &field.ident;
256        quote! {
257            if let Some(value) = self.#field_name {
258                tracing::warn!(
259                    "ProtocolConfig field \"{}\" has been overridden with the value: {value:?}",
260                    stringify!(#field_name),
261                );
262                config.#field_name = value;
263            }
264        }
265    });
266
267    let output = quote! {
269        #[derive(serde::Deserialize, Debug)]
270        pub struct #optional_struct_name {
271            #(#optional_fields,)*
272        }
273
274        impl #optional_struct_name {
275            pub fn apply_to(self, config: &mut #struct_name) {
276                #(#update_fields)*
277            }
278        }
279    };
280
281    TokenStream::from(output)
282}
283
284#[proc_macro_derive(ProtocolConfigFeatureFlagsGetters)]
285pub fn feature_flag_getters_macro(input: TokenStream) -> TokenStream {
286    let ast = parse_macro_input!(input as DeriveInput);
287
288    let struct_name = &ast.ident;
289    let data = &ast.data;
290
291    let getters = match data {
292        Data::Struct(data_struct) => match &data_struct.fields {
293            Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
295                let field_name = field.ident.as_ref().expect("Field must be named");
297                let field_type = &field.ty;
298                match field_type {
300                    Type::Path(type_path)
301                        if type_path
302                            .path
303                            .segments
304                            .last()
305                            .is_some_and(|segment| segment.ident == "bool") =>
306                    {
307                        Some((
308                            quote! {
309                                pub fn #field_name(&self) -> #field_type {
311                                    self.#field_name
312                                }
313                            },
314                            (
315                                quote! {
316                                    stringify!(#field_name) => Some(self.#field_name),
317                                },
318                                quote! {
319                                    stringify!(#field_name)
320                                },
321                            ),
322                        ))
323                    }
324                    _ => None,
325                }
326            }),
327            _ => panic!("Only named fields are supported."),
328        },
329        _ => panic!("Only structs supported."),
330    };
331
332    let (by_fn_getters, (string_name_getters, field_names)): (Vec<_>, (Vec<_>, Vec<_>)) =
333        getters.unzip();
334
335    let output = quote! {
336        impl #struct_name {
338            #(#by_fn_getters)*
339
340            pub fn lookup_attr(&self, value: String) -> Option<bool> {
342                match value.as_str() {
343                    #(#string_name_getters)*
344                    _ => None,
345                }
346            }
347
348            pub fn attr_map(&self) -> std::collections::BTreeMap<String, bool> {
350                vec![
351                    #(((#field_names).to_owned(), self.lookup_attr((#field_names).to_owned()).unwrap()),)*
353                    ].into_iter().collect()
354            }
355        }
356    };
357
358    TokenStream::from(output)
359}