sui_protocol_config_macros/
lib.rs1extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{Data, DeriveInput, Fields, Type, parse_macro_input};
9
10#[proc_macro_derive(ProtocolConfigAccessors)]
43pub fn accessors_macro(input: TokenStream) -> TokenStream {
44 let ast = parse_macro_input!(input as DeriveInput);
45
46 let struct_name = &ast.ident;
47 let data = &ast.data;
48 let mut inner_types = vec![];
49
50 let tokens = match data {
51 Data::Struct(data_struct) => match &data_struct.fields {
52 Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
54 let field_name = field.ident.as_ref().expect("Field must be named");
56 let field_type = &field.ty;
57 match field_type {
59 Type::Path(type_path)
60 if type_path
61 .path
62 .segments
63 .last().is_some_and(|segment| segment.ident == "Option") =>
64 {
65 let inner_type = if let syn::PathArguments::AngleBracketed(
67 angle_bracketed_generic_arguments,
68 ) = &type_path.path.segments.last().unwrap().arguments
69 {
70 if let Some(syn::GenericArgument::Type(ty)) =
71 angle_bracketed_generic_arguments.args.first()
72 {
73 ty.clone()
74 } else {
75 panic!("Expected a type argument.");
76 }
77 } else {
78 panic!("Expected angle bracketed arguments.");
79 };
80
81 let is_vec = matches!(&inner_type, Type::Path(tp)
83 if tp.path.segments.last()
84 .is_some_and(|s| s.ident == "Vec"));
85 if is_vec {
86 return None;
87 }
88
89 let as_option_name = format!("{field_name}_as_option");
90 let as_option_name: proc_macro2::TokenStream =
91 as_option_name.parse().unwrap();
92 let test_setter_name: proc_macro2::TokenStream =
93 format!("set_{field_name}_for_testing").parse().unwrap();
94 let test_un_setter_name: proc_macro2::TokenStream =
95 format!("disable_{field_name}_for_testing").parse().unwrap();
96 let test_setter_from_str_name: proc_macro2::TokenStream =
97 format!("set_{field_name}_from_str_for_testing").parse().unwrap();
98
99 let getter = quote! {
100 pub fn #field_name(&self) -> #inner_type {
102 self.#field_name.expect(Self::CONSTANT_ERR_MSG)
103 }
104
105 pub fn #as_option_name(&self) -> #field_type {
106 self.#field_name
107 }
108 };
109
110 let test_setter = quote! {
111 pub fn #test_setter_name(&mut self, val: #inner_type) {
113 self.#field_name = Some(val);
114 }
115
116 pub fn #test_setter_from_str_name(&mut self, val: String) {
118 use std::str::FromStr;
119 self.#test_setter_name(#inner_type::from_str(&val).unwrap());
120 }
121
122 pub fn #test_un_setter_name(&mut self) {
124 self.#field_name = None;
125 }
126 };
127
128 let value_setter = quote! {
129 stringify!(#field_name) => self.#test_setter_from_str_name(val),
130 };
131
132
133 let value_lookup = quote! {
134 stringify!(#field_name) => self.#field_name.map(|v| ProtocolConfigValue::#inner_type(v)),
135 };
136
137 let field_name_str = quote! {
138 stringify!(#field_name)
139 };
140
141 if inner_types.contains(&inner_type) {
143 None
144 } else {
145 inner_types.push(inner_type.clone());
146 Some(quote! {
147 #inner_type
148 })
149 };
150
151 Some(((getter, (test_setter, value_setter)), (value_lookup, field_name_str)))
152 }
153 _ => None,
154 }
155 }),
156 _ => panic!("Only named fields are supported."),
157 },
158 _ => panic!("Only structs supported."),
159 };
160
161 #[allow(clippy::type_complexity)]
162 let ((getters, (test_setters, value_setters)), (value_lookup, field_names_str)): (
163 (Vec<_>, (Vec<_>, Vec<_>)),
164 (Vec<_>, Vec<_>),
165 ) = tokens.unzip();
166 let output = quote! {
167 impl #struct_name {
169 const CONSTANT_ERR_MSG: &'static str = "protocol constant not present in current protocol version";
170 #(#getters)*
171
172 pub fn lookup_attr(&self, value: String) -> Option<ProtocolConfigValue> {
174 match value.as_str() {
175 #(#value_lookup)*
176 _ => None,
177 }
178 }
179
180 pub fn attr_map(&self) -> std::collections::BTreeMap<String, Option<ProtocolConfigValue>> {
182 vec![
183 #(((#field_names_str).to_owned(), self.lookup_attr((#field_names_str).to_owned())),)*
184 ].into_iter().collect()
185 }
186
187 pub fn lookup_feature(&self, value: String) -> Option<bool> {
189 self.feature_flags.lookup_attr(value)
190 }
191
192 pub fn feature_map(&self) -> std::collections::BTreeMap<String, bool> {
193 self.feature_flags.attr_map()
194 }
195 }
196
197 impl #struct_name {
199 #(#test_setters)*
200
201 pub fn set_attr_for_testing(&mut self, attr: String, val: String) {
202 match attr.as_str() {
203 #(#value_setters)*
204 _ => panic!("Attempting to set unknown attribute: {}", attr),
205 }
206 }
207 }
208
209 #[allow(non_camel_case_types)]
210 #[derive(Clone, Serialize, Debug, PartialEq, Deserialize, schemars::JsonSchema)]
211 pub enum ProtocolConfigValue {
212 #(#inner_types(#inner_types),)*
213 }
214
215 impl std::fmt::Display for ProtocolConfigValue {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 use std::fmt::Write;
218 let mut writer = String::new();
219 match self {
220 #(
221 ProtocolConfigValue::#inner_types(x) => {
222 write!(writer, "{}", x)?;
223 }
224 )*
225 }
226 write!(f, "{}", writer)
227 }
228 }
229 };
230
231 TokenStream::from(output)
232}
233
234#[proc_macro_derive(ProtocolConfigOverride)]
235pub fn protocol_config_override_macro(input: TokenStream) -> TokenStream {
236 let ast = parse_macro_input!(input as DeriveInput);
237
238 let struct_name = &ast.ident;
240 let optional_struct_name =
241 syn::Ident::new(&format!("{}Optional", struct_name), struct_name.span());
242
243 let fields = match &ast.data {
245 Data::Struct(data_struct) => match &data_struct.fields {
246 Fields::Named(fields_named) => &fields_named.named,
247 _ => panic!("ProtocolConfig must have named fields"),
248 },
249 _ => panic!("ProtocolConfig must be a struct"),
250 };
251
252 let optional_fields = fields.iter().map(|field| {
254 let field_name = &field.ident;
255 let field_type = &field.ty;
256 quote! {
257 #field_name: Option<#field_type>
258 }
259 });
260
261 let update_fields = fields.iter().map(|field| {
263 let field_name = &field.ident;
264 quote! {
265 if let Some(value) = self.#field_name {
266 tracing::warn!(
267 "ProtocolConfig field \"{}\" has been overridden with the value: {value:?}",
268 stringify!(#field_name),
269 );
270 config.#field_name = value;
271 }
272 }
273 });
274
275 let output = quote! {
277 #[derive(serde::Deserialize, Debug)]
278 pub struct #optional_struct_name {
279 #(#optional_fields,)*
280 }
281
282 impl #optional_struct_name {
283 pub fn apply_to(self, config: &mut #struct_name) {
284 #(#update_fields)*
285 }
286 }
287 };
288
289 TokenStream::from(output)
290}
291
292#[proc_macro_derive(ProtocolConfigFeatureFlagsGetters)]
293pub fn feature_flag_getters_macro(input: TokenStream) -> TokenStream {
294 let ast = parse_macro_input!(input as DeriveInput);
295
296 let struct_name = &ast.ident;
297 let data = &ast.data;
298
299 let getters = match data {
300 Data::Struct(data_struct) => match &data_struct.fields {
301 Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
303 let field_name = field.ident.as_ref().expect("Field must be named");
305 let field_type = &field.ty;
306 match field_type {
308 Type::Path(type_path)
309 if type_path
310 .path
311 .segments
312 .last()
313 .is_some_and(|segment| segment.ident == "bool") =>
314 {
315 Some((
316 quote! {
317 pub fn #field_name(&self) -> #field_type {
319 self.#field_name
320 }
321 },
322 (
323 quote! {
324 stringify!(#field_name) => Some(self.#field_name),
325 },
326 quote! {
327 stringify!(#field_name)
328 },
329 ),
330 ))
331 }
332 _ => None,
333 }
334 }),
335 _ => panic!("Only named fields are supported."),
336 },
337 _ => panic!("Only structs supported."),
338 };
339
340 let (by_fn_getters, (string_name_getters, field_names)): (Vec<_>, (Vec<_>, Vec<_>)) =
341 getters.unzip();
342
343 let output = quote! {
344 impl #struct_name {
346 #(#by_fn_getters)*
347
348 pub fn lookup_attr(&self, value: String) -> Option<bool> {
350 match value.as_str() {
351 #(#string_name_getters)*
352 _ => None,
353 }
354 }
355
356 pub fn attr_map(&self) -> std::collections::BTreeMap<String, bool> {
358 vec![
359 #(((#field_names).to_owned(), self.lookup_attr((#field_names).to_owned()).unwrap()),)*
361 ].into_iter().collect()
362 }
363 }
364 };
365
366 TokenStream::from(output)
367}