sui_proc_macros/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use proc_macro::TokenStream;
5use quote::{ToTokens, quote, quote_spanned};
6use syn::{
7    Attribute, BinOp, Data, DataEnum, DeriveInput, Expr, ExprBinary, ExprMacro, Item, ItemMacro,
8    Stmt, StmtMacro, Token, UnOp,
9    fold::{Fold, fold_expr, fold_item_macro, fold_stmt},
10    parse::Parser,
11    parse_macro_input, parse2,
12    punctuated::Punctuated,
13    spanned::Spanned,
14};
15
16#[proc_macro_attribute]
17pub fn init_static_initializers(_args: TokenStream, item: TokenStream) -> TokenStream {
18    let mut input = parse_macro_input!(item as syn::ItemFn);
19
20    let body = &input.block;
21    input.block = syn::parse2(quote! {
22        {
23            // We have some lazily-initialized static state in the program. The initializers
24            // alter the thread-local hash container state any time they create a new hash
25            // container. Therefore, we need to ensure that these initializers are run in a
26            // separate thread before the first test thread is launched. Otherwise, they would
27            // run inside of the first test thread, but not subsequent ones.
28            //
29            // Note that none of this has any effect on process-level determinism. Without this
30            // code, we can still get the same test results from two processes started with the
31            // same seed.
32            //
33            // However, when using sim_test(check_determinism) or MSIM_TEST_CHECK_DETERMINISM=1,
34            // we want the same test invocation to be deterministic when run twice
35            // _in the same process_, so we need to take care of this. This will also
36            // be very important for being able to reproduce a failure that occurs in the Nth
37            // iteration of a multi-iteration test run.
38            std::thread::spawn(|| {
39                use sui_protocol_config::ProtocolConfig;
40                ::sui_simulator::telemetry_subscribers::init_for_testing();
41                ::sui_simulator::sui_types::execution_params::get_denied_certificates_for_sim_test();
42                ::sui_simulator::sui_framework::BuiltInFramework::all_package_ids();
43                ::sui_simulator::sui_types::gas::SuiGasStatus::new_unmetered();
44
45                // For reasons I can't understand, LruCache causes divergent behavior the second
46                // time one is constructed and inserted into, so construct one before the first
47                // test run for determinism.
48                let mut cache = ::sui_simulator::lru::LruCache::new(1.try_into().unwrap());
49                cache.put(1, 1);
50
51                use std::sync::Arc;
52
53                use ::sui_simulator::anemo_tower::callback::CallbackLayer;
54                use ::sui_simulator::anemo_tower::trace::DefaultMakeSpan;
55                use ::sui_simulator::anemo_tower::trace::DefaultOnFailure;
56                use ::sui_simulator::anemo_tower::trace::TraceLayer;
57                use ::sui_simulator::fastcrypto::traits::KeyPair;
58                use ::sui_simulator::mysten_network::metrics::MetricsMakeCallbackHandler;
59                use ::sui_simulator::mysten_network::metrics::NetworkMetrics;
60                use ::sui_simulator::rand_crate::rngs::{StdRng, OsRng};
61                use ::sui_simulator::rand::SeedableRng;
62                use ::sui_simulator::tower::ServiceBuilder;
63
64                // anemo uses x509-parser, which has many lazy static variables. start a network to
65                // initialize all that static state before the first test.
66                let rt = ::sui_simulator::runtime::Runtime::new();
67                rt.block_on(async move {
68                    use ::sui_simulator::anemo::{Network, Request};
69
70                    {
71                        // Initialize the static initializers here:
72                        // https://github.com/move-language/move/blob/652badf6fd67e1d4cc2aa6dc69d63ad14083b673/language/tools/move-package/src/package_lock.rs#L12
73                        use std::path::PathBuf;
74                        use sui_simulator::sui_move_build::BuildConfig;
75                        use sui_simulator::tempfile::TempDir;
76
77                        let mut path = PathBuf::from(env!("SIMTEST_STATIC_INIT_MOVE"));
78                        let mut build_config = BuildConfig::new_for_testing();
79                        build_config.config.install_dir = Some(TempDir::new().unwrap().keep());
80                        let _all_module_bytes = build_config
81                            .build_async(&path)
82                            .await
83                            .unwrap()
84                            .get_package_bytes(/* with_unpublished_deps */ false);
85                    }
86
87                    let make_network = |port: u16| {
88                        let registry = prometheus::Registry::new();
89                        let inbound_network_metrics =
90                            NetworkMetrics::new("sui", "inbound", &registry);
91                        let outbound_network_metrics =
92                            NetworkMetrics::new("sui", "outbound", &registry);
93
94                        let service = ServiceBuilder::new()
95                            .layer(
96                                TraceLayer::new_for_server_errors()
97                                    .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
98                                    .on_failure(DefaultOnFailure::new().level(tracing::Level::WARN)),
99                            )
100                            .layer(CallbackLayer::new(MetricsMakeCallbackHandler::new(
101                                Arc::new(inbound_network_metrics),
102                                usize::MAX,
103                            )))
104                            .service(::sui_simulator::anemo::Router::new());
105
106                        let outbound_layer = ServiceBuilder::new()
107                            .layer(
108                                TraceLayer::new_for_client_and_server_errors()
109                                    .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
110                                    .on_failure(DefaultOnFailure::new().level(tracing::Level::WARN)),
111                            )
112                            .layer(CallbackLayer::new(MetricsMakeCallbackHandler::new(
113                                Arc::new(outbound_network_metrics),
114                                usize::MAX,
115                            )))
116                            .into_inner();
117
118
119                        Network::bind(format!("127.0.0.1:{}", port))
120                            .server_name("static-init-network")
121                            .private_key(
122                                ::sui_simulator::fastcrypto::ed25519::Ed25519KeyPair::generate(&mut StdRng::from_rng(OsRng).unwrap())
123                                    .private()
124                                    .0
125                                    .to_bytes(),
126                            )
127                            .start(service)
128                            .unwrap()
129                    };
130                    let n1 = make_network(80);
131                    let n2 = make_network(81);
132
133                    let _peer = n1.connect(n2.local_addr()).await.unwrap();
134                });
135            }).join().unwrap();
136
137            #body
138        }
139    })
140    .expect("Parsing failure");
141
142    let result = quote! {
143        #input
144    };
145
146    result.into()
147}
148
149/// The sui_test macro will invoke either `#[msim::test]` or `#[tokio::test]`,
150/// depending on whether the simulator config var is enabled.
151///
152/// This should be used for tests that can meaningfully run in either environment.
153#[proc_macro_attribute]
154pub fn sui_test(args: TokenStream, item: TokenStream) -> TokenStream {
155    let input = parse_macro_input!(item as syn::ItemFn);
156    let arg_parser = Punctuated::<syn::Meta, Token![,]>::parse_terminated;
157    let args = arg_parser.parse(args).unwrap().into_iter();
158
159    let header = if cfg!(msim) {
160        quote! {
161            #[::sui_simulator::sim_test(crate = "sui_simulator", #(#args)* )]
162        }
163    } else {
164        quote! {
165            #[::tokio::test(#(#args)*)]
166        }
167    };
168
169    let result = quote! {
170        #header
171        #[::sui_macros::init_static_initializers]
172        #input
173    };
174
175    result.into()
176}
177
178/// The sim_test macro will invoke `#[msim::test]` if the simulator config var is enabled.
179///
180/// Otherwise, it will emit an ignored test - if forcibly run, the ignored test will panic.
181///
182/// This macro must be used in order to pass any simulator-specific arguments, such as
183/// `check_determinism`, which is not understood by tokio.
184#[proc_macro_attribute]
185pub fn sim_test(args: TokenStream, item: TokenStream) -> TokenStream {
186    let input = parse_macro_input!(item as syn::ItemFn);
187    let arg_parser = Punctuated::<syn::Meta, Token![,]>::parse_terminated;
188    let args = arg_parser.parse(args).unwrap().into_iter();
189
190    let ignore = input
191        .attrs
192        .iter()
193        .find(|attr| attr.path().is_ident("ignore"))
194        .map_or(quote! {}, |_| quote! { #[ignore] });
195
196    let result = if cfg!(msim) {
197        let sig = &input.sig;
198        let return_type = &sig.output;
199        let body = &input.block;
200        quote! {
201            #[::sui_simulator::sim_test(crate = "sui_simulator", #(#args),*)]
202            #[::sui_macros::init_static_initializers]
203            #ignore
204            #sig {
205                async fn body_fn() #return_type { #body }
206
207                let timeout_secs: u64 = std::env::var("SUI_SIM_TEST_TIMEOUT_SECS")
208                    .ok()
209                    .and_then(|s| s.parse().ok())
210                    .unwrap_or(1000);
211                let timeout_duration = tokio::time::Duration::from_secs(timeout_secs);
212
213                let ret = tokio::time::timeout(timeout_duration, body_fn())
214                    .await
215                    .expect("sim_test timed out");
216
217                ::sui_simulator::task::shutdown_all_nodes();
218
219                // all node handles should have been dropped after the above block exits, but task
220                // shutdown is asynchronous, so we need a brief delay before checking for leaks.
221                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
222
223                assert_eq!(
224                    sui_simulator::NodeLeakDetector::get_current_node_count(),
225                    0,
226                    "SuiNode leak detected"
227                );
228
229                ret
230            }
231        }
232    } else {
233        let fn_name = &input.sig.ident;
234        let sig = &input.sig;
235        let body = &input.block;
236        quote! {
237            #[allow(clippy::needless_return)]
238            #[tokio::test]
239            #ignore
240            #sig {
241                if std::env::var("SUI_SKIP_SIMTESTS").is_ok() {
242                    println!("not running test {} in `cargo test`: SUI_SKIP_SIMTESTS is set", stringify!(#fn_name));
243
244                    struct Ret;
245
246                    impl From<Ret> for () {
247                        fn from(_ret: Ret) -> Self {
248                        }
249                    }
250
251                    impl<E> From<Ret> for Result<(), E> {
252                        fn from(_ret: Ret) -> Self {
253                            Ok(())
254                        }
255                    }
256
257                    return Ret.into();
258                }
259
260                #body
261            }
262        }
263    };
264
265    result.into()
266}
267
268#[proc_macro]
269pub fn checked_arithmetic(input: TokenStream) -> TokenStream {
270    let input_file = CheckArithmetic.fold_file(parse_macro_input!(input));
271
272    let output_items = input_file.items;
273
274    let output = quote! {
275        #(#output_items)*
276    };
277
278    TokenStream::from(output)
279}
280
281#[proc_macro_attribute]
282pub fn with_checked_arithmetic(_attr: TokenStream, item: TokenStream) -> TokenStream {
283    let input_item = parse_macro_input!(item as Item);
284    match input_item {
285        Item::Fn(input_fn) => {
286            let transformed_fn = CheckArithmetic.fold_item_fn(input_fn);
287            TokenStream::from(quote! { #transformed_fn })
288        }
289        Item::Impl(input_impl) => {
290            let transformed_impl = CheckArithmetic.fold_item_impl(input_impl);
291            TokenStream::from(quote! { #transformed_impl })
292        }
293        item => {
294            let transformed_impl = CheckArithmetic.fold_item(item);
295            TokenStream::from(quote! { #transformed_impl })
296        }
297    }
298}
299
300struct CheckArithmetic;
301
302impl CheckArithmetic {
303    fn maybe_skip_macro(&self, attrs: &mut Vec<Attribute>) -> bool {
304        if let Some(idx) = attrs
305            .iter()
306            .position(|attr| attr.path().is_ident("skip_checked_arithmetic"))
307        {
308            // Skip processing macro because it is annotated with
309            // #[skip_checked_arithmetic]
310            attrs.remove(idx);
311            true
312        } else {
313            false
314        }
315    }
316
317    fn process_macro_contents(
318        &mut self,
319        tokens: proc_macro2::TokenStream,
320    ) -> syn::Result<proc_macro2::TokenStream> {
321        // Parse the macro's contents as a comma-separated list of expressions.
322        let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
323        let Ok(exprs) = parser.parse(tokens.clone().into()) else {
324            return Err(syn::Error::new_spanned(
325                tokens,
326                "could not process macro contents - use #[skip_checked_arithmetic] to skip this macro",
327            ));
328        };
329
330        // Fold each sub expression.
331        let folded_exprs = exprs
332            .into_iter()
333            .map(|expr| self.fold_expr(expr))
334            .collect::<Vec<_>>();
335
336        // Convert the folded expressions back into tokens and reconstruct the macro.
337        let mut folded_tokens = proc_macro2::TokenStream::new();
338        for (i, folded_expr) in folded_exprs.into_iter().enumerate() {
339            if i > 0 {
340                folded_tokens.extend(std::iter::once::<proc_macro2::TokenTree>(
341                    proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone).into(),
342                ));
343            }
344            folded_expr.to_tokens(&mut folded_tokens);
345        }
346
347        Ok(folded_tokens)
348    }
349}
350
351impl Fold for CheckArithmetic {
352    fn fold_stmt(&mut self, stmt: Stmt) -> Stmt {
353        let stmt = fold_stmt(self, stmt);
354        if let Stmt::Macro(stmt_macro) = stmt {
355            let StmtMacro {
356                mut attrs,
357                mut mac,
358                semi_token,
359            } = stmt_macro;
360
361            if self.maybe_skip_macro(&mut attrs) {
362                Stmt::Macro(StmtMacro {
363                    attrs,
364                    mac,
365                    semi_token,
366                })
367            } else {
368                match self.process_macro_contents(mac.tokens.clone()) {
369                    Ok(folded_tokens) => {
370                        mac.tokens = folded_tokens;
371                        Stmt::Macro(StmtMacro {
372                            attrs,
373                            mac,
374                            semi_token,
375                        })
376                    }
377                    Err(error) => parse2(error.to_compile_error()).unwrap(),
378                }
379            }
380        } else {
381            stmt
382        }
383    }
384
385    fn fold_item_macro(&mut self, mut item_macro: ItemMacro) -> ItemMacro {
386        if !self.maybe_skip_macro(&mut item_macro.attrs) {
387            let err = syn::Error::new_spanned(
388                item_macro.to_token_stream(),
389                "cannot process macros - use #[skip_checked_arithmetic] to skip \
390                    processing this macro",
391            );
392
393            return parse2(err.to_compile_error()).unwrap();
394        }
395        fold_item_macro(self, item_macro)
396    }
397
398    fn fold_expr(&mut self, expr: Expr) -> Expr {
399        let span = expr.span();
400        let expr = fold_expr(self, expr);
401        let expr = match expr {
402            Expr::Macro(expr_macro) => {
403                let ExprMacro { mut attrs, mut mac } = expr_macro;
404
405                if self.maybe_skip_macro(&mut attrs) {
406                    return Expr::Macro(ExprMacro { attrs, mac });
407                } else {
408                    match self.process_macro_contents(mac.tokens.clone()) {
409                        Ok(folded_tokens) => {
410                            mac.tokens = folded_tokens;
411                            let expr_macro = Expr::Macro(ExprMacro { attrs, mac });
412                            quote!(#expr_macro)
413                        }
414                        Err(error) => {
415                            return Expr::Verbatim(error.to_compile_error());
416                        }
417                    }
418                }
419            }
420
421            Expr::Binary(expr_binary) => {
422                let ExprBinary {
423                    attrs,
424                    mut left,
425                    op,
426                    mut right,
427                } = expr_binary;
428
429                fn remove_parens(expr: &mut Expr) {
430                    if let Expr::Paren(paren) = expr {
431                        // i don't even think rust allows this, but just in case
432                        assert!(paren.attrs.is_empty(), "TODO: attrs on parenthesized");
433                        *expr = *paren.expr.clone();
434                    }
435                }
436
437                macro_rules! wrap_op {
438                    ($left: expr, $right: expr, $method: ident, $span: expr) => {{
439                        // Remove parens from exprs since both sides get assigned to tmp variables.
440                        // otherwise we get lint errors
441                        remove_parens(&mut $left);
442                        remove_parens(&mut $right);
443
444                        quote_spanned!($span => {
445                            // assign in one stmt in case either #left or #right contains
446                            // references to `left` or `right` symbols.
447                            let (left, right) = (#left, #right);
448                            left.$method(right)
449                                .unwrap_or_else(||
450                                    panic!(
451                                        "Overflow or underflow in {} {} + {}",
452                                        stringify!($method),
453                                        left,
454                                        right,
455                                    )
456                                )
457                        })
458                    }};
459                }
460
461                macro_rules! wrap_op_assign {
462                    ($left: expr, $right: expr, $method: ident, $span: expr) => {{
463                        // Remove parens from exprs since both sides get assigned to tmp variables.
464                        // otherwise we get lint errors
465                        remove_parens(&mut $left);
466                        remove_parens(&mut $right);
467
468                        quote_spanned!($span => {
469                            // assign in one stmt in case either #left or #right contains
470                            // references to `left` or `right` symbols.
471                            let (left, right) = (&mut #left, #right);
472                            *left = (*left).$method(right)
473                                .unwrap_or_else(||
474                                    panic!(
475                                        "Overflow or underflow in {} {} + {}",
476                                        stringify!($method),
477                                        *left,
478                                        right
479                                    )
480                                )
481                        })
482                    }};
483                }
484
485                match op {
486                    BinOp::Add(_) => {
487                        wrap_op!(left, right, checked_add, span)
488                    }
489                    BinOp::Sub(_) => {
490                        wrap_op!(left, right, checked_sub, span)
491                    }
492                    BinOp::Mul(_) => {
493                        wrap_op!(left, right, checked_mul, span)
494                    }
495                    BinOp::Div(_) => {
496                        wrap_op!(left, right, checked_div, span)
497                    }
498                    BinOp::Rem(_) => {
499                        wrap_op!(left, right, checked_rem, span)
500                    }
501                    BinOp::AddAssign(_) => {
502                        wrap_op_assign!(left, right, checked_add, span)
503                    }
504                    BinOp::SubAssign(_) => {
505                        wrap_op_assign!(left, right, checked_sub, span)
506                    }
507                    BinOp::MulAssign(_) => {
508                        wrap_op_assign!(left, right, checked_mul, span)
509                    }
510                    BinOp::DivAssign(_) => {
511                        wrap_op_assign!(left, right, checked_div, span)
512                    }
513                    BinOp::RemAssign(_) => {
514                        wrap_op_assign!(left, right, checked_rem, span)
515                    }
516                    _ => {
517                        let expr_binary = ExprBinary {
518                            attrs,
519                            left,
520                            op,
521                            right,
522                        };
523                        quote_spanned!(span => #expr_binary)
524                    }
525                }
526            }
527            Expr::Unary(expr_unary) => {
528                let op = &expr_unary.op;
529                let operand = &expr_unary.expr;
530                match op {
531                    UnOp::Neg(_) => {
532                        quote_spanned!(span => #operand.checked_neg().expect("Overflow or underflow in negation"))
533                    }
534                    _ => quote_spanned!(span => #expr_unary),
535                }
536            }
537            _ => quote_spanned!(span => #expr),
538        };
539
540        parse2(expr).unwrap()
541    }
542}
543
544/// This proc macro generates a function `order_to_variant_map` which returns a map
545/// of the position of each variant to the name of the variant.
546/// It is intended to catch changes in enum order when backward compat is required.
547/// ```rust,ignore
548///    /// Example for this enum
549///    #[derive(EnumVariantOrder)]
550///    pub enum MyEnum {
551///         A,
552///         B(u64),
553///         C{x: bool, y: i8},
554///     }
555///     let order_map = MyEnum::order_to_variant_map();
556///     assert!(order_map.get(0).unwrap() == "A");
557///     assert!(order_map.get(1).unwrap() == "B");
558///     assert!(order_map.get(2).unwrap() == "C");
559/// ```
560#[proc_macro_derive(EnumVariantOrder)]
561pub fn enum_variant_order_derive(input: TokenStream) -> TokenStream {
562    let ast = parse_macro_input!(input as DeriveInput);
563    let name = &ast.ident;
564
565    if let Data::Enum(DataEnum { variants, .. }) = ast.data {
566        let variant_entries = variants
567            .iter()
568            .enumerate()
569            .map(|(index, variant)| {
570                let variant_name = variant.ident.to_string();
571                quote! {
572                    map.insert( #index as u64, (#variant_name).to_string());
573                }
574            })
575            .collect::<Vec<_>>();
576
577        let deriv = quote! {
578            impl sui_enum_compat_util::EnumOrderMap for #name {
579                fn order_to_variant_map() -> std::collections::BTreeMap<u64, String > {
580                    let mut map = std::collections::BTreeMap::new();
581                    #(#variant_entries)*
582                    map
583                }
584            }
585        };
586
587        deriv.into()
588    } else {
589        panic!("EnumVariantOrder can only be used with enums.");
590    }
591}