sui_protocol_config_macros/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{Data, DeriveInput, Fields, Type, parse_macro_input};
9
10/// This proc macro generates getters, attribute lookup, etc for protocol config fields of type `Option<T>`
11/// and for the feature flags
12/// Example for a field: `new_constant: Option<u64>`, and for feature flags `feature: bool`, we derive
13/// ```rust,ignore
14///     /// Returns the value of the field if exists at the given version, otherise panic
15///     pub fn new_constant(&self) -> u64 {
16///         self.new_constant.expect(Self::CONSTANT_ERR_MSG)
17///     }
18///     /// Returns the value of the field if exists at the given version, otherise None.
19///     pub fn new_constant_as_option(&self) -> Option<u64> {
20///         self.new_constant
21///     }
22///     // We auto derive an enum such that the variants are all the types of the fields
23///     pub enum ProtocolConfigValue {
24///        u32(u32),
25///        u64(u64),
26///        ..............
27///     }
28///     // This enum is used to return field values so that the type is also encoded in the response
29///
30///     /// Returns the value of the field if exists at the given version, otherise None
31///     pub fn lookup_attr(&self, value: String) -> Option<ProtocolConfigValue>;
32///
33///     /// Returns a map of all configs to values
34///     pub fn attr_map(&self) -> std::collections::BTreeMap<String, Option<ProtocolConfigValue>>;
35///
36///     /// Returns a feature by the string name or None if it doesn't exist
37///     pub fn lookup_feature(&self, value: String) -> Option<bool>;
38///
39///     /// Returns a map of all features to values
40///     pub fn feature_map(&self) -> std::collections::BTreeMap<String, bool>;
41/// ```
42///
43/// Every field (scalar and non-scalar) is also emitted into a typed
44/// `render<F: Format>(&self, meter: &mut impl Meter) -> Result<BTreeMap<String, F>, MeterError>`
45/// method, where each value is produced via `mysten_common::rpc_format::ToFormat`. Fields that
46/// aren't configured at the current protocol version render as `F::null(...)` rather than being
47/// absent from the map, so the keyset is stable across protocol versions. This is the path RPC
48/// code should use to expose protocol config to clients; the same call site can target
49/// `serde_json::Value`, `prost_types::Value`, or any other `Format` impl by choosing `F`.
50///
51/// Scalar (`u16`/`u32`/`u64`/`bool`) fields continue to feed `ProtocolConfigValue` / `attr_map`
52/// for back-compat with existing consumers. Non-scalar fields appear only in `render`. Add
53/// `#[skip_accessor]` to keep a field internal and out of every generated surface.
54#[proc_macro_derive(ProtocolConfigAccessors, attributes(skip_accessor))]
55pub fn accessors_macro(input: TokenStream) -> TokenStream {
56    let ast = parse_macro_input!(input as DeriveInput);
57
58    let struct_name = &ast.ident;
59    let data = &ast.data;
60
61    let fields: Vec<AccessorField> = match data {
62        Data::Struct(data_struct) => match &data_struct.fields {
63            Fields::Named(fields_named) => fields_named
64                .named
65                .iter()
66                .filter_map(parse_accessor_field)
67                .collect(),
68            _ => panic!("Only named fields are supported."),
69        },
70        _ => panic!("Only structs supported."),
71    };
72
73    let expanded: Vec<ExpandedField> = fields.iter().map(expand_field).collect();
74
75    let accessors = expanded.iter().map(|e| &e.accessor);
76    let setters = expanded.iter().map(|e| &e.setter);
77    let render_arms = expanded.iter().map(|e| &e.render_arm);
78
79    // Scalar-only collections — driven by the optional `ScalarExtras` extension on each field.
80    let scalar_extras: Vec<&ScalarExtras> = expanded
81        .iter()
82        .filter_map(|e| e.scalar_extras.as_ref())
83        .collect();
84    let scalar_value_setter_arms = scalar_extras.iter().map(|s| &s.value_setter_arm);
85    let scalar_lookup_arms = scalar_extras.iter().map(|s| &s.lookup_arm);
86    let scalar_field_names = scalar_extras.iter().map(|s| &s.field_name_str);
87
88    // Multiple scalar fields of the same primitive type all share a single
89    // `ProtocolConfigValue` variant (e.g. every `u64` field maps to `ProtocolConfigValue::u64`).
90    let mut variant_decls = Vec::new();
91    let mut display_variants = Vec::new();
92    let mut seen = std::collections::HashSet::new();
93    for s in &scalar_extras {
94        if !seen.insert(s.variant_ident.to_string()) {
95            continue;
96        }
97        let ident = &s.variant_ident;
98        let inner = &s.inner_type;
99        variant_decls.push(quote! { #ident(#inner) });
100        display_variants.push(s.variant_ident.clone());
101    }
102
103    let output = quote! {
104        impl #struct_name {
105            const CONSTANT_ERR_MSG: &'static str = "protocol constant not present in current protocol version";
106            #(#accessors)*
107
108            /// Lookup a scalar config attribute by its string representation.
109            pub fn lookup_attr(&self, value: String) -> Option<ProtocolConfigValue> {
110                match value.as_str() {
111                    #(#scalar_lookup_arms)*
112                    _ => None,
113                }
114            }
115
116            /// Get a map of all scalar config attributes from string representations.
117            ///
118            /// Non-scalar (e.g. list-typed) fields aren't represented here — use
119            /// `Self::render` for a typed view that includes every field.
120            pub fn attr_map(&self) -> std::collections::BTreeMap<String, Option<ProtocolConfigValue>> {
121                vec![
122                    #(((#scalar_field_names).to_owned(), self.lookup_attr((#scalar_field_names).to_owned())),)*
123                    ].into_iter().collect()
124            }
125
126            /// Render every protocol-config attribute into the chosen `Format`.
127            ///
128            /// Fields that aren't configured at this protocol version render as `F::null(...)`,
129            /// so the keyset is stable across versions and callers can distinguish "unknown
130            /// key" from "present but unset".
131            pub fn render<F>(
132                &self,
133                meter: &mut impl ::mysten_common::rpc_format::Meter,
134            ) -> ::std::result::Result<
135                std::collections::BTreeMap<String, F>,
136                ::mysten_common::rpc_format::MeterError,
137            >
138            where
139                F: ::mysten_common::rpc_format::Format,
140            {
141                let mut map = std::collections::BTreeMap::new();
142                #(#render_arms)*
143                Ok(map)
144            }
145
146            /// Get the feature flags
147            pub fn lookup_feature(&self, value: String) -> Option<bool> {
148                self.feature_flags.lookup_attr(value)
149            }
150
151            pub fn feature_map(&self) -> std::collections::BTreeMap<String, bool> {
152                self.feature_flags.attr_map()
153            }
154        }
155
156        impl #struct_name {
157            #(#setters)*
158
159            pub fn set_attr_for_testing(&mut self, attr: String, val: String) {
160                match attr.as_str() {
161                    #(#scalar_value_setter_arms)*
162                    _ => panic!(
163                        "Attempting to set unknown or non-string-settable attribute: {}",
164                        attr,
165                    ),
166                }
167            }
168        }
169
170        #[allow(non_camel_case_types)]
171        #[derive(Clone, Serialize, Debug, PartialEq, Deserialize, schemars::JsonSchema)]
172        pub enum ProtocolConfigValue {
173            #(#variant_decls,)*
174        }
175
176        impl std::fmt::Display for ProtocolConfigValue {
177            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178                use std::fmt::Write;
179                let mut writer = String::new();
180                match self {
181                    #(
182                        ProtocolConfigValue::#display_variants(x) => {
183                            write!(writer, "{}", x)?;
184                        }
185                    )*
186                }
187                write!(f, "{}", writer)
188            }
189        }
190    };
191
192    TokenStream::from(output)
193}
194
195/// Token streams emitted for a single `ProtocolConfig` field. Every field contributes an
196/// accessor, a setter, and a `render` arm; scalar fields additionally populate
197/// [`ScalarExtras`] for the `ProtocolConfigValue` / `attr_map` / `set_attr_for_testing` paths.
198struct ExpandedField {
199    /// `fn field_name(&self) -> T` (scalars only) + `fn field_name_as_option(&self) -> Option<T>`.
200    accessor: proc_macro2::TokenStream,
201    /// `set_x_for_testing` + `disable_x_for_testing` (always) and the `from_str` variant
202    /// (scalars only — non-scalars don't generally implement `FromStr`).
203    setter: proc_macro2::TokenStream,
204    /// One per-field block of `render` that calls `ToFormat::to_format` and inserts the result.
205    render_arm: proc_macro2::TokenStream,
206    /// Populated only when the field is a scalar (i.e. inner type is a single bare identifier
207    /// usable as a `ProtocolConfigValue` variant ident — `u16`/`u32`/`u64`/`bool`).
208    scalar_extras: Option<ScalarExtras>,
209}
210
211/// The extra tokens scalar fields contribute on top of the always-emitted accessor/setter/render
212/// pieces. Non-scalar fields don't participate in `ProtocolConfigValue` at all.
213struct ScalarExtras {
214    /// `stringify!(field_name) => self.set_x_from_str_for_testing(val),` — match arm for
215    /// `set_attr_for_testing`.
216    value_setter_arm: proc_macro2::TokenStream,
217    /// `stringify!(field_name) => self.field_name.map(ProtocolConfigValue::Variant),` — match
218    /// arm for `lookup_attr`.
219    lookup_arm: proc_macro2::TokenStream,
220    /// `stringify!(field_name)` — used to assemble `attr_map`.
221    field_name_str: proc_macro2::TokenStream,
222    /// Variant identifier used in `ProtocolConfigValue`.
223    variant_ident: syn::Ident,
224    /// Inner `T` of `Option<T>` — pairs with `variant_ident` when declaring the enum variant.
225    inner_type: syn::Type,
226}
227
228fn expand_field(f: &AccessorField) -> ExpandedField {
229    let field_name = &f.field_name;
230    let field_type = &f.field_type;
231    let inner_type = &f.inner_type;
232    let as_option_name: proc_macro2::TokenStream =
233        format!("{field_name}_as_option").parse().unwrap();
234    let test_setter_name: proc_macro2::TokenStream =
235        format!("set_{field_name}_for_testing").parse().unwrap();
236    let test_un_setter_name: proc_macro2::TokenStream =
237        format!("disable_{field_name}_for_testing").parse().unwrap();
238
239    let render_arm = quote! {
240        {
241            let value = match self.#field_name.as_ref() {
242                Some(v) => <_ as ::mysten_common::rpc_format::ToFormat>::to_format::<F, _>(v, meter)?,
243                None => <F as ::mysten_common::rpc_format::Format>::null(meter)?,
244            };
245            map.insert(stringify!(#field_name).to_owned(), value);
246        }
247    };
248
249    // `_as_option` is always emitted. The plain getter and the string-based setter only make
250    // sense for scalars (the plain getter unwraps to a `Copy` primitive; the `from_str` setter
251    // requires `FromStr` on the inner type). Non-scalars provide custom getters next to the
252    // field definition when they want one (e.g. borrowed-slice ergonomics).
253    let as_option_emit = quote! {
254        pub fn #as_option_name(&self) -> #field_type {
255            self.#field_name.clone()
256        }
257    };
258    let common_setters = quote! {
259        pub fn #test_setter_name(&mut self, val: #inner_type) {
260            self.#field_name = Some(val);
261        }
262
263        pub fn #test_un_setter_name(&mut self) {
264            self.#field_name = None;
265        }
266    };
267
268    match &f.scalar_variant {
269        Some(variant_ident) => {
270            let test_setter_from_str_name: proc_macro2::TokenStream =
271                format!("set_{field_name}_from_str_for_testing")
272                    .parse()
273                    .unwrap();
274            ExpandedField {
275                accessor: quote! {
276                    pub fn #field_name(&self) -> #inner_type {
277                        self.#field_name.expect(Self::CONSTANT_ERR_MSG)
278                    }
279
280                    pub fn #as_option_name(&self) -> #field_type {
281                        self.#field_name
282                    }
283                },
284                setter: quote! {
285                    #common_setters
286
287                    pub fn #test_setter_from_str_name(&mut self, val: String) {
288                        use std::str::FromStr;
289                        self.#test_setter_name(#inner_type::from_str(&val).unwrap());
290                    }
291                },
292                render_arm,
293                scalar_extras: Some(ScalarExtras {
294                    value_setter_arm: quote! {
295                        stringify!(#field_name) => self.#test_setter_from_str_name(val),
296                    },
297                    lookup_arm: quote! {
298                        stringify!(#field_name) => self
299                            .#field_name
300                            .map(ProtocolConfigValue::#variant_ident),
301                    },
302                    field_name_str: quote! { stringify!(#field_name) },
303                    variant_ident: variant_ident.clone(),
304                    inner_type: inner_type.clone(),
305                }),
306            }
307        }
308        None => ExpandedField {
309            accessor: as_option_emit,
310            setter: common_setters,
311            render_arm,
312            scalar_extras: None,
313        },
314    }
315}
316
317/// Per-field metadata extracted from a `ProtocolConfig` field while expanding the
318/// `ProtocolConfigAccessors` derive.
319struct AccessorField {
320    /// The `#field_name` identifier — used both for accessor method names and as the string key
321    /// in the generated maps.
322    field_name: syn::Ident,
323    /// The full `Option<T>` type as written in the struct.
324    field_type: syn::Type,
325    /// The inner `T` extracted from `Option<T>`.
326    inner_type: syn::Type,
327    /// `Some(ident)` when the inner type is a single bare identifier usable directly as a
328    /// `ProtocolConfigValue` variant ident (`u16`/`u32`/`u64`/`bool`). `None` for non-scalar
329    /// fields, which never appear in `ProtocolConfigValue`.
330    scalar_variant: Option<syn::Ident>,
331}
332
333fn parse_accessor_field(field: &syn::Field) -> Option<AccessorField> {
334    let field_name = field.ident.clone().expect("Field must be named");
335
336    let skip_accessor = field
337        .attrs
338        .iter()
339        .any(|attr| attr.path.is_ident("skip_accessor"));
340    if skip_accessor {
341        return None;
342    }
343
344    let field_type = &field.ty;
345    let type_path = match field_type {
346        Type::Path(p) => p,
347        _ => return None,
348    };
349    let last_segment = type_path.path.segments.last()?;
350    if last_segment.ident != "Option" {
351        return None;
352    }
353    let inner_type = match &last_segment.arguments {
354        syn::PathArguments::AngleBracketed(args) => match args.args.first()? {
355            syn::GenericArgument::Type(ty) => ty.clone(),
356            _ => panic!("Expected a type argument inside Option<...> for `{field_name}`"),
357        },
358        _ => panic!("Expected angle bracketed arguments inside Option<...> for `{field_name}`"),
359    };
360
361    let scalar_variant = inferred_scalar_variant_ident(&inner_type);
362
363    Some(AccessorField {
364        field_name,
365        field_type: field_type.clone(),
366        inner_type,
367        scalar_variant,
368    })
369}
370
371fn inferred_scalar_variant_ident(ty: &syn::Type) -> Option<syn::Ident> {
372    const SCALAR_PRIMITIVES: &[&str] = &["bool", "u8", "u16", "u32", "u64", "u128", "usize"];
373
374    let Type::Path(path) = ty else { return None };
375    if path.qself.is_some() {
376        return None;
377    }
378    if path.path.segments.len() != 1 {
379        return None;
380    }
381    let segment = path.path.segments.first()?;
382    if !matches!(segment.arguments, syn::PathArguments::None) {
383        return None;
384    }
385    if !SCALAR_PRIMITIVES.iter().any(|p| segment.ident == *p) {
386        return None;
387    }
388    Some(segment.ident.clone())
389}
390
391#[proc_macro_derive(ProtocolConfigOverride)]
392pub fn protocol_config_override_macro(input: TokenStream) -> TokenStream {
393    let ast = parse_macro_input!(input as DeriveInput);
394
395    // Create a new struct name by appending "Optional".
396    let struct_name = &ast.ident;
397    let optional_struct_name =
398        syn::Ident::new(&format!("{}Optional", struct_name), struct_name.span());
399
400    // Extract the fields from the struct
401    let fields = match &ast.data {
402        Data::Struct(data_struct) => match &data_struct.fields {
403            Fields::Named(fields_named) => &fields_named.named,
404            _ => panic!("ProtocolConfig must have named fields"),
405        },
406        _ => panic!("ProtocolConfig must be a struct"),
407    };
408
409    // Create new fields with types wrapped in Option.
410    let optional_fields = fields.iter().map(|field| {
411        let field_name = &field.ident;
412        let field_type = &field.ty;
413        quote! {
414            #field_name: Option<#field_type>
415        }
416    });
417
418    // Generate the function to update the original struct.
419    let update_fields = fields.iter().map(|field| {
420        let field_name = &field.ident;
421        quote! {
422            if let Some(value) = self.#field_name {
423                tracing::warn!(
424                    "ProtocolConfig field \"{}\" has been overridden with the value: {value:?}",
425                    stringify!(#field_name),
426                );
427                config.#field_name = value;
428            }
429        }
430    });
431
432    // Generate the new struct definition.
433    let output = quote! {
434        #[derive(serde::Deserialize, Debug)]
435        pub struct #optional_struct_name {
436            #(#optional_fields,)*
437        }
438
439        impl #optional_struct_name {
440            pub fn apply_to(self, config: &mut #struct_name) {
441                #(#update_fields)*
442            }
443        }
444    };
445
446    TokenStream::from(output)
447}
448
449#[proc_macro_derive(ProtocolConfigFeatureFlagsGetters, attributes(skip_accessor))]
450pub fn feature_flag_getters_macro(input: TokenStream) -> TokenStream {
451    let ast = parse_macro_input!(input as DeriveInput);
452
453    let struct_name = &ast.ident;
454    let data = &ast.data;
455
456    let getters = match data {
457        Data::Struct(data_struct) => match &data_struct.fields {
458            // Operate on each field of the ProtocolConfig struct
459            Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
460                // Extract field name and type
461                let field_name = field.ident.as_ref().expect("Field must be named");
462                let field_type = &field.ty;
463                let skip_accessor = field
464                    .attrs
465                    .iter()
466                    .any(|attr| attr.path.is_ident("skip_accessor"));
467                if skip_accessor {
468                    return None;
469                }
470                // Check if field is of type bool
471                match field_type {
472                    Type::Path(type_path)
473                        if type_path
474                            .path
475                            .segments
476                            .last()
477                            .is_some_and(|segment| segment.ident == "bool") =>
478                    {
479                        Some((
480                            quote! {
481                                // Derive the getter
482                                pub fn #field_name(&self) -> #field_type {
483                                    self.#field_name
484                                }
485                            },
486                            (
487                                quote! {
488                                    stringify!(#field_name) => Some(self.#field_name),
489                                },
490                                quote! {
491                                    stringify!(#field_name)
492                                },
493                            ),
494                        ))
495                    }
496                    _ => None,
497                }
498            }),
499            _ => panic!("Only named fields are supported."),
500        },
501        _ => panic!("Only structs supported."),
502    };
503
504    let (by_fn_getters, (string_name_getters, field_names)): (Vec<_>, (Vec<_>, Vec<_>)) =
505        getters.unzip();
506
507    let output = quote! {
508        // For each getter, expand it out into a function in the impl block
509        impl #struct_name {
510            #(#by_fn_getters)*
511
512            /// Lookup a feature flag by its string representation
513            pub fn lookup_attr(&self, value: String) -> Option<bool> {
514                match value.as_str() {
515                    #(#string_name_getters)*
516                    _ => None,
517                }
518            }
519
520            /// Get a map of all feature flags from string representations
521            pub fn attr_map(&self) -> std::collections::BTreeMap<String, bool> {
522                vec![
523                    // Okay to unwrap since we added all above
524                    #(((#field_names).to_owned(), self.lookup_attr((#field_names).to_owned()).unwrap()),)*
525                    ].into_iter().collect()
526            }
527        }
528    };
529
530    TokenStream::from(output)
531}