sui_proc_macros/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use proc_macro::TokenStream;
5use quote::{ToTokens, quote, quote_spanned};
6use syn::{
7    Attribute, BinOp, Data, DataEnum, DeriveInput, Expr, ExprBinary, ExprMacro, Item, ItemMacro,
8    Stmt, StmtMacro, Token, UnOp,
9    fold::{Fold, fold_expr, fold_item_macro, fold_stmt},
10    parse::Parser,
11    parse_macro_input, parse2,
12    punctuated::Punctuated,
13    spanned::Spanned,
14};
15
16#[proc_macro_attribute]
17pub fn init_static_initializers(_args: TokenStream, item: TokenStream) -> TokenStream {
18    let mut input = parse_macro_input!(item as syn::ItemFn);
19
20    let body = &input.block;
21    input.block = syn::parse2(quote! {
22        {
23            // We have some lazily-initialized static state in the program. The initializers
24            // alter the thread-local hash container state any time they create a new hash
25            // container. Therefore, we need to ensure that these initializers are run in a
26            // separate thread before the first test thread is launched. Otherwise, they would
27            // run inside of the first test thread, but not subsequent ones.
28            //
29            // Note that none of this has any effect on process-level determinism. Without this
30            // code, we can still get the same test results from two processes started with the
31            // same seed.
32            //
33            // However, when using sim_test(check_determinism) or MSIM_TEST_CHECK_DETERMINISM=1,
34            // we want the same test invocation to be deterministic when run twice
35            // _in the same process_, so we need to take care of this. This will also
36            // be very important for being able to reproduce a failure that occurs in the Nth
37            // iteration of a multi-iteration test run.
38            std::thread::spawn(|| {
39                use sui_protocol_config::ProtocolConfig;
40                ::sui_simulator::telemetry_subscribers::init_for_testing();
41                ::sui_simulator::sui_types::execution_params::get_denied_certificates_for_sim_test();
42                ::sui_simulator::sui_framework::BuiltInFramework::all_package_ids();
43                ::sui_simulator::sui_types::gas::SuiGasStatus::new_unmetered();
44
45                // For reasons I can't understand, LruCache causes divergent behavior the second
46                // time one is constructed and inserted into, so construct one before the first
47                // test run for determinism.
48                let mut cache = ::sui_simulator::lru::LruCache::new(1.try_into().unwrap());
49                cache.put(1, 1);
50
51                use std::sync::Arc;
52
53                use ::sui_simulator::anemo_tower::callback::CallbackLayer;
54                use ::sui_simulator::anemo_tower::trace::DefaultMakeSpan;
55                use ::sui_simulator::anemo_tower::trace::DefaultOnFailure;
56                use ::sui_simulator::anemo_tower::trace::TraceLayer;
57                use ::sui_simulator::fastcrypto::traits::KeyPair;
58                use ::sui_simulator::mysten_network::metrics::MetricsMakeCallbackHandler;
59                use ::sui_simulator::mysten_network::metrics::NetworkMetrics;
60                use ::sui_simulator::rand_crate::rngs::{StdRng, OsRng};
61                use ::sui_simulator::rand::SeedableRng;
62                use ::sui_simulator::tower::ServiceBuilder;
63
64                // anemo uses x509-parser, which has many lazy static variables. start a network to
65                // initialize all that static state before the first test.
66                let rt = ::sui_simulator::runtime::Runtime::new();
67                rt.block_on(async move {
68                    use ::sui_simulator::anemo::{Network, Request};
69
70                    {
71                        // Initialize the static initializers here:
72                        // https://github.com/move-language/move/blob/652badf6fd67e1d4cc2aa6dc69d63ad14083b673/language/tools/move-package/src/package_lock.rs#L12
73                        use std::path::PathBuf;
74                        use sui_simulator::sui_move_build::BuildConfig;
75                        use sui_simulator::tempfile::TempDir;
76
77                        let mut path = PathBuf::from(env!("SIMTEST_STATIC_INIT_MOVE"));
78                        let mut build_config = BuildConfig::new_for_testing();
79                        build_config.config.install_dir = Some(TempDir::new().unwrap().keep());
80                        let _all_module_bytes = build_config
81                            .build_async(&path)
82                            .await
83                            .unwrap()
84                            .get_package_bytes(/* with_unpublished_deps */ false);
85                    }
86
87                    let make_network = |port: u16| {
88                        let registry = prometheus::Registry::new();
89                        let inbound_network_metrics =
90                            NetworkMetrics::new("sui", "inbound", &registry);
91                        let outbound_network_metrics =
92                            NetworkMetrics::new("sui", "outbound", &registry);
93
94                        let service = ServiceBuilder::new()
95                            .layer(
96                                TraceLayer::new_for_server_errors()
97                                    .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
98                                    .on_failure(DefaultOnFailure::new().level(tracing::Level::WARN)),
99                            )
100                            .layer(CallbackLayer::new(MetricsMakeCallbackHandler::new(
101                                Arc::new(inbound_network_metrics),
102                                usize::MAX,
103                            )))
104                            .service(::sui_simulator::anemo::Router::new());
105
106                        let outbound_layer = ServiceBuilder::new()
107                            .layer(
108                                TraceLayer::new_for_client_and_server_errors()
109                                    .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
110                                    .on_failure(DefaultOnFailure::new().level(tracing::Level::WARN)),
111                            )
112                            .layer(CallbackLayer::new(MetricsMakeCallbackHandler::new(
113                                Arc::new(outbound_network_metrics),
114                                usize::MAX,
115                            )))
116                            .into_inner();
117
118
119                        Network::bind(format!("127.0.0.1:{}", port))
120                            .server_name("static-init-network")
121                            .private_key(
122                                ::sui_simulator::fastcrypto::ed25519::Ed25519KeyPair::generate(&mut StdRng::from_rng(OsRng).unwrap())
123                                    .private()
124                                    .0
125                                    .to_bytes(),
126                            )
127                            .start(service)
128                            .unwrap()
129                    };
130                    let n1 = make_network(80);
131                    let n2 = make_network(81);
132
133                    let _peer = n1.connect(n2.local_addr()).await.unwrap();
134                });
135            }).join().unwrap();
136
137            #body
138        }
139    })
140    .expect("Parsing failure");
141
142    let result = quote! {
143        #input
144    };
145
146    result.into()
147}
148
149/// The sui_test macro will invoke either `#[msim::test]` or `#[tokio::test]`,
150/// depending on whether the simulator config var is enabled.
151///
152/// This should be used for tests that can meaningfully run in either environment.
153#[proc_macro_attribute]
154pub fn sui_test(args: TokenStream, item: TokenStream) -> TokenStream {
155    let input = parse_macro_input!(item as syn::ItemFn);
156    let arg_parser = Punctuated::<syn::Meta, Token![,]>::parse_terminated;
157    let args = arg_parser.parse(args).unwrap().into_iter();
158
159    let header = if cfg!(msim) {
160        quote! {
161            #[::sui_simulator::sim_test(crate = "sui_simulator", #(#args)* )]
162        }
163    } else {
164        quote! {
165            #[::tokio::test(#(#args)*)]
166        }
167    };
168
169    let result = quote! {
170        #header
171        #[::sui_macros::init_static_initializers]
172        #input
173    };
174
175    result.into()
176}
177
178/// The sim_test macro will invoke `#[msim::test]` if the simulator config var is enabled.
179///
180/// Otherwise, it will emit an ignored test - if forcibly run, the ignored test will panic.
181///
182/// This macro must be used in order to pass any simulator-specific arguments, such as
183/// `check_determinism`, which is not understood by tokio.
184#[proc_macro_attribute]
185pub fn sim_test(args: TokenStream, item: TokenStream) -> TokenStream {
186    let input = parse_macro_input!(item as syn::ItemFn);
187    let arg_parser = Punctuated::<syn::Meta, Token![,]>::parse_terminated;
188    let args = arg_parser.parse(args).unwrap().into_iter();
189
190    let ignore = input
191        .attrs
192        .iter()
193        .find(|attr| attr.path().is_ident("ignore"))
194        .map_or(quote! {}, |_| quote! { #[ignore] });
195
196    let result = if cfg!(msim) {
197        let sig = &input.sig;
198        let return_type = &sig.output;
199        let body = &input.block;
200        quote! {
201            #[::sui_simulator::sim_test(crate = "sui_simulator", #(#args),*)]
202            #[::sui_macros::init_static_initializers]
203            #ignore
204            #sig {
205                async fn body_fn() #return_type { #body }
206
207                let ret = body_fn().await;
208
209                ::sui_simulator::task::shutdown_all_nodes();
210
211                // all node handles should have been dropped after the above block exits, but task
212                // shutdown is asynchronous, so we need a brief delay before checking for leaks.
213                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
214
215                assert_eq!(
216                    sui_simulator::NodeLeakDetector::get_current_node_count(),
217                    0,
218                    "SuiNode leak detected"
219                );
220
221                ret
222            }
223        }
224    } else {
225        let fn_name = &input.sig.ident;
226        let sig = &input.sig;
227        let body = &input.block;
228        quote! {
229            #[allow(clippy::needless_return)]
230            #[tokio::test]
231            #ignore
232            #sig {
233                if std::env::var("SUI_SKIP_SIMTESTS").is_ok() {
234                    println!("not running test {} in `cargo test`: SUI_SKIP_SIMTESTS is set", stringify!(#fn_name));
235
236                    struct Ret;
237
238                    impl From<Ret> for () {
239                        fn from(_ret: Ret) -> Self {
240                        }
241                    }
242
243                    impl<E> From<Ret> for Result<(), E> {
244                        fn from(_ret: Ret) -> Self {
245                            Ok(())
246                        }
247                    }
248
249                    return Ret.into();
250                }
251
252                #body
253            }
254        }
255    };
256
257    result.into()
258}
259
260#[proc_macro]
261pub fn checked_arithmetic(input: TokenStream) -> TokenStream {
262    let input_file = CheckArithmetic.fold_file(parse_macro_input!(input));
263
264    let output_items = input_file.items;
265
266    let output = quote! {
267        #(#output_items)*
268    };
269
270    TokenStream::from(output)
271}
272
273#[proc_macro_attribute]
274pub fn with_checked_arithmetic(_attr: TokenStream, item: TokenStream) -> TokenStream {
275    let input_item = parse_macro_input!(item as Item);
276    match input_item {
277        Item::Fn(input_fn) => {
278            let transformed_fn = CheckArithmetic.fold_item_fn(input_fn);
279            TokenStream::from(quote! { #transformed_fn })
280        }
281        Item::Impl(input_impl) => {
282            let transformed_impl = CheckArithmetic.fold_item_impl(input_impl);
283            TokenStream::from(quote! { #transformed_impl })
284        }
285        item => {
286            let transformed_impl = CheckArithmetic.fold_item(item);
287            TokenStream::from(quote! { #transformed_impl })
288        }
289    }
290}
291
292struct CheckArithmetic;
293
294impl CheckArithmetic {
295    fn maybe_skip_macro(&self, attrs: &mut Vec<Attribute>) -> bool {
296        if let Some(idx) = attrs
297            .iter()
298            .position(|attr| attr.path().is_ident("skip_checked_arithmetic"))
299        {
300            // Skip processing macro because it is annotated with
301            // #[skip_checked_arithmetic]
302            attrs.remove(idx);
303            true
304        } else {
305            false
306        }
307    }
308
309    fn process_macro_contents(
310        &mut self,
311        tokens: proc_macro2::TokenStream,
312    ) -> syn::Result<proc_macro2::TokenStream> {
313        // Parse the macro's contents as a comma-separated list of expressions.
314        let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
315        let Ok(exprs) = parser.parse(tokens.clone().into()) else {
316            return Err(syn::Error::new_spanned(
317                tokens,
318                "could not process macro contents - use #[skip_checked_arithmetic] to skip this macro",
319            ));
320        };
321
322        // Fold each sub expression.
323        let folded_exprs = exprs
324            .into_iter()
325            .map(|expr| self.fold_expr(expr))
326            .collect::<Vec<_>>();
327
328        // Convert the folded expressions back into tokens and reconstruct the macro.
329        let mut folded_tokens = proc_macro2::TokenStream::new();
330        for (i, folded_expr) in folded_exprs.into_iter().enumerate() {
331            if i > 0 {
332                folded_tokens.extend(std::iter::once::<proc_macro2::TokenTree>(
333                    proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone).into(),
334                ));
335            }
336            folded_expr.to_tokens(&mut folded_tokens);
337        }
338
339        Ok(folded_tokens)
340    }
341}
342
343impl Fold for CheckArithmetic {
344    fn fold_stmt(&mut self, stmt: Stmt) -> Stmt {
345        let stmt = fold_stmt(self, stmt);
346        if let Stmt::Macro(stmt_macro) = stmt {
347            let StmtMacro {
348                mut attrs,
349                mut mac,
350                semi_token,
351            } = stmt_macro;
352
353            if self.maybe_skip_macro(&mut attrs) {
354                Stmt::Macro(StmtMacro {
355                    attrs,
356                    mac,
357                    semi_token,
358                })
359            } else {
360                match self.process_macro_contents(mac.tokens.clone()) {
361                    Ok(folded_tokens) => {
362                        mac.tokens = folded_tokens;
363                        Stmt::Macro(StmtMacro {
364                            attrs,
365                            mac,
366                            semi_token,
367                        })
368                    }
369                    Err(error) => parse2(error.to_compile_error()).unwrap(),
370                }
371            }
372        } else {
373            stmt
374        }
375    }
376
377    fn fold_item_macro(&mut self, mut item_macro: ItemMacro) -> ItemMacro {
378        if !self.maybe_skip_macro(&mut item_macro.attrs) {
379            let err = syn::Error::new_spanned(
380                item_macro.to_token_stream(),
381                "cannot process macros - use #[skip_checked_arithmetic] to skip \
382                    processing this macro",
383            );
384
385            return parse2(err.to_compile_error()).unwrap();
386        }
387        fold_item_macro(self, item_macro)
388    }
389
390    fn fold_expr(&mut self, expr: Expr) -> Expr {
391        let span = expr.span();
392        let expr = fold_expr(self, expr);
393        let expr = match expr {
394            Expr::Macro(expr_macro) => {
395                let ExprMacro { mut attrs, mut mac } = expr_macro;
396
397                if self.maybe_skip_macro(&mut attrs) {
398                    return Expr::Macro(ExprMacro { attrs, mac });
399                } else {
400                    match self.process_macro_contents(mac.tokens.clone()) {
401                        Ok(folded_tokens) => {
402                            mac.tokens = folded_tokens;
403                            let expr_macro = Expr::Macro(ExprMacro { attrs, mac });
404                            quote!(#expr_macro)
405                        }
406                        Err(error) => {
407                            return Expr::Verbatim(error.to_compile_error());
408                        }
409                    }
410                }
411            }
412
413            Expr::Binary(expr_binary) => {
414                let ExprBinary {
415                    attrs,
416                    mut left,
417                    op,
418                    mut right,
419                } = expr_binary;
420
421                fn remove_parens(expr: &mut Expr) {
422                    if let Expr::Paren(paren) = expr {
423                        // i don't even think rust allows this, but just in case
424                        assert!(paren.attrs.is_empty(), "TODO: attrs on parenthesized");
425                        *expr = *paren.expr.clone();
426                    }
427                }
428
429                macro_rules! wrap_op {
430                    ($left: expr, $right: expr, $method: ident, $span: expr) => {{
431                        // Remove parens from exprs since both sides get assigned to tmp variables.
432                        // otherwise we get lint errors
433                        remove_parens(&mut $left);
434                        remove_parens(&mut $right);
435
436                        quote_spanned!($span => {
437                            // assign in one stmt in case either #left or #right contains
438                            // references to `left` or `right` symbols.
439                            let (left, right) = (#left, #right);
440                            left.$method(right)
441                                .unwrap_or_else(||
442                                    panic!(
443                                        "Overflow or underflow in {} {} + {}",
444                                        stringify!($method),
445                                        left,
446                                        right,
447                                    )
448                                )
449                        })
450                    }};
451                }
452
453                macro_rules! wrap_op_assign {
454                    ($left: expr, $right: expr, $method: ident, $span: expr) => {{
455                        // Remove parens from exprs since both sides get assigned to tmp variables.
456                        // otherwise we get lint errors
457                        remove_parens(&mut $left);
458                        remove_parens(&mut $right);
459
460                        quote_spanned!($span => {
461                            // assign in one stmt in case either #left or #right contains
462                            // references to `left` or `right` symbols.
463                            let (left, right) = (&mut #left, #right);
464                            *left = (*left).$method(right)
465                                .unwrap_or_else(||
466                                    panic!(
467                                        "Overflow or underflow in {} {} + {}",
468                                        stringify!($method),
469                                        *left,
470                                        right
471                                    )
472                                )
473                        })
474                    }};
475                }
476
477                match op {
478                    BinOp::Add(_) => {
479                        wrap_op!(left, right, checked_add, span)
480                    }
481                    BinOp::Sub(_) => {
482                        wrap_op!(left, right, checked_sub, span)
483                    }
484                    BinOp::Mul(_) => {
485                        wrap_op!(left, right, checked_mul, span)
486                    }
487                    BinOp::Div(_) => {
488                        wrap_op!(left, right, checked_div, span)
489                    }
490                    BinOp::Rem(_) => {
491                        wrap_op!(left, right, checked_rem, span)
492                    }
493                    BinOp::AddAssign(_) => {
494                        wrap_op_assign!(left, right, checked_add, span)
495                    }
496                    BinOp::SubAssign(_) => {
497                        wrap_op_assign!(left, right, checked_sub, span)
498                    }
499                    BinOp::MulAssign(_) => {
500                        wrap_op_assign!(left, right, checked_mul, span)
501                    }
502                    BinOp::DivAssign(_) => {
503                        wrap_op_assign!(left, right, checked_div, span)
504                    }
505                    BinOp::RemAssign(_) => {
506                        wrap_op_assign!(left, right, checked_rem, span)
507                    }
508                    _ => {
509                        let expr_binary = ExprBinary {
510                            attrs,
511                            left,
512                            op,
513                            right,
514                        };
515                        quote_spanned!(span => #expr_binary)
516                    }
517                }
518            }
519            Expr::Unary(expr_unary) => {
520                let op = &expr_unary.op;
521                let operand = &expr_unary.expr;
522                match op {
523                    UnOp::Neg(_) => {
524                        quote_spanned!(span => #operand.checked_neg().expect("Overflow or underflow in negation"))
525                    }
526                    _ => quote_spanned!(span => #expr_unary),
527                }
528            }
529            _ => quote_spanned!(span => #expr),
530        };
531
532        parse2(expr).unwrap()
533    }
534}
535
536/// This proc macro generates a function `order_to_variant_map` which returns a map
537/// of the position of each variant to the name of the variant.
538/// It is intended to catch changes in enum order when backward compat is required.
539/// ```rust,ignore
540///    /// Example for this enum
541///    #[derive(EnumVariantOrder)]
542///    pub enum MyEnum {
543///         A,
544///         B(u64),
545///         C{x: bool, y: i8},
546///     }
547///     let order_map = MyEnum::order_to_variant_map();
548///     assert!(order_map.get(0).unwrap() == "A");
549///     assert!(order_map.get(1).unwrap() == "B");
550///     assert!(order_map.get(2).unwrap() == "C");
551/// ```
552#[proc_macro_derive(EnumVariantOrder)]
553pub fn enum_variant_order_derive(input: TokenStream) -> TokenStream {
554    let ast = parse_macro_input!(input as DeriveInput);
555    let name = &ast.ident;
556
557    if let Data::Enum(DataEnum { variants, .. }) = ast.data {
558        let variant_entries = variants
559            .iter()
560            .enumerate()
561            .map(|(index, variant)| {
562                let variant_name = variant.ident.to_string();
563                quote! {
564                    map.insert( #index as u64, (#variant_name).to_string());
565                }
566            })
567            .collect::<Vec<_>>();
568
569        let deriv = quote! {
570            impl sui_enum_compat_util::EnumOrderMap for #name {
571                fn order_to_variant_map() -> std::collections::BTreeMap<u64, String > {
572                    let mut map = std::collections::BTreeMap::new();
573                    #(#variant_entries)*
574                    map
575                }
576            }
577        };
578
579        deriv.into()
580    } else {
581        panic!("EnumVariantOrder can only be used with enums.");
582    }
583}