sui_protocol_config_macros/
lib.rs1extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{Data, DeriveInput, Fields, Type, parse_macro_input};
9
10#[proc_macro_derive(ProtocolConfigAccessors, attributes(skip_accessor))]
55pub fn accessors_macro(input: TokenStream) -> TokenStream {
56 let ast = parse_macro_input!(input as DeriveInput);
57
58 let struct_name = &ast.ident;
59 let data = &ast.data;
60
61 let fields: Vec<AccessorField> = match data {
62 Data::Struct(data_struct) => match &data_struct.fields {
63 Fields::Named(fields_named) => fields_named
64 .named
65 .iter()
66 .filter_map(parse_accessor_field)
67 .collect(),
68 _ => panic!("Only named fields are supported."),
69 },
70 _ => panic!("Only structs supported."),
71 };
72
73 let expanded: Vec<ExpandedField> = fields.iter().map(expand_field).collect();
74
75 let accessors = expanded.iter().map(|e| &e.accessor);
76 let setters = expanded.iter().map(|e| &e.setter);
77 let render_arms = expanded.iter().map(|e| &e.render_arm);
78
79 let scalar_extras: Vec<&ScalarExtras> = expanded
81 .iter()
82 .filter_map(|e| e.scalar_extras.as_ref())
83 .collect();
84 let scalar_value_setter_arms = scalar_extras.iter().map(|s| &s.value_setter_arm);
85 let scalar_lookup_arms = scalar_extras.iter().map(|s| &s.lookup_arm);
86 let scalar_field_names = scalar_extras.iter().map(|s| &s.field_name_str);
87
88 let mut variant_decls = Vec::new();
91 let mut display_variants = Vec::new();
92 let mut seen = std::collections::HashSet::new();
93 for s in &scalar_extras {
94 if !seen.insert(s.variant_ident.to_string()) {
95 continue;
96 }
97 let ident = &s.variant_ident;
98 let inner = &s.inner_type;
99 variant_decls.push(quote! { #ident(#inner) });
100 display_variants.push(s.variant_ident.clone());
101 }
102
103 let output = quote! {
104 impl #struct_name {
105 const CONSTANT_ERR_MSG: &'static str = "protocol constant not present in current protocol version";
106 #(#accessors)*
107
108 pub fn lookup_attr(&self, value: String) -> Option<ProtocolConfigValue> {
110 match value.as_str() {
111 #(#scalar_lookup_arms)*
112 _ => None,
113 }
114 }
115
116 pub fn attr_map(&self) -> std::collections::BTreeMap<String, Option<ProtocolConfigValue>> {
121 vec![
122 #(((#scalar_field_names).to_owned(), self.lookup_attr((#scalar_field_names).to_owned())),)*
123 ].into_iter().collect()
124 }
125
126 pub fn render<F>(
132 &self,
133 meter: &mut impl ::mysten_common::rpc_format::Meter,
134 ) -> ::std::result::Result<
135 std::collections::BTreeMap<String, F>,
136 ::mysten_common::rpc_format::MeterError,
137 >
138 where
139 F: ::mysten_common::rpc_format::Format,
140 {
141 let mut map = std::collections::BTreeMap::new();
142 #(#render_arms)*
143 Ok(map)
144 }
145
146 pub fn lookup_feature(&self, value: String) -> Option<bool> {
148 self.feature_flags.lookup_attr(value)
149 }
150
151 pub fn feature_map(&self) -> std::collections::BTreeMap<String, bool> {
152 self.feature_flags.attr_map()
153 }
154 }
155
156 impl #struct_name {
157 #(#setters)*
158
159 pub fn set_attr_for_testing(&mut self, attr: String, val: String) {
160 match attr.as_str() {
161 #(#scalar_value_setter_arms)*
162 _ => panic!(
163 "Attempting to set unknown or non-string-settable attribute: {}",
164 attr,
165 ),
166 }
167 }
168 }
169
170 #[allow(non_camel_case_types)]
171 #[derive(Clone, Serialize, Debug, PartialEq, Deserialize, schemars::JsonSchema)]
172 pub enum ProtocolConfigValue {
173 #(#variant_decls,)*
174 }
175
176 impl std::fmt::Display for ProtocolConfigValue {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 use std::fmt::Write;
179 let mut writer = String::new();
180 match self {
181 #(
182 ProtocolConfigValue::#display_variants(x) => {
183 write!(writer, "{}", x)?;
184 }
185 )*
186 }
187 write!(f, "{}", writer)
188 }
189 }
190 };
191
192 TokenStream::from(output)
193}
194
195struct ExpandedField {
199 accessor: proc_macro2::TokenStream,
201 setter: proc_macro2::TokenStream,
204 render_arm: proc_macro2::TokenStream,
206 scalar_extras: Option<ScalarExtras>,
209}
210
211struct ScalarExtras {
214 value_setter_arm: proc_macro2::TokenStream,
217 lookup_arm: proc_macro2::TokenStream,
220 field_name_str: proc_macro2::TokenStream,
222 variant_ident: syn::Ident,
224 inner_type: syn::Type,
226}
227
228fn expand_field(f: &AccessorField) -> ExpandedField {
229 let field_name = &f.field_name;
230 let field_type = &f.field_type;
231 let inner_type = &f.inner_type;
232 let as_option_name: proc_macro2::TokenStream =
233 format!("{field_name}_as_option").parse().unwrap();
234 let test_setter_name: proc_macro2::TokenStream =
235 format!("set_{field_name}_for_testing").parse().unwrap();
236 let test_un_setter_name: proc_macro2::TokenStream =
237 format!("disable_{field_name}_for_testing").parse().unwrap();
238
239 let render_arm = quote! {
240 {
241 let value = match self.#field_name.as_ref() {
242 Some(v) => <_ as ::mysten_common::rpc_format::ToFormat>::to_format::<F, _>(v, meter)?,
243 None => <F as ::mysten_common::rpc_format::Format>::null(meter)?,
244 };
245 map.insert(stringify!(#field_name).to_owned(), value);
246 }
247 };
248
249 let as_option_emit = quote! {
254 pub fn #as_option_name(&self) -> #field_type {
255 self.#field_name.clone()
256 }
257 };
258 let common_setters = quote! {
259 pub fn #test_setter_name(&mut self, val: #inner_type) {
260 self.#field_name = Some(val);
261 }
262
263 pub fn #test_un_setter_name(&mut self) {
264 self.#field_name = None;
265 }
266 };
267
268 match &f.scalar_variant {
269 Some(variant_ident) => {
270 let test_setter_from_str_name: proc_macro2::TokenStream =
271 format!("set_{field_name}_from_str_for_testing")
272 .parse()
273 .unwrap();
274 ExpandedField {
275 accessor: quote! {
276 pub fn #field_name(&self) -> #inner_type {
277 self.#field_name.expect(Self::CONSTANT_ERR_MSG)
278 }
279
280 pub fn #as_option_name(&self) -> #field_type {
281 self.#field_name
282 }
283 },
284 setter: quote! {
285 #common_setters
286
287 pub fn #test_setter_from_str_name(&mut self, val: String) {
288 use std::str::FromStr;
289 self.#test_setter_name(#inner_type::from_str(&val).unwrap());
290 }
291 },
292 render_arm,
293 scalar_extras: Some(ScalarExtras {
294 value_setter_arm: quote! {
295 stringify!(#field_name) => self.#test_setter_from_str_name(val),
296 },
297 lookup_arm: quote! {
298 stringify!(#field_name) => self
299 .#field_name
300 .map(ProtocolConfigValue::#variant_ident),
301 },
302 field_name_str: quote! { stringify!(#field_name) },
303 variant_ident: variant_ident.clone(),
304 inner_type: inner_type.clone(),
305 }),
306 }
307 }
308 None => ExpandedField {
309 accessor: as_option_emit,
310 setter: common_setters,
311 render_arm,
312 scalar_extras: None,
313 },
314 }
315}
316
317struct AccessorField {
320 field_name: syn::Ident,
323 field_type: syn::Type,
325 inner_type: syn::Type,
327 scalar_variant: Option<syn::Ident>,
331}
332
333fn parse_accessor_field(field: &syn::Field) -> Option<AccessorField> {
334 let field_name = field.ident.clone().expect("Field must be named");
335
336 let skip_accessor = field
337 .attrs
338 .iter()
339 .any(|attr| attr.path.is_ident("skip_accessor"));
340 if skip_accessor {
341 return None;
342 }
343
344 let field_type = &field.ty;
345 let type_path = match field_type {
346 Type::Path(p) => p,
347 _ => return None,
348 };
349 let last_segment = type_path.path.segments.last()?;
350 if last_segment.ident != "Option" {
351 return None;
352 }
353 let inner_type = match &last_segment.arguments {
354 syn::PathArguments::AngleBracketed(args) => match args.args.first()? {
355 syn::GenericArgument::Type(ty) => ty.clone(),
356 _ => panic!("Expected a type argument inside Option<...> for `{field_name}`"),
357 },
358 _ => panic!("Expected angle bracketed arguments inside Option<...> for `{field_name}`"),
359 };
360
361 let scalar_variant = inferred_scalar_variant_ident(&inner_type);
362
363 Some(AccessorField {
364 field_name,
365 field_type: field_type.clone(),
366 inner_type,
367 scalar_variant,
368 })
369}
370
371fn inferred_scalar_variant_ident(ty: &syn::Type) -> Option<syn::Ident> {
372 const SCALAR_PRIMITIVES: &[&str] = &["bool", "u8", "u16", "u32", "u64", "u128", "usize"];
373
374 let Type::Path(path) = ty else { return None };
375 if path.qself.is_some() {
376 return None;
377 }
378 if path.path.segments.len() != 1 {
379 return None;
380 }
381 let segment = path.path.segments.first()?;
382 if !matches!(segment.arguments, syn::PathArguments::None) {
383 return None;
384 }
385 if !SCALAR_PRIMITIVES.iter().any(|p| segment.ident == *p) {
386 return None;
387 }
388 Some(segment.ident.clone())
389}
390
391#[proc_macro_derive(ProtocolConfigOverride)]
392pub fn protocol_config_override_macro(input: TokenStream) -> TokenStream {
393 let ast = parse_macro_input!(input as DeriveInput);
394
395 let struct_name = &ast.ident;
397 let optional_struct_name =
398 syn::Ident::new(&format!("{}Optional", struct_name), struct_name.span());
399
400 let fields = match &ast.data {
402 Data::Struct(data_struct) => match &data_struct.fields {
403 Fields::Named(fields_named) => &fields_named.named,
404 _ => panic!("ProtocolConfig must have named fields"),
405 },
406 _ => panic!("ProtocolConfig must be a struct"),
407 };
408
409 let optional_fields = fields.iter().map(|field| {
411 let field_name = &field.ident;
412 let field_type = &field.ty;
413 quote! {
414 #field_name: Option<#field_type>
415 }
416 });
417
418 let update_fields = fields.iter().map(|field| {
420 let field_name = &field.ident;
421 quote! {
422 if let Some(value) = self.#field_name {
423 tracing::warn!(
424 "ProtocolConfig field \"{}\" has been overridden with the value: {value:?}",
425 stringify!(#field_name),
426 );
427 config.#field_name = value;
428 }
429 }
430 });
431
432 let output = quote! {
434 #[derive(serde::Deserialize, Debug)]
435 pub struct #optional_struct_name {
436 #(#optional_fields,)*
437 }
438
439 impl #optional_struct_name {
440 pub fn apply_to(self, config: &mut #struct_name) {
441 #(#update_fields)*
442 }
443 }
444 };
445
446 TokenStream::from(output)
447}
448
449#[proc_macro_derive(ProtocolConfigFeatureFlagsGetters, attributes(skip_accessor))]
450pub fn feature_flag_getters_macro(input: TokenStream) -> TokenStream {
451 let ast = parse_macro_input!(input as DeriveInput);
452
453 let struct_name = &ast.ident;
454 let data = &ast.data;
455
456 let getters = match data {
457 Data::Struct(data_struct) => match &data_struct.fields {
458 Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
460 let field_name = field.ident.as_ref().expect("Field must be named");
462 let field_type = &field.ty;
463 let skip_accessor = field
464 .attrs
465 .iter()
466 .any(|attr| attr.path.is_ident("skip_accessor"));
467 if skip_accessor {
468 return None;
469 }
470 match field_type {
472 Type::Path(type_path)
473 if type_path
474 .path
475 .segments
476 .last()
477 .is_some_and(|segment| segment.ident == "bool") =>
478 {
479 Some((
480 quote! {
481 pub fn #field_name(&self) -> #field_type {
483 self.#field_name
484 }
485 },
486 (
487 quote! {
488 stringify!(#field_name) => Some(self.#field_name),
489 },
490 quote! {
491 stringify!(#field_name)
492 },
493 ),
494 ))
495 }
496 _ => None,
497 }
498 }),
499 _ => panic!("Only named fields are supported."),
500 },
501 _ => panic!("Only structs supported."),
502 };
503
504 let (by_fn_getters, (string_name_getters, field_names)): (Vec<_>, (Vec<_>, Vec<_>)) =
505 getters.unzip();
506
507 let output = quote! {
508 impl #struct_name {
510 #(#by_fn_getters)*
511
512 pub fn lookup_attr(&self, value: String) -> Option<bool> {
514 match value.as_str() {
515 #(#string_name_getters)*
516 _ => None,
517 }
518 }
519
520 pub fn attr_map(&self) -> std::collections::BTreeMap<String, bool> {
522 vec![
523 #(((#field_names).to_owned(), self.lookup_attr((#field_names).to_owned()).unwrap()),)*
525 ].into_iter().collect()
526 }
527 }
528 };
529
530 TokenStream::from(output)
531}