sui_protocol_config_macros/
lib.rs1extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{Data, DeriveInput, Fields, Type, parse_macro_input};
9
10#[proc_macro_derive(ProtocolConfigAccessors)]
43pub fn accessors_macro(input: TokenStream) -> TokenStream {
44 let ast = parse_macro_input!(input as DeriveInput);
45
46 let struct_name = &ast.ident;
47 let data = &ast.data;
48 let mut inner_types = vec![];
49
50 let tokens = match data {
51 Data::Struct(data_struct) => match &data_struct.fields {
52 Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
54 let field_name = field.ident.as_ref().expect("Field must be named");
56 let field_type = &field.ty;
57 match field_type {
59 Type::Path(type_path)
60 if type_path
61 .path
62 .segments
63 .last().is_some_and(|segment| segment.ident == "Option") =>
64 {
65 let inner_type = if let syn::PathArguments::AngleBracketed(
67 angle_bracketed_generic_arguments,
68 ) = &type_path.path.segments.last().unwrap().arguments
69 {
70 if let Some(syn::GenericArgument::Type(ty)) =
71 angle_bracketed_generic_arguments.args.first()
72 {
73 ty.clone()
74 } else {
75 panic!("Expected a type argument.");
76 }
77 } else {
78 panic!("Expected angle bracketed arguments.");
79 };
80
81 let as_option_name = format!("{field_name}_as_option");
82 let as_option_name: proc_macro2::TokenStream =
83 as_option_name.parse().unwrap();
84 let test_setter_name: proc_macro2::TokenStream =
85 format!("set_{field_name}_for_testing").parse().unwrap();
86 let test_un_setter_name: proc_macro2::TokenStream =
87 format!("disable_{field_name}_for_testing").parse().unwrap();
88 let test_setter_from_str_name: proc_macro2::TokenStream =
89 format!("set_{field_name}_from_str_for_testing").parse().unwrap();
90
91 let getter = quote! {
92 pub fn #field_name(&self) -> #inner_type {
94 self.#field_name.expect(Self::CONSTANT_ERR_MSG)
95 }
96
97 pub fn #as_option_name(&self) -> #field_type {
98 self.#field_name
99 }
100 };
101
102 let test_setter = quote! {
103 pub fn #test_setter_name(&mut self, val: #inner_type) {
105 self.#field_name = Some(val);
106 }
107
108 pub fn #test_setter_from_str_name(&mut self, val: String) {
110 use std::str::FromStr;
111 self.#test_setter_name(#inner_type::from_str(&val).unwrap());
112 }
113
114 pub fn #test_un_setter_name(&mut self) {
116 self.#field_name = None;
117 }
118 };
119
120 let value_setter = quote! {
121 stringify!(#field_name) => self.#test_setter_from_str_name(val),
122 };
123
124
125 let value_lookup = quote! {
126 stringify!(#field_name) => self.#field_name.map(|v| ProtocolConfigValue::#inner_type(v)),
127 };
128
129 let field_name_str = quote! {
130 stringify!(#field_name)
131 };
132
133 if inner_types.contains(&inner_type) {
135 None
136 } else {
137 inner_types.push(inner_type.clone());
138 Some(quote! {
139 #inner_type
140 })
141 };
142
143 Some(((getter, (test_setter, value_setter)), (value_lookup, field_name_str)))
144 }
145 _ => None,
146 }
147 }),
148 _ => panic!("Only named fields are supported."),
149 },
150 _ => panic!("Only structs supported."),
151 };
152
153 #[allow(clippy::type_complexity)]
154 let ((getters, (test_setters, value_setters)), (value_lookup, field_names_str)): (
155 (Vec<_>, (Vec<_>, Vec<_>)),
156 (Vec<_>, Vec<_>),
157 ) = tokens.unzip();
158 let output = quote! {
159 impl #struct_name {
161 const CONSTANT_ERR_MSG: &'static str = "protocol constant not present in current protocol version";
162 #(#getters)*
163
164 pub fn lookup_attr(&self, value: String) -> Option<ProtocolConfigValue> {
166 match value.as_str() {
167 #(#value_lookup)*
168 _ => None,
169 }
170 }
171
172 pub fn attr_map(&self) -> std::collections::BTreeMap<String, Option<ProtocolConfigValue>> {
174 vec![
175 #(((#field_names_str).to_owned(), self.lookup_attr((#field_names_str).to_owned())),)*
176 ].into_iter().collect()
177 }
178
179 pub fn lookup_feature(&self, value: String) -> Option<bool> {
181 self.feature_flags.lookup_attr(value)
182 }
183
184 pub fn feature_map(&self) -> std::collections::BTreeMap<String, bool> {
185 self.feature_flags.attr_map()
186 }
187 }
188
189 impl #struct_name {
191 #(#test_setters)*
192
193 pub fn set_attr_for_testing(&mut self, attr: String, val: String) {
194 match attr.as_str() {
195 #(#value_setters)*
196 _ => panic!("Attempting to set unknown attribute: {}", attr),
197 }
198 }
199 }
200
201 #[allow(non_camel_case_types)]
202 #[derive(Clone, Serialize, Debug, PartialEq, Deserialize, schemars::JsonSchema)]
203 pub enum ProtocolConfigValue {
204 #(#inner_types(#inner_types),)*
205 }
206
207 impl std::fmt::Display for ProtocolConfigValue {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 use std::fmt::Write;
210 let mut writer = String::new();
211 match self {
212 #(
213 ProtocolConfigValue::#inner_types(x) => {
214 write!(writer, "{}", x)?;
215 }
216 )*
217 }
218 write!(f, "{}", writer)
219 }
220 }
221 };
222
223 TokenStream::from(output)
224}
225
226#[proc_macro_derive(ProtocolConfigOverride)]
227pub fn protocol_config_override_macro(input: TokenStream) -> TokenStream {
228 let ast = parse_macro_input!(input as DeriveInput);
229
230 let struct_name = &ast.ident;
232 let optional_struct_name =
233 syn::Ident::new(&format!("{}Optional", struct_name), struct_name.span());
234
235 let fields = match &ast.data {
237 Data::Struct(data_struct) => match &data_struct.fields {
238 Fields::Named(fields_named) => &fields_named.named,
239 _ => panic!("ProtocolConfig must have named fields"),
240 },
241 _ => panic!("ProtocolConfig must be a struct"),
242 };
243
244 let optional_fields = fields.iter().map(|field| {
246 let field_name = &field.ident;
247 let field_type = &field.ty;
248 quote! {
249 #field_name: Option<#field_type>
250 }
251 });
252
253 let update_fields = fields.iter().map(|field| {
255 let field_name = &field.ident;
256 quote! {
257 if let Some(value) = self.#field_name {
258 tracing::warn!(
259 "ProtocolConfig field \"{}\" has been overridden with the value: {value:?}",
260 stringify!(#field_name),
261 );
262 config.#field_name = value;
263 }
264 }
265 });
266
267 let output = quote! {
269 #[derive(serde::Deserialize, Debug)]
270 pub struct #optional_struct_name {
271 #(#optional_fields,)*
272 }
273
274 impl #optional_struct_name {
275 pub fn apply_to(self, config: &mut #struct_name) {
276 #(#update_fields)*
277 }
278 }
279 };
280
281 TokenStream::from(output)
282}
283
284#[proc_macro_derive(ProtocolConfigFeatureFlagsGetters)]
285pub fn feature_flag_getters_macro(input: TokenStream) -> TokenStream {
286 let ast = parse_macro_input!(input as DeriveInput);
287
288 let struct_name = &ast.ident;
289 let data = &ast.data;
290
291 let getters = match data {
292 Data::Struct(data_struct) => match &data_struct.fields {
293 Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
295 let field_name = field.ident.as_ref().expect("Field must be named");
297 let field_type = &field.ty;
298 match field_type {
300 Type::Path(type_path)
301 if type_path
302 .path
303 .segments
304 .last()
305 .is_some_and(|segment| segment.ident == "bool") =>
306 {
307 Some((
308 quote! {
309 pub fn #field_name(&self) -> #field_type {
311 self.#field_name
312 }
313 },
314 (
315 quote! {
316 stringify!(#field_name) => Some(self.#field_name),
317 },
318 quote! {
319 stringify!(#field_name)
320 },
321 ),
322 ))
323 }
324 _ => None,
325 }
326 }),
327 _ => panic!("Only named fields are supported."),
328 },
329 _ => panic!("Only structs supported."),
330 };
331
332 let (by_fn_getters, (string_name_getters, field_names)): (Vec<_>, (Vec<_>, Vec<_>)) =
333 getters.unzip();
334
335 let output = quote! {
336 impl #struct_name {
338 #(#by_fn_getters)*
339
340 pub fn lookup_attr(&self, value: String) -> Option<bool> {
342 match value.as_str() {
343 #(#string_name_getters)*
344 _ => None,
345 }
346 }
347
348 pub fn attr_map(&self) -> std::collections::BTreeMap<String, bool> {
350 vec![
351 #(((#field_names).to_owned(), self.lookup_attr((#field_names).to_owned()).unwrap()),)*
353 ].into_iter().collect()
354 }
355 }
356 };
357
358 TokenStream::from(output)
359}